Skip to main content

polars_io/cloud/
polars_object_store.rs

1use std::fmt::Display;
2use std::ops::Range;
3use std::sync::Arc;
4
5use futures::{Stream, StreamExt as _, TryStreamExt as _};
6use hashbrown::hash_map::RawEntryMut;
7use object_store::path::Path;
8use object_store::{ObjectMeta, ObjectStore, ObjectStoreExt};
9use polars_buffer::Buffer;
10use polars_core::prelude::{InitHashMaps, PlHashMap};
11use polars_error::{PolarsError, PolarsResult};
12use polars_utils::pl_path::PlRefPath;
13use tokio::io::AsyncWriteExt;
14
15use crate::pl_async::{
16    self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,
17    tune_with_concurrency_budget, with_concurrency_budget,
18};
19
20#[derive(Debug)]
21pub struct PolarsObjectStoreError {
22    pub base_url: PlRefPath,
23    pub source: object_store::Error,
24}
25
26impl PolarsObjectStoreError {
27    pub fn from_url(base_url: &PlRefPath) -> impl FnOnce(object_store::Error) -> Self {
28        |error| Self {
29            base_url: base_url.clone(),
30            source: error,
31        }
32    }
33}
34
35impl Display for PolarsObjectStoreError {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        write!(
38            f,
39            "object-store error: {} (path: {})",
40            self.source, &self.base_url
41        )
42    }
43}
44
45impl std::error::Error for PolarsObjectStoreError {
46    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
47        Some(&self.source)
48    }
49}
50
51impl From<PolarsObjectStoreError> for std::io::Error {
52    fn from(value: PolarsObjectStoreError) -> Self {
53        std::io::Error::other(value)
54    }
55}
56
57impl From<PolarsObjectStoreError> for PolarsError {
58    fn from(value: PolarsObjectStoreError) -> Self {
59        PolarsError::IO {
60            error: Arc::new(value.into()),
61            msg: None,
62        }
63    }
64}
65
66mod inner {
67
68    use std::borrow::Cow;
69    use std::future::Future;
70    use std::sync::Arc;
71
72    use object_store::ObjectStore;
73    use polars_core::config;
74    use polars_error::{PolarsError, PolarsResult};
75    use polars_utils::relaxed_cell::RelaxedCell;
76
77    use crate::cloud::{ObjectStoreErrorContext, PolarsObjectStoreBuilder};
78    use crate::metrics::{IOMetrics, OptIOMetrics};
79
80    #[derive(Debug)]
81    struct Inner {
82        store: tokio::sync::RwLock<Arc<dyn ObjectStore>>,
83        builder: PolarsObjectStoreBuilder,
84        rebuilt: RelaxedCell<bool>,
85    }
86
87    /// Polars wrapper around [`ObjectStore`] functionality. This struct is cheaply cloneable.
88    #[derive(Clone, Debug)]
89    pub struct PolarsObjectStore {
90        inner: Arc<Inner>,
91        /// Avoid contending the Mutex `lock()` until the first re-build.
92        initial_store: std::sync::Arc<dyn ObjectStore>,
93        io_metrics: OptIOMetrics,
94    }
95
96    impl PolarsObjectStore {
97        pub(crate) fn new_from_inner(
98            store: Arc<dyn ObjectStore>,
99            builder: PolarsObjectStoreBuilder,
100        ) -> Self {
101            let initial_store = store.clone();
102            Self {
103                inner: Arc::new(Inner {
104                    store: tokio::sync::RwLock::new(store),
105                    builder,
106                    rebuilt: RelaxedCell::from(false),
107                }),
108                initial_store,
109                io_metrics: OptIOMetrics(None),
110            }
111        }
112
113        pub fn set_io_metrics(&mut self, io_metrics: Option<Arc<IOMetrics>>) -> &mut Self {
114            self.io_metrics = OptIOMetrics(io_metrics);
115            self
116        }
117
118        pub fn io_metrics(&self) -> &OptIOMetrics {
119            &self.io_metrics
120        }
121
122        /// Gets the underlying [`ObjectStore`] implementation.
123        pub async fn to_dyn_object_store(&self) -> Cow<'_, Arc<dyn ObjectStore>> {
124            if !self.inner.rebuilt.load() {
125                Cow::Borrowed(&self.initial_store)
126            } else {
127                Cow::Owned(self.inner.store.read().await.clone())
128            }
129        }
130
131        pub async fn rebuild_inner(
132            &self,
133            from_version: &Arc<dyn ObjectStore>,
134        ) -> PolarsResult<Arc<dyn ObjectStore>> {
135            let mut current_store = self.inner.store.write().await;
136
137            // If this does not eq, then `inner` was already re-built by another thread.
138            if Arc::ptr_eq(&*current_store, from_version) {
139                *current_store =
140                    self.inner
141                        .builder
142                        .clone()
143                        .build_impl(true)
144                        .await
145                        .map_err(|e| {
146                            e.wrap_msg(|e| format!("attempt to rebuild object store failed: {e}"))
147                        })?;
148            }
149
150            self.inner.rebuilt.store(true);
151
152            Ok((*current_store).clone())
153        }
154
155        pub async fn exec_with_rebuild_retry_on_err<'s, 'f, Fn, Fut, O>(
156            &'s self,
157            mut func: Fn,
158        ) -> PolarsResult<O>
159        where
160            Fn: FnMut(Cow<'s, Arc<dyn ObjectStore>>) -> Fut + 'f,
161            Fut: Future<Output = object_store::Result<O>>,
162        {
163            let store = self.to_dyn_object_store().await;
164
165            let out = func(store.clone()).await;
166
167            let orig_err = match out {
168                Ok(v) => return Ok(v),
169                Err(e) => e,
170            };
171
172            if config::verbose() {
173                eprintln!(
174                    "[PolarsObjectStore]: got error: {}, will rebuild store and retry",
175                    &orig_err
176                );
177            }
178
179            let store = self
180                .rebuild_inner(&store)
181                .await
182                .map_err(|e| e.wrap_msg(|e| format!("{e}; original error: {orig_err}")))?;
183
184            func(Cow::Owned(store)).await.map_err(|e| {
185                let e: PolarsError = self.error_context().attach_err_info(e).into();
186
187                if self.inner.builder.is_azure()
188                    && std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()
189                        != Ok("1")
190                {
191                    // Note: This error is intended for Python audiences. The logic for retrieving
192                    // these keys exist only on the Python side.
193                    e.wrap_msg(|e| {
194                        format!(
195                            "{e}; note: if you are using Python, consider setting \
196POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \
197and use the storage account keys from Azure CLI to authenticate"
198                        )
199                    })
200                } else {
201                    e
202                }
203            })
204        }
205
206        pub fn error_context(&self) -> ObjectStoreErrorContext {
207            ObjectStoreErrorContext::new(self.inner.builder.path().clone())
208        }
209    }
210}
211
212#[derive(Clone)]
213pub struct ObjectStoreErrorContext {
214    path: PlRefPath,
215}
216
217impl ObjectStoreErrorContext {
218    pub fn new(path: PlRefPath) -> Self {
219        Self { path }
220    }
221
222    pub fn attach_err_info(self, err: object_store::Error) -> PolarsObjectStoreError {
223        let ObjectStoreErrorContext { path } = self;
224
225        PolarsObjectStoreError {
226            base_url: path,
227            source: err,
228        }
229    }
230}
231
232pub use inner::PolarsObjectStore;
233
234pub type ObjectStorePath = object_store::path::Path;
235
236impl PolarsObjectStore {
237    pub fn build_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(
238        &'a self,
239        path: &'a Path,
240        ranges: T,
241    ) -> impl Stream<Item = PolarsResult<Buffer<u8>>> + use<'a, T> {
242        futures::stream::iter(ranges.map(move |range| async move {
243            if range.is_empty() {
244                return Ok(Buffer::new());
245            }
246
247            let out = self
248                .io_metrics()
249                .record_io_read(
250                    range.len() as u64,
251                    self.exec_with_rebuild_retry_on_err(|s| async move {
252                        s.get_range(path, range.start as u64..range.end as u64)
253                            .await
254                    }),
255                )
256                .await?;
257
258            Ok(Buffer::from_owner(out))
259        }))
260        // Add a limit locally as this gets run inside a single `tune_with_concurrency_budget`.
261        .buffered(get_concurrency_limit() as usize)
262    }
263
264    pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Buffer<u8>> {
265        if range.is_empty() {
266            return Ok(Buffer::new());
267        }
268
269        let parts = split_range(range.clone());
270
271        if parts.len() == 1 {
272            let out = tune_with_concurrency_budget(1, move || async move {
273                let bytes = self
274                    .io_metrics()
275                    .record_io_read(
276                        range.len() as u64,
277                        self.exec_with_rebuild_retry_on_err(|s| async move {
278                            s.get_range(path, range.start as u64..range.end as u64)
279                                .await
280                        }),
281                    )
282                    .await?;
283
284                PolarsResult::Ok(Buffer::from_owner(bytes))
285            })
286            .await?;
287
288            Ok(out)
289        } else {
290            let parts = tune_with_concurrency_budget(
291                parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
292                || {
293                    self.build_buffered_ranges_stream(path, parts)
294                        .try_collect::<Vec<Buffer<u8>>>()
295                },
296            )
297            .await?;
298
299            let mut combined = Vec::with_capacity(range.len());
300
301            for part in parts {
302                combined.extend_from_slice(&part)
303            }
304
305            assert_eq!(combined.len(), range.len());
306
307            PolarsResult::Ok(Buffer::from_vec(combined))
308        }
309    }
310
311    /// Fetch byte ranges into a HashMap keyed by the range start. This will mutably sort the
312    /// `ranges` slice for coalescing.
313    ///
314    /// # Panics
315    /// Panics if the same range start is used by more than 1 range.
316    pub async fn get_ranges_sort(
317        &self,
318        path: &Path,
319        ranges: &mut [Range<usize>],
320    ) -> PolarsResult<PlHashMap<usize, Buffer<u8>>> {
321        if ranges.is_empty() {
322            return Ok(Default::default());
323        }
324
325        ranges.sort_unstable_by_key(|x| x.start);
326
327        let ranges_len = ranges.len();
328        let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();
329
330        let mut out = PlHashMap::with_capacity(ranges_len);
331
332        let mut stream = self.build_buffered_ranges_stream(path, merged_ranges.iter().cloned());
333
334        tune_with_concurrency_budget(
335            merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
336            || async {
337                let mut len = 0;
338                let mut current_offset = 0;
339                let mut ends_iter = merged_ends.iter();
340
341                let mut splitted_parts = vec![];
342
343                while let Some(bytes) = stream.try_next().await? {
344                    len += bytes.len();
345                    let end = *ends_iter.next().unwrap();
346
347                    if end == 0 {
348                        splitted_parts.push(bytes);
349                        continue;
350                    }
351
352                    let full_range = ranges[current_offset..end]
353                        .iter()
354                        .cloned()
355                        .reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))
356                        .unwrap();
357
358                    let bytes = if splitted_parts.is_empty() {
359                        bytes
360                    } else {
361                        let mut out = Vec::with_capacity(full_range.len());
362
363                        for x in splitted_parts.drain(..) {
364                            out.extend_from_slice(&x);
365                        }
366
367                        out.extend_from_slice(&bytes);
368                        Buffer::from(out)
369                    };
370
371                    assert_eq!(bytes.len(), full_range.len());
372
373                    for range in &ranges[current_offset..end] {
374                        let slice = bytes
375                            .clone()
376                            .sliced(range.start - full_range.start..range.end - full_range.start);
377
378                        match out.raw_entry_mut().from_key(&range.start) {
379                            RawEntryMut::Vacant(slot) => {
380                                slot.insert(range.start, slice);
381                            },
382                            RawEntryMut::Occupied(mut slot) => {
383                                if slot.get_mut().len() < slice.len() {
384                                    *slot.get_mut() = slice;
385                                }
386                            },
387                        }
388                    }
389
390                    current_offset = end;
391                }
392
393                assert!(splitted_parts.is_empty());
394
395                PolarsResult::Ok(pl_async::Size::from(len as u64))
396            },
397        )
398        .await?;
399
400        Ok(out)
401    }
402
403    pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {
404        let size = self.head(path).await?.size;
405        let parts = split_range(0..size as usize);
406
407        tune_with_concurrency_budget(
408            parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
409            || async {
410                let mut stream = self.build_buffered_ranges_stream(path, parts);
411                let mut len = 0;
412                while let Some(bytes) = stream.try_next().await? {
413                    len += bytes.len();
414                    file.write_all(&bytes).await?;
415                }
416
417                assert_eq!(len, size as usize);
418
419                PolarsResult::Ok(pl_async::Size::from(len as u64))
420            },
421        )
422        .await?;
423
424        // Dropping is delayed for tokio async files so we need to explicitly
425        // flush here (https://github.com/tokio-rs/tokio/issues/2307#issuecomment-596336451).
426        file.sync_all().await.map_err(PolarsError::from)?;
427
428        Ok(())
429    }
430
431    /// Fetch the metadata of the parquet file, do not memoize it.
432    pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {
433        with_concurrency_budget(1, || {
434            self.exec_with_rebuild_retry_on_err(|s| {
435                async move {
436                    let head_result = self.io_metrics().record_io_read(0, s.head(path)).await;
437
438                    if head_result.is_err() {
439                        // Pre-signed URLs forbid the HEAD method, but we can still retrieve the header
440                        // information with a range 0-1 request.
441                        let get_range_0_1_result = self
442                            .io_metrics()
443                            .record_io_read(
444                                0,
445                                s.get_opts(
446                                    path,
447                                    object_store::GetOptions {
448                                        range: Some((0..1).into()),
449                                        ..Default::default()
450                                    },
451                                ),
452                            )
453                            .await;
454
455                        if let Ok(v) = get_range_0_1_result {
456                            return Ok(v.meta);
457                        }
458                    }
459
460                    let out = head_result?;
461
462                    Ok(out)
463                }
464            })
465        })
466        .await
467    }
468}
469
470/// Splits a single range into multiple smaller ranges, which can be downloaded concurrently for
471/// much higher throughput.
472fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {
473    let chunk_size = get_download_chunk_size();
474
475    // Calculate n_parts such that we are as close as possible to the `chunk_size`.
476    let n_parts = [
477        (range.len().div_ceil(chunk_size)).max(1),
478        (range.len() / chunk_size).max(1),
479    ]
480    .into_iter()
481    .min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))
482    .unwrap();
483
484    let chunk_size = (range.len() / n_parts).max(1);
485
486    assert_eq!(n_parts, (range.len() / chunk_size).max(1));
487    let bytes_rem = range.len() % chunk_size;
488
489    (0..n_parts).map(move |part_no| {
490        let (start, end) = if part_no == 0 {
491            // Download remainder length in the first chunk since it starts downloading first.
492            let end = range.start + chunk_size + bytes_rem;
493            let end = if end > range.end { range.end } else { end };
494            (range.start, end)
495        } else {
496            let start = bytes_rem + range.start + part_no * chunk_size;
497            (start, start + chunk_size)
498        };
499
500        start..end
501    })
502}
503
504/// Note: For optimal performance, `ranges` should be sorted. More generally,
505/// ranges placed next to each other should also be close in range value.
506///
507/// # Returns
508/// `[(range1, end1), (range2, end2)]`, where:
509/// * `range1` contains bytes for the ranges from `ranges[0..end1]`
510/// * `range2` contains bytes for the ranges from `ranges[end1..end2]`
511/// * etc..
512///
513/// Note that if an end value is 0, it means the range is a splitted part and should be combined.
514fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {
515    let chunk_size = get_download_chunk_size();
516
517    let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);
518    // Number of fetched bytes excluding excess.
519    let mut current_n_bytes = current_merged_range.len();
520
521    (0..ranges.len())
522        .filter_map(move |current_idx| {
523            let current_idx = 1 + current_idx;
524
525            if current_idx == ranges.len() {
526                // No more items - flush current state.
527                Some((current_merged_range.clone(), current_idx))
528            } else {
529                let range = ranges[current_idx].clone();
530
531                let new_merged = current_merged_range.start.min(range.start)
532                    ..current_merged_range.end.max(range.end);
533
534                // E.g.:
535                // |--------|
536                //  oo        // range1
537                //       oo   // range2
538                //    ^^^     // distance = 3, is_overlapping = false
539                // E.g.:
540                // |--------|
541                //  ooooo     // range1
542                //     ooooo  // range2
543                //     ^^     // distance = 2, is_overlapping = true
544                let (distance, is_overlapping) = {
545                    let l = current_merged_range.end.min(range.end);
546                    let r = current_merged_range.start.max(range.start);
547
548                    (r.abs_diff(l), r < l)
549                };
550
551                let should_merge = is_overlapping || {
552                    let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)
553                        <= current_merged_range.len().abs_diff(chunk_size);
554                    let gap_tolerance =
555                        (current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);
556
557                    leq_current_len_dist_to_chunk_size && distance <= gap_tolerance
558                };
559
560                if should_merge {
561                    // Merge to existing range
562                    current_merged_range = new_merged;
563                    current_n_bytes += if is_overlapping {
564                        range.len() - distance
565                    } else {
566                        range.len()
567                    };
568                    None
569                } else {
570                    let out = (current_merged_range.clone(), current_idx);
571                    current_merged_range = range;
572                    current_n_bytes = current_merged_range.len();
573                    Some(out)
574                }
575            }
576        })
577        .flat_map(|x| {
578            // Split large individual ranges within the list of ranges.
579            let (range, end) = x;
580            let split = split_range(range);
581            let len = split.len();
582
583            split
584                .enumerate()
585                .map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))
586        })
587}
588
589#[cfg(test)]
590mod tests {
591
592    #[test]
593    fn test_split_range() {
594        use super::{get_download_chunk_size, split_range};
595
596        let chunk_size = get_download_chunk_size();
597
598        assert_eq!(chunk_size, 64 * 1024 * 1024);
599
600        #[allow(clippy::single_range_in_vec_init)]
601        {
602            // Round-trip empty ranges.
603            assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);
604            assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);
605        }
606
607        // Threshold to start splitting to 2 ranges
608        //
609        // n - chunk_size == chunk_size - n / 2
610        // n + n / 2 == 2 * chunk_size
611        // 3 * n == 4 * chunk_size
612        // n = 4 * chunk_size / 3
613        let n = 4 * chunk_size / 3;
614
615        #[allow(clippy::single_range_in_vec_init)]
616        {
617            assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);
618        }
619
620        assert_eq!(
621            split_range(0..n + 1).collect::<Vec<_>>(),
622            [0..44739243, 44739243..89478486]
623        );
624
625        // Threshold to start splitting to 3 ranges
626        //
627        // n / 2 - chunk_size == chunk_size - n / 3
628        // n / 2 + n / 3 == 2 * chunk_size
629        // 5 * n == 12 * chunk_size
630        // n == 12 * chunk_size / 5
631        let n = 12 * chunk_size / 5;
632
633        assert_eq!(
634            split_range(0..n).collect::<Vec<_>>(),
635            [0..80530637, 80530637..161061273]
636        );
637
638        assert_eq!(
639            split_range(0..n + 1).collect::<Vec<_>>(),
640            [0..53687092, 53687092..107374183, 107374183..161061274]
641        );
642    }
643
644    #[test]
645    fn test_merge_ranges() {
646        use super::{get_download_chunk_size, merge_ranges};
647
648        let chunk_size = get_download_chunk_size();
649
650        assert_eq!(chunk_size, 64 * 1024 * 1024);
651
652        // Round-trip empty slice
653        assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);
654
655        // We have 1 tiny request followed by 1 huge request. They are combined as it reduces the
656        // `abs_diff()` to the `chunk_size`, but afterwards they are split to 2 evenly sized
657        // requests.
658        assert_eq!(
659            merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),
660            [(0..66584576, 0), (66584576..133169152, 2)]
661        );
662
663        // <= 1MiB gap, merge
664        assert_eq!(
665            merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),
666            [(0..1048578, 2)]
667        );
668
669        // > 1MiB gap, do not merge
670        assert_eq!(
671            merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),
672            [(0..1, 1), (1048578..1048579, 2)]
673        );
674
675        // <= 12.5% gap, merge
676        assert_eq!(
677            merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),
678            [(0..11, 2)]
679        );
680
681        // <= 12.5% gap relative to RHS, merge
682        assert_eq!(
683            merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),
684            [(0..11, 2)]
685        );
686
687        // Overlapping range, merge
688        assert_eq!(
689            merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])
690                .collect::<Vec<_>>(),
691            [(0..80 * 1024 * 1024, 2)]
692        );
693    }
694}