polars_io/cloud/
polars_object_store.rs

1use std::ops::Range;
2
3use bytes::Bytes;
4use futures::{StreamExt, TryStreamExt};
5use hashbrown::hash_map::RawEntryMut;
6use object_store::path::Path;
7use object_store::{ObjectMeta, ObjectStore};
8use polars_core::prelude::{InitHashMaps, PlHashMap};
9use polars_error::{PolarsError, PolarsResult};
10use polars_utils::mmap::MemSlice;
11use tokio::io::{AsyncSeekExt, AsyncWriteExt};
12
13use crate::pl_async::{
14    self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,
15    tune_with_concurrency_budget, with_concurrency_budget,
16};
17
18mod inner {
19    use std::future::Future;
20    use std::sync::Arc;
21
22    use object_store::ObjectStore;
23    use polars_core::config;
24    use polars_error::PolarsResult;
25    use polars_utils::relaxed_cell::RelaxedCell;
26
27    use crate::cloud::PolarsObjectStoreBuilder;
28
29    #[derive(Debug)]
30    struct Inner {
31        store: tokio::sync::Mutex<Arc<dyn ObjectStore>>,
32        builder: PolarsObjectStoreBuilder,
33    }
34
35    /// Polars wrapper around [`ObjectStore`] functionality. This struct is cheaply cloneable.
36    #[derive(Clone, Debug)]
37    pub struct PolarsObjectStore {
38        inner: Arc<Inner>,
39        /// Avoid contending the Mutex `lock()` until the first re-build.
40        initial_store: std::sync::Arc<dyn ObjectStore>,
41        /// Used for interior mutability. Doesn't need to be shared with other threads so it's not
42        /// inside `Arc<>`.
43        rebuilt: RelaxedCell<bool>,
44    }
45
46    impl PolarsObjectStore {
47        pub(crate) fn new_from_inner(
48            store: Arc<dyn ObjectStore>,
49            builder: PolarsObjectStoreBuilder,
50        ) -> Self {
51            let initial_store = store.clone();
52            Self {
53                inner: Arc::new(Inner {
54                    store: tokio::sync::Mutex::new(store),
55                    builder,
56                }),
57                initial_store,
58                rebuilt: RelaxedCell::from(false),
59            }
60        }
61
62        /// Gets the underlying [`ObjectStore`] implementation.
63        pub async fn to_dyn_object_store(&self) -> Arc<dyn ObjectStore> {
64            if !self.rebuilt.load() {
65                self.initial_store.clone()
66            } else {
67                self.inner.store.lock().await.clone()
68            }
69        }
70
71        pub async fn rebuild_inner(
72            &self,
73            from_version: &Arc<dyn ObjectStore>,
74        ) -> PolarsResult<Arc<dyn ObjectStore>> {
75            let mut current_store = self.inner.store.lock().await;
76
77            // If this does not eq, then `inner` was already re-built by another thread.
78            if Arc::ptr_eq(&*current_store, from_version) {
79                *current_store =
80                    self.inner
81                        .builder
82                        .clone()
83                        .build_impl(true)
84                        .await
85                        .map_err(|e| {
86                            e.wrap_msg(|e| format!("attempt to rebuild object store failed: {e}"))
87                        })?;
88            }
89
90            self.rebuilt.store(true);
91
92            Ok((*current_store).clone())
93        }
94
95        pub async fn try_exec_rebuild_on_err<Fn, Fut, O>(&self, mut func: Fn) -> PolarsResult<O>
96        where
97            Fn: FnMut(&Arc<dyn ObjectStore>) -> Fut,
98            Fut: Future<Output = PolarsResult<O>>,
99        {
100            let store = self.to_dyn_object_store().await;
101
102            let out = func(&store).await;
103
104            let orig_err = match out {
105                Ok(v) => return Ok(v),
106                Err(e) => e,
107            };
108
109            if config::verbose() {
110                eprintln!(
111                    "[PolarsObjectStore]: got error: {}, will attempt re-build",
112                    &orig_err
113                );
114            }
115
116            let store = self
117                .rebuild_inner(&store)
118                .await
119                .map_err(|e| e.wrap_msg(|e| format!("{e}; original error: {orig_err}")))?;
120
121            func(&store).await.map_err(|e| {
122                if self.inner.builder.is_azure()
123                    && std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()
124                        != Ok("1")
125                {
126                    // Note: This error is intended for Python audiences. The logic for retrieving
127                    // these keys exist only on the Python side.
128                    e.wrap_msg(|e| {
129                        format!(
130                            "{e}; note: if you are using Python, consider setting \
131POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \
132and use the storage account keys from Azure CLI to authenticate"
133                        )
134                    })
135                } else {
136                    e
137                }
138            })
139        }
140    }
141}
142
143pub use inner::PolarsObjectStore;
144
145pub type ObjectStorePath = object_store::path::Path;
146
147impl PolarsObjectStore {
148    /// Returns a buffered stream that downloads concurrently up to the concurrency limit.
149    fn get_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(
150        store: &'a dyn ObjectStore,
151        path: &'a Path,
152        ranges: T,
153    ) -> impl StreamExt<Item = PolarsResult<Bytes>>
154    + TryStreamExt<Ok = Bytes, Error = PolarsError, Item = PolarsResult<Bytes>>
155    + use<'a, T> {
156        futures::stream::iter(ranges.map(move |range| async move {
157            if range.is_empty() {
158                return Ok(Bytes::new());
159            }
160
161            let out = store
162                .get_range(path, range.start as u64..range.end as u64)
163                .await?;
164            Ok(out)
165        }))
166        // Add a limit locally as this gets run inside a single `tune_with_concurrency_budget`.
167        .buffered(get_concurrency_limit() as usize)
168    }
169
170    pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Bytes> {
171        if range.is_empty() {
172            return Ok(Bytes::new());
173        }
174
175        self.try_exec_rebuild_on_err(move |store| {
176            let range = range.clone();
177            let st = store.clone();
178
179            async move {
180                let store = st;
181                let parts = split_range(range.clone());
182
183                if parts.len() == 1 {
184                    let out = tune_with_concurrency_budget(1, move || async move {
185                        store
186                            .get_range(path, range.start as u64..range.end as u64)
187                            .await
188                    })
189                    .await?;
190
191                    Ok(out)
192                } else {
193                    let parts = tune_with_concurrency_budget(
194                        parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
195                        || {
196                            Self::get_buffered_ranges_stream(&store, path, parts)
197                                .try_collect::<Vec<Bytes>>()
198                        },
199                    )
200                    .await?;
201
202                    let mut combined = Vec::with_capacity(range.len());
203
204                    for part in parts {
205                        combined.extend_from_slice(&part)
206                    }
207
208                    assert_eq!(combined.len(), range.len());
209
210                    PolarsResult::Ok(Bytes::from(combined))
211                }
212            }
213        })
214        .await
215    }
216
217    /// Fetch byte ranges into a HashMap keyed by the range start. This will mutably sort the
218    /// `ranges` slice for coalescing.
219    ///
220    /// # Panics
221    /// Panics if the same range start is used by more than 1 range.
222    pub async fn get_ranges_sort(
223        &self,
224        path: &Path,
225        ranges: &mut [Range<usize>],
226    ) -> PolarsResult<PlHashMap<usize, MemSlice>> {
227        if ranges.is_empty() {
228            return Ok(Default::default());
229        }
230
231        ranges.sort_unstable_by_key(|x| x.start);
232
233        let ranges_len = ranges.len();
234        let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();
235
236        self.try_exec_rebuild_on_err(|store| {
237            let st = store.clone();
238
239            async {
240                let store = st;
241                let mut out = PlHashMap::with_capacity(ranges_len);
242
243                let mut stream =
244                    Self::get_buffered_ranges_stream(&store, path, merged_ranges.iter().cloned());
245
246                tune_with_concurrency_budget(
247                    merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
248                    || async {
249                        let mut len = 0;
250                        let mut current_offset = 0;
251                        let mut ends_iter = merged_ends.iter();
252
253                        let mut splitted_parts = vec![];
254
255                        while let Some(bytes) = stream.try_next().await? {
256                            len += bytes.len();
257                            let end = *ends_iter.next().unwrap();
258
259                            if end == 0 {
260                                splitted_parts.push(bytes);
261                                continue;
262                            }
263
264                            let full_range = ranges[current_offset..end]
265                                .iter()
266                                .cloned()
267                                .reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))
268                                .unwrap();
269
270                            let bytes = if splitted_parts.is_empty() {
271                                bytes
272                            } else {
273                                let mut out = Vec::with_capacity(full_range.len());
274
275                                for x in splitted_parts.drain(..) {
276                                    out.extend_from_slice(&x);
277                                }
278
279                                out.extend_from_slice(&bytes);
280                                Bytes::from(out)
281                            };
282
283                            assert_eq!(bytes.len(), full_range.len());
284
285                            let bytes = MemSlice::from_bytes(bytes);
286
287                            for range in &ranges[current_offset..end] {
288                                let mem_slice = bytes.slice(
289                                    range.start - full_range.start..range.end - full_range.start,
290                                );
291
292                                match out.raw_entry_mut().from_key(&range.start) {
293                                    RawEntryMut::Vacant(slot) => {
294                                        slot.insert(range.start, mem_slice);
295                                    },
296                                    RawEntryMut::Occupied(mut slot) => {
297                                        if slot.get_mut().len() < mem_slice.len() {
298                                            *slot.get_mut() = mem_slice;
299                                        }
300                                    },
301                                }
302                            }
303
304                            current_offset = end;
305                        }
306
307                        assert!(splitted_parts.is_empty());
308
309                        PolarsResult::Ok(pl_async::Size::from(len as u64))
310                    },
311                )
312                .await?;
313
314                Ok(out)
315            }
316        })
317        .await
318    }
319
320    pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {
321        let opt_size = self.head(path).await.ok().map(|x| x.size);
322
323        let initial_pos = file.stream_position().await?;
324
325        self.try_exec_rebuild_on_err(|store| {
326            let st = store.clone();
327
328            // Workaround for "can't move captured variable".
329            let file: &mut tokio::fs::File = unsafe { std::mem::transmute_copy(&file) };
330
331            async {
332                file.set_len(initial_pos).await?; // Reset if this function was called again.
333
334                let store = st;
335                let parts = opt_size
336                    .map(|x| split_range(0..x as usize))
337                    .filter(|x| x.len() > 1);
338
339                if let Some(parts) = parts {
340                    tune_with_concurrency_budget(
341                        parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
342                        || async {
343                            let mut stream = Self::get_buffered_ranges_stream(&store, path, parts);
344                            let mut len = 0;
345                            while let Some(bytes) = stream.try_next().await? {
346                                len += bytes.len();
347                                file.write_all(&bytes).await?;
348                            }
349
350                            assert_eq!(len, opt_size.unwrap() as usize);
351
352                            PolarsResult::Ok(pl_async::Size::from(len as u64))
353                        },
354                    )
355                    .await?
356                } else {
357                    tune_with_concurrency_budget(1, || async {
358                        let mut stream = store.get(path).await?.into_stream();
359
360                        let mut len = 0;
361                        while let Some(bytes) = stream.try_next().await? {
362                            len += bytes.len();
363                            file.write_all(&bytes).await?;
364                        }
365
366                        PolarsResult::Ok(pl_async::Size::from(len as u64))
367                    })
368                    .await?
369                };
370
371                // Dropping is delayed for tokio async files so we need to explicitly
372                // flush here (https://github.com/tokio-rs/tokio/issues/2307#issuecomment-596336451).
373                file.sync_all().await.map_err(PolarsError::from)?;
374
375                Ok(())
376            }
377        })
378        .await
379    }
380
381    /// Fetch the metadata of the parquet file, do not memoize it.
382    pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {
383        self.try_exec_rebuild_on_err(|store| {
384            let st = store.clone();
385
386            async {
387                with_concurrency_budget(1, || async {
388                    let store = st;
389                    let head_result = store.head(path).await;
390
391                    if head_result.is_err() {
392                        // Pre-signed URLs forbid the HEAD method, but we can still retrieve the header
393                        // information with a range 0-0 request.
394                        let get_range_0_0_result = store
395                            .get_opts(
396                                path,
397                                object_store::GetOptions {
398                                    range: Some((0..1).into()),
399                                    ..Default::default()
400                                },
401                            )
402                            .await;
403
404                        if let Ok(v) = get_range_0_0_result {
405                            return Ok(v.meta);
406                        }
407                    }
408
409                    let out = head_result?;
410
411                    Ok(out)
412                })
413                .await
414            }
415        })
416        .await
417    }
418}
419
420/// Splits a single range into multiple smaller ranges, which can be downloaded concurrently for
421/// much higher throughput.
422fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {
423    let chunk_size = get_download_chunk_size();
424
425    // Calculate n_parts such that we are as close as possible to the `chunk_size`.
426    let n_parts = [
427        (range.len().div_ceil(chunk_size)).max(1),
428        (range.len() / chunk_size).max(1),
429    ]
430    .into_iter()
431    .min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))
432    .unwrap();
433
434    let chunk_size = (range.len() / n_parts).max(1);
435
436    assert_eq!(n_parts, (range.len() / chunk_size).max(1));
437    let bytes_rem = range.len() % chunk_size;
438
439    (0..n_parts).map(move |part_no| {
440        let (start, end) = if part_no == 0 {
441            // Download remainder length in the first chunk since it starts downloading first.
442            let end = range.start + chunk_size + bytes_rem;
443            let end = if end > range.end { range.end } else { end };
444            (range.start, end)
445        } else {
446            let start = bytes_rem + range.start + part_no * chunk_size;
447            (start, start + chunk_size)
448        };
449
450        start..end
451    })
452}
453
454/// Note: For optimal performance, `ranges` should be sorted. More generally,
455/// ranges placed next to each other should also be close in range value.
456///
457/// # Returns
458/// `[(range1, end1), (range2, end2)]`, where:
459/// * `range1` contains bytes for the ranges from `ranges[0..end1]`
460/// * `range2` contains bytes for the ranges from `ranges[end1..end2]`
461/// * etc..
462///
463/// Note that if an end value is 0, it means the range is a splitted part and should be combined.
464fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {
465    let chunk_size = get_download_chunk_size();
466
467    let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);
468    // Number of fetched bytes excluding excess.
469    let mut current_n_bytes = current_merged_range.len();
470
471    (0..ranges.len())
472        .filter_map(move |current_idx| {
473            let current_idx = 1 + current_idx;
474
475            if current_idx == ranges.len() {
476                // No more items - flush current state.
477                Some((current_merged_range.clone(), current_idx))
478            } else {
479                let range = ranges[current_idx].clone();
480
481                let new_merged = current_merged_range.start.min(range.start)
482                    ..current_merged_range.end.max(range.end);
483
484                // E.g.:
485                // |--------|
486                //  oo        // range1
487                //       oo   // range2
488                //    ^^^     // distance = 3, is_overlapping = false
489                // E.g.:
490                // |--------|
491                //  ooooo     // range1
492                //     ooooo  // range2
493                //     ^^     // distance = 2, is_overlapping = true
494                let (distance, is_overlapping) = {
495                    let l = current_merged_range.end.min(range.end);
496                    let r = current_merged_range.start.max(range.start);
497
498                    (r.abs_diff(l), r < l)
499                };
500
501                let should_merge = is_overlapping || {
502                    let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)
503                        <= current_merged_range.len().abs_diff(chunk_size);
504                    let gap_tolerance =
505                        (current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);
506
507                    leq_current_len_dist_to_chunk_size && distance <= gap_tolerance
508                };
509
510                if should_merge {
511                    // Merge to existing range
512                    current_merged_range = new_merged;
513                    current_n_bytes += if is_overlapping {
514                        range.len() - distance
515                    } else {
516                        range.len()
517                    };
518                    None
519                } else {
520                    let out = (current_merged_range.clone(), current_idx);
521                    current_merged_range = range;
522                    current_n_bytes = current_merged_range.len();
523                    Some(out)
524                }
525            }
526        })
527        .flat_map(|x| {
528            // Split large individual ranges within the list of ranges.
529            let (range, end) = x;
530            let split = split_range(range);
531            let len = split.len();
532
533            split
534                .enumerate()
535                .map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))
536        })
537}
538
539#[cfg(test)]
540mod tests {
541
542    #[test]
543    fn test_split_range() {
544        use super::{get_download_chunk_size, split_range};
545
546        let chunk_size = get_download_chunk_size();
547
548        assert_eq!(chunk_size, 64 * 1024 * 1024);
549
550        #[allow(clippy::single_range_in_vec_init)]
551        {
552            // Round-trip empty ranges.
553            assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);
554            assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);
555        }
556
557        // Threshold to start splitting to 2 ranges
558        //
559        // n - chunk_size == chunk_size - n / 2
560        // n + n / 2 == 2 * chunk_size
561        // 3 * n == 4 * chunk_size
562        // n = 4 * chunk_size / 3
563        let n = 4 * chunk_size / 3;
564
565        #[allow(clippy::single_range_in_vec_init)]
566        {
567            assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);
568        }
569
570        assert_eq!(
571            split_range(0..n + 1).collect::<Vec<_>>(),
572            [0..44739243, 44739243..89478486]
573        );
574
575        // Threshold to start splitting to 3 ranges
576        //
577        // n / 2 - chunk_size == chunk_size - n / 3
578        // n / 2 + n / 3 == 2 * chunk_size
579        // 5 * n == 12 * chunk_size
580        // n == 12 * chunk_size / 5
581        let n = 12 * chunk_size / 5;
582
583        assert_eq!(
584            split_range(0..n).collect::<Vec<_>>(),
585            [0..80530637, 80530637..161061273]
586        );
587
588        assert_eq!(
589            split_range(0..n + 1).collect::<Vec<_>>(),
590            [0..53687092, 53687092..107374183, 107374183..161061274]
591        );
592    }
593
594    #[test]
595    fn test_merge_ranges() {
596        use super::{get_download_chunk_size, merge_ranges};
597
598        let chunk_size = get_download_chunk_size();
599
600        assert_eq!(chunk_size, 64 * 1024 * 1024);
601
602        // Round-trip empty slice
603        assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);
604
605        // We have 1 tiny request followed by 1 huge request. They are combined as it reduces the
606        // `abs_diff()` to the `chunk_size`, but afterwards they are split to 2 evenly sized
607        // requests.
608        assert_eq!(
609            merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),
610            [(0..66584576, 0), (66584576..133169152, 2)]
611        );
612
613        // <= 1MiB gap, merge
614        assert_eq!(
615            merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),
616            [(0..1048578, 2)]
617        );
618
619        // > 1MiB gap, do not merge
620        assert_eq!(
621            merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),
622            [(0..1, 1), (1048578..1048579, 2)]
623        );
624
625        // <= 12.5% gap, merge
626        assert_eq!(
627            merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),
628            [(0..11, 2)]
629        );
630
631        // <= 12.5% gap relative to RHS, merge
632        assert_eq!(
633            merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),
634            [(0..11, 2)]
635        );
636
637        // Overlapping range, merge
638        assert_eq!(
639            merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])
640                .collect::<Vec<_>>(),
641            [(0..80 * 1024 * 1024, 2)]
642        );
643    }
644}