polars_io/cloud/
polars_object_store.rs

1use std::ops::Range;
2
3use bytes::Bytes;
4use futures::{StreamExt, TryStreamExt};
5use object_store::path::Path;
6use object_store::{ObjectMeta, ObjectStore};
7use polars_core::prelude::{InitHashMaps, PlHashMap};
8use polars_error::{PolarsError, PolarsResult, to_compute_err};
9use tokio::io::{AsyncSeekExt, AsyncWriteExt};
10
11use crate::pl_async::{
12    self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,
13    tune_with_concurrency_budget, with_concurrency_budget,
14};
15
16mod inner {
17    use std::future::Future;
18    use std::sync::Arc;
19    use std::sync::atomic::AtomicBool;
20
21    use object_store::ObjectStore;
22    use polars_core::config;
23    use polars_error::PolarsResult;
24
25    use crate::cloud::PolarsObjectStoreBuilder;
26
27    #[derive(Debug)]
28    struct Inner {
29        store: tokio::sync::Mutex<Arc<dyn ObjectStore>>,
30        builder: PolarsObjectStoreBuilder,
31    }
32
33    /// Polars wrapper around [`ObjectStore`] functionality. This struct is cheaply cloneable.
34    #[derive(Debug)]
35    pub struct PolarsObjectStore {
36        inner: Arc<Inner>,
37        /// Avoid contending the Mutex `lock()` until the first re-build.
38        initial_store: std::sync::Arc<dyn ObjectStore>,
39        /// Used for interior mutability. Doesn't need to be shared with other threads so it's not
40        /// inside `Arc<>`.
41        rebuilt: AtomicBool,
42    }
43
44    impl Clone for PolarsObjectStore {
45        fn clone(&self) -> Self {
46            Self {
47                inner: self.inner.clone(),
48                initial_store: self.initial_store.clone(),
49                rebuilt: AtomicBool::new(self.rebuilt.load(std::sync::atomic::Ordering::Relaxed)),
50            }
51        }
52    }
53
54    impl PolarsObjectStore {
55        pub(crate) fn new_from_inner(
56            store: Arc<dyn ObjectStore>,
57            builder: PolarsObjectStoreBuilder,
58        ) -> Self {
59            let initial_store = store.clone();
60            Self {
61                inner: Arc::new(Inner {
62                    store: tokio::sync::Mutex::new(store),
63                    builder,
64                }),
65                initial_store,
66                rebuilt: AtomicBool::new(false),
67            }
68        }
69
70        /// Gets the underlying [`ObjectStore`] implementation.
71        pub async fn to_dyn_object_store(&self) -> Arc<dyn ObjectStore> {
72            if !self.rebuilt.load(std::sync::atomic::Ordering::Relaxed) {
73                self.initial_store.clone()
74            } else {
75                self.inner.store.lock().await.clone()
76            }
77        }
78
79        pub async fn rebuild_inner(
80            &self,
81            from_version: &Arc<dyn ObjectStore>,
82        ) -> PolarsResult<Arc<dyn ObjectStore>> {
83            let mut current_store = self.inner.store.lock().await;
84
85            self.rebuilt
86                .store(true, std::sync::atomic::Ordering::Relaxed);
87
88            // If this does not eq, then `inner` was already re-built by another thread.
89            if Arc::ptr_eq(&*current_store, from_version) {
90                *current_store = self.inner.builder.clone().build_impl().await.map_err(|e| {
91                    e.wrap_msg(|e| format!("attempt to rebuild object store failed: {}", e))
92                })?;
93            }
94
95            Ok((*current_store).clone())
96        }
97
98        pub async fn try_exec_rebuild_on_err<Fn, Fut, O>(&self, mut func: Fn) -> PolarsResult<O>
99        where
100            Fn: FnMut(&Arc<dyn ObjectStore>) -> Fut,
101            Fut: Future<Output = PolarsResult<O>>,
102        {
103            let store = self.to_dyn_object_store().await;
104
105            let out = func(&store).await;
106
107            let orig_err = match out {
108                Ok(v) => return Ok(v),
109                Err(e) => e,
110            };
111
112            if config::verbose() {
113                eprintln!(
114                    "[PolarsObjectStore]: got error: {}, will attempt re-build",
115                    &orig_err
116                );
117            }
118
119            let store = self
120                .rebuild_inner(&store)
121                .await
122                .map_err(|e| e.wrap_msg(|e| format!("{}; original error: {}", e, orig_err)))?;
123
124            func(&store).await.map_err(|e| {
125                if self.inner.builder.is_azure()
126                    && std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()
127                        != Ok("1")
128                {
129                    // Note: This error is intended for Python audiences. The logic for retrieving
130                    // these keys exist only on the Python side.
131                    e.wrap_msg(|e| {
132                        format!(
133                            "{}; note: if you are using Python, consider setting \
134POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \
135and use the storage account keys from Azure CLI to authenticate",
136                            e
137                        )
138                    })
139                } else {
140                    e
141                }
142            })
143        }
144    }
145}
146
147pub use inner::PolarsObjectStore;
148
149pub type ObjectStorePath = object_store::path::Path;
150
151impl PolarsObjectStore {
152    /// Returns a buffered stream that downloads concurrently up to the concurrency limit.
153    fn get_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(
154        store: &'a dyn ObjectStore,
155        path: &'a Path,
156        ranges: T,
157    ) -> impl StreamExt<Item = PolarsResult<Bytes>>
158    + TryStreamExt<Ok = Bytes, Error = PolarsError, Item = PolarsResult<Bytes>>
159    + use<'a, T> {
160        futures::stream::iter(
161            ranges
162                .map(|range| async { store.get_range(path, range).await.map_err(to_compute_err) }),
163        )
164        // Add a limit locally as this gets run inside a single `tune_with_concurrency_budget`.
165        .buffered(get_concurrency_limit() as usize)
166    }
167
168    pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Bytes> {
169        self.try_exec_rebuild_on_err(move |store| {
170            let range = range.clone();
171            let st = store.clone();
172
173            async {
174                let store = st;
175                let parts = split_range(range.clone());
176
177                if parts.len() == 1 {
178                    tune_with_concurrency_budget(1, || async { store.get_range(path, range).await })
179                        .await
180                        .map_err(to_compute_err)
181                } else {
182                    let parts = tune_with_concurrency_budget(
183                        parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
184                        || {
185                            Self::get_buffered_ranges_stream(&store, path, parts)
186                                .try_collect::<Vec<Bytes>>()
187                        },
188                    )
189                    .await?;
190
191                    let mut combined = Vec::with_capacity(range.len());
192
193                    for part in parts {
194                        combined.extend_from_slice(&part)
195                    }
196
197                    assert_eq!(combined.len(), range.len());
198
199                    PolarsResult::Ok(Bytes::from(combined))
200                }
201            }
202        })
203        .await
204    }
205
206    /// Fetch byte ranges into a HashMap keyed by the range start. This will mutably sort the
207    /// `ranges` slice for coalescing.
208    ///
209    /// # Panics
210    /// Panics if the same range start is used by more than 1 range.
211    pub async fn get_ranges_sort<
212        K: TryFrom<usize, Error = impl std::fmt::Debug> + std::hash::Hash + Eq,
213        T: From<Bytes>,
214    >(
215        &self,
216        path: &Path,
217        ranges: &mut [Range<usize>],
218    ) -> PolarsResult<PlHashMap<K, T>> {
219        if ranges.is_empty() {
220            return Ok(Default::default());
221        }
222
223        ranges.sort_unstable_by_key(|x| x.start);
224
225        let ranges_len = ranges.len();
226        let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();
227
228        self.try_exec_rebuild_on_err(|store| {
229            let st = store.clone();
230
231            async {
232                let store = st;
233                let mut out = PlHashMap::with_capacity(ranges_len);
234
235                let mut stream =
236                    Self::get_buffered_ranges_stream(&store, path, merged_ranges.iter().cloned());
237
238                tune_with_concurrency_budget(
239                    merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
240                    || async {
241                        let mut len = 0;
242                        let mut current_offset = 0;
243                        let mut ends_iter = merged_ends.iter();
244
245                        let mut splitted_parts = vec![];
246
247                        while let Some(bytes) = stream.try_next().await? {
248                            len += bytes.len();
249                            let end = *ends_iter.next().unwrap();
250
251                            if end == 0 {
252                                splitted_parts.push(bytes);
253                                continue;
254                            }
255
256                            let full_range = ranges[current_offset..end]
257                                .iter()
258                                .cloned()
259                                .reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))
260                                .unwrap();
261
262                            let bytes = if splitted_parts.is_empty() {
263                                bytes
264                            } else {
265                                let mut out = Vec::with_capacity(full_range.len());
266
267                                for x in splitted_parts.drain(..) {
268                                    out.extend_from_slice(&x);
269                                }
270
271                                out.extend_from_slice(&bytes);
272                                Bytes::from(out)
273                            };
274
275                            assert_eq!(bytes.len(), full_range.len());
276
277                            for range in &ranges[current_offset..end] {
278                                let v = out.insert(
279                                    K::try_from(range.start).unwrap(),
280                                    T::from(bytes.slice(
281                                        range.start - full_range.start
282                                            ..range.end - full_range.start,
283                                    )),
284                                );
285
286                                assert!(v.is_none()); // duplicate range start
287                            }
288
289                            current_offset = end;
290                        }
291
292                        assert!(splitted_parts.is_empty());
293
294                        PolarsResult::Ok(pl_async::Size::from(len as u64))
295                    },
296                )
297                .await?;
298
299                Ok(out)
300            }
301        })
302        .await
303    }
304
305    pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {
306        let opt_size = self.head(path).await.ok().map(|x| x.size);
307
308        let initial_pos = file.stream_position().await?;
309
310        self.try_exec_rebuild_on_err(|store| {
311            let st = store.clone();
312
313            // Workaround for "can't move captured variable".
314            let file: &mut tokio::fs::File = unsafe { std::mem::transmute_copy(&file) };
315
316            async {
317                file.set_len(initial_pos).await?; // Reset if this function was called again.
318
319                let store = st;
320                let parts = opt_size.map(|x| split_range(0..x)).filter(|x| x.len() > 1);
321
322                if let Some(parts) = parts {
323                    tune_with_concurrency_budget(
324                        parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
325                        || async {
326                            let mut stream = Self::get_buffered_ranges_stream(&store, path, parts);
327                            let mut len = 0;
328                            while let Some(bytes) = stream.try_next().await? {
329                                len += bytes.len();
330                                file.write_all(&bytes).await.map_err(to_compute_err)?;
331                            }
332
333                            assert_eq!(len, opt_size.unwrap());
334
335                            PolarsResult::Ok(pl_async::Size::from(len as u64))
336                        },
337                    )
338                    .await?
339                } else {
340                    tune_with_concurrency_budget(1, || async {
341                        let mut stream =
342                            store.get(path).await.map_err(to_compute_err)?.into_stream();
343
344                        let mut len = 0;
345                        while let Some(bytes) = stream.try_next().await? {
346                            len += bytes.len();
347                            file.write_all(&bytes).await.map_err(to_compute_err)?;
348                        }
349
350                        PolarsResult::Ok(pl_async::Size::from(len as u64))
351                    })
352                    .await?
353                };
354
355                // Dropping is delayed for tokio async files so we need to explicitly
356                // flush here (https://github.com/tokio-rs/tokio/issues/2307#issuecomment-596336451).
357                file.sync_all().await.map_err(PolarsError::from)?;
358
359                Ok(())
360            }
361        })
362        .await
363    }
364
365    /// Fetch the metadata of the parquet file, do not memoize it.
366    pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {
367        self.try_exec_rebuild_on_err(|store| {
368            let st = store.clone();
369
370            async {
371                with_concurrency_budget(1, || async {
372                    let store = st;
373                    let head_result = store.head(path).await;
374
375                    if head_result.is_err() {
376                        // Pre-signed URLs forbid the HEAD method, but we can still retrieve the header
377                        // information with a range 0-0 request.
378                        let get_range_0_0_result = store
379                            .get_opts(
380                                path,
381                                object_store::GetOptions {
382                                    range: Some((0..1).into()),
383                                    ..Default::default()
384                                },
385                            )
386                            .await;
387
388                        if let Ok(v) = get_range_0_0_result {
389                            return Ok(v.meta);
390                        }
391                    }
392
393                    head_result
394                })
395                .await
396                .map_err(to_compute_err)
397            }
398        })
399        .await
400    }
401}
402
403/// Splits a single range into multiple smaller ranges, which can be downloaded concurrently for
404/// much higher throughput.
405fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {
406    let chunk_size = get_download_chunk_size();
407
408    // Calculate n_parts such that we are as close as possible to the `chunk_size`.
409    let n_parts = [
410        (range.len().div_ceil(chunk_size)).max(1),
411        (range.len() / chunk_size).max(1),
412    ]
413    .into_iter()
414    .min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))
415    .unwrap();
416
417    let chunk_size = (range.len() / n_parts).max(1);
418
419    assert_eq!(n_parts, (range.len() / chunk_size).max(1));
420    let bytes_rem = range.len() % chunk_size;
421
422    (0..n_parts).map(move |part_no| {
423        let (start, end) = if part_no == 0 {
424            // Download remainder length in the first chunk since it starts downloading first.
425            let end = range.start + chunk_size + bytes_rem;
426            let end = if end > range.end { range.end } else { end };
427            (range.start, end)
428        } else {
429            let start = bytes_rem + range.start + part_no * chunk_size;
430            (start, start + chunk_size)
431        };
432
433        start..end
434    })
435}
436
437/// Note: For optimal performance, `ranges` should be sorted. More generally,
438/// ranges placed next to each other should also be close in range value.
439///
440/// # Returns
441/// `[(range1, end1), (range2, end2)]`, where:
442/// * `range1` contains bytes for the ranges from `ranges[0..end1]`
443/// * `range2` contains bytes for the ranges from `ranges[end1..end2]`
444/// * etc..
445///
446/// Note that if an end value is 0, it means the range is a splitted part and should be combined.
447fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {
448    let chunk_size = get_download_chunk_size();
449
450    let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);
451    // Number of fetched bytes excluding excess.
452    let mut current_n_bytes = current_merged_range.len();
453
454    (0..ranges.len())
455        .filter_map(move |current_idx| {
456            let current_idx = 1 + current_idx;
457
458            if current_idx == ranges.len() {
459                // No more items - flush current state.
460                Some((current_merged_range.clone(), current_idx))
461            } else {
462                let range = ranges[current_idx].clone();
463
464                let new_merged = current_merged_range.start.min(range.start)
465                    ..current_merged_range.end.max(range.end);
466
467                // E.g.:
468                // |--------|
469                //  oo        // range1
470                //       oo   // range2
471                //    ^^^     // distance = 3, is_overlapping = false
472                // E.g.:
473                // |--------|
474                //  ooooo     // range1
475                //     ooooo  // range2
476                //     ^^     // distance = 2, is_overlapping = true
477                let (distance, is_overlapping) = {
478                    let l = current_merged_range.end.min(range.end);
479                    let r = current_merged_range.start.max(range.start);
480
481                    (r.abs_diff(l), r < l)
482                };
483
484                let should_merge = is_overlapping || {
485                    let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)
486                        <= current_merged_range.len().abs_diff(chunk_size);
487                    let gap_tolerance =
488                        (current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);
489
490                    leq_current_len_dist_to_chunk_size && distance <= gap_tolerance
491                };
492
493                if should_merge {
494                    // Merge to existing range
495                    current_merged_range = new_merged;
496                    current_n_bytes += if is_overlapping {
497                        range.len() - distance
498                    } else {
499                        range.len()
500                    };
501                    None
502                } else {
503                    let out = (current_merged_range.clone(), current_idx);
504                    current_merged_range = range;
505                    current_n_bytes = current_merged_range.len();
506                    Some(out)
507                }
508            }
509        })
510        .flat_map(|x| {
511            // Split large individual ranges within the list of ranges.
512            let (range, end) = x;
513            let split = split_range(range.clone());
514            let len = split.len();
515
516            split
517                .enumerate()
518                .map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))
519        })
520}
521
522#[cfg(test)]
523mod tests {
524
525    #[test]
526    fn test_split_range() {
527        use super::{get_download_chunk_size, split_range};
528
529        let chunk_size = get_download_chunk_size();
530
531        assert_eq!(chunk_size, 64 * 1024 * 1024);
532
533        #[allow(clippy::single_range_in_vec_init)]
534        {
535            // Round-trip empty ranges.
536            assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);
537            assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);
538        }
539
540        // Threshold to start splitting to 2 ranges
541        //
542        // n - chunk_size == chunk_size - n / 2
543        // n + n / 2 == 2 * chunk_size
544        // 3 * n == 4 * chunk_size
545        // n = 4 * chunk_size / 3
546        let n = 4 * chunk_size / 3;
547
548        #[allow(clippy::single_range_in_vec_init)]
549        {
550            assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);
551        }
552
553        assert_eq!(
554            split_range(0..n + 1).collect::<Vec<_>>(),
555            [0..44739243, 44739243..89478486]
556        );
557
558        // Threshold to start splitting to 3 ranges
559        //
560        // n / 2 - chunk_size == chunk_size - n / 3
561        // n / 2 + n / 3 == 2 * chunk_size
562        // 5 * n == 12 * chunk_size
563        // n == 12 * chunk_size / 5
564        let n = 12 * chunk_size / 5;
565
566        assert_eq!(
567            split_range(0..n).collect::<Vec<_>>(),
568            [0..80530637, 80530637..161061273]
569        );
570
571        assert_eq!(
572            split_range(0..n + 1).collect::<Vec<_>>(),
573            [0..53687092, 53687092..107374183, 107374183..161061274]
574        );
575    }
576
577    #[test]
578    fn test_merge_ranges() {
579        use super::{get_download_chunk_size, merge_ranges};
580
581        let chunk_size = get_download_chunk_size();
582
583        assert_eq!(chunk_size, 64 * 1024 * 1024);
584
585        // Round-trip empty slice
586        assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);
587
588        // We have 1 tiny request followed by 1 huge request. They are combined as it reduces the
589        // `abs_diff()` to the `chunk_size`, but afterwards they are split to 2 evenly sized
590        // requests.
591        assert_eq!(
592            merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),
593            [(0..66584576, 0), (66584576..133169152, 2)]
594        );
595
596        // <= 1MiB gap, merge
597        assert_eq!(
598            merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),
599            [(0..1048578, 2)]
600        );
601
602        // > 1MiB gap, do not merge
603        assert_eq!(
604            merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),
605            [(0..1, 1), (1048578..1048579, 2)]
606        );
607
608        // <= 12.5% gap, merge
609        assert_eq!(
610            merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),
611            [(0..11, 2)]
612        );
613
614        // <= 12.5% gap relative to RHS, merge
615        assert_eq!(
616            merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),
617            [(0..11, 2)]
618        );
619
620        // Overlapping range, merge
621        assert_eq!(
622            merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])
623                .collect::<Vec<_>>(),
624            [(0..80 * 1024 * 1024, 2)]
625        );
626    }
627}