1use std::ops::Range;
2
3use bytes::Bytes;
4use futures::{StreamExt, TryStreamExt};
5use object_store::path::Path;
6use object_store::{ObjectMeta, ObjectStore};
7use polars_core::prelude::{InitHashMaps, PlHashMap};
8use polars_error::{PolarsError, PolarsResult, to_compute_err};
9use tokio::io::{AsyncSeekExt, AsyncWriteExt};
10
11use crate::pl_async::{
12 self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,
13 tune_with_concurrency_budget, with_concurrency_budget,
14};
15
16mod inner {
17 use std::future::Future;
18 use std::sync::Arc;
19 use std::sync::atomic::AtomicBool;
20
21 use object_store::ObjectStore;
22 use polars_core::config;
23 use polars_error::PolarsResult;
24
25 use crate::cloud::PolarsObjectStoreBuilder;
26
27 #[derive(Debug)]
28 struct Inner {
29 store: tokio::sync::Mutex<Arc<dyn ObjectStore>>,
30 builder: PolarsObjectStoreBuilder,
31 }
32
33 #[derive(Debug)]
35 pub struct PolarsObjectStore {
36 inner: Arc<Inner>,
37 initial_store: std::sync::Arc<dyn ObjectStore>,
39 rebuilt: AtomicBool,
42 }
43
44 impl Clone for PolarsObjectStore {
45 fn clone(&self) -> Self {
46 Self {
47 inner: self.inner.clone(),
48 initial_store: self.initial_store.clone(),
49 rebuilt: AtomicBool::new(self.rebuilt.load(std::sync::atomic::Ordering::Relaxed)),
50 }
51 }
52 }
53
54 impl PolarsObjectStore {
55 pub(crate) fn new_from_inner(
56 store: Arc<dyn ObjectStore>,
57 builder: PolarsObjectStoreBuilder,
58 ) -> Self {
59 let initial_store = store.clone();
60 Self {
61 inner: Arc::new(Inner {
62 store: tokio::sync::Mutex::new(store),
63 builder,
64 }),
65 initial_store,
66 rebuilt: AtomicBool::new(false),
67 }
68 }
69
70 pub async fn to_dyn_object_store(&self) -> Arc<dyn ObjectStore> {
72 if !self.rebuilt.load(std::sync::atomic::Ordering::Relaxed) {
73 self.initial_store.clone()
74 } else {
75 self.inner.store.lock().await.clone()
76 }
77 }
78
79 pub async fn rebuild_inner(
80 &self,
81 from_version: &Arc<dyn ObjectStore>,
82 ) -> PolarsResult<Arc<dyn ObjectStore>> {
83 let mut current_store = self.inner.store.lock().await;
84
85 self.rebuilt
86 .store(true, std::sync::atomic::Ordering::Relaxed);
87
88 if Arc::ptr_eq(&*current_store, from_version) {
90 *current_store = self.inner.builder.clone().build_impl().await.map_err(|e| {
91 e.wrap_msg(|e| format!("attempt to rebuild object store failed: {}", e))
92 })?;
93 }
94
95 Ok((*current_store).clone())
96 }
97
98 pub async fn try_exec_rebuild_on_err<Fn, Fut, O>(&self, mut func: Fn) -> PolarsResult<O>
99 where
100 Fn: FnMut(&Arc<dyn ObjectStore>) -> Fut,
101 Fut: Future<Output = PolarsResult<O>>,
102 {
103 let store = self.to_dyn_object_store().await;
104
105 let out = func(&store).await;
106
107 let orig_err = match out {
108 Ok(v) => return Ok(v),
109 Err(e) => e,
110 };
111
112 if config::verbose() {
113 eprintln!(
114 "[PolarsObjectStore]: got error: {}, will attempt re-build",
115 &orig_err
116 );
117 }
118
119 let store = self
120 .rebuild_inner(&store)
121 .await
122 .map_err(|e| e.wrap_msg(|e| format!("{}; original error: {}", e, orig_err)))?;
123
124 func(&store).await.map_err(|e| {
125 if self.inner.builder.is_azure()
126 && std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()
127 != Ok("1")
128 {
129 e.wrap_msg(|e| {
132 format!(
133 "{}; note: if you are using Python, consider setting \
134POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \
135and use the storage account keys from Azure CLI to authenticate",
136 e
137 )
138 })
139 } else {
140 e
141 }
142 })
143 }
144 }
145}
146
147pub use inner::PolarsObjectStore;
148
149pub type ObjectStorePath = object_store::path::Path;
150
151impl PolarsObjectStore {
152 fn get_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(
154 store: &'a dyn ObjectStore,
155 path: &'a Path,
156 ranges: T,
157 ) -> impl StreamExt<Item = PolarsResult<Bytes>>
158 + TryStreamExt<Ok = Bytes, Error = PolarsError, Item = PolarsResult<Bytes>>
159 + use<'a, T> {
160 futures::stream::iter(
161 ranges
162 .map(|range| async { store.get_range(path, range).await.map_err(to_compute_err) }),
163 )
164 .buffered(get_concurrency_limit() as usize)
166 }
167
168 pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Bytes> {
169 self.try_exec_rebuild_on_err(move |store| {
170 let range = range.clone();
171 let st = store.clone();
172
173 async {
174 let store = st;
175 let parts = split_range(range.clone());
176
177 if parts.len() == 1 {
178 tune_with_concurrency_budget(1, || async { store.get_range(path, range).await })
179 .await
180 .map_err(to_compute_err)
181 } else {
182 let parts = tune_with_concurrency_budget(
183 parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
184 || {
185 Self::get_buffered_ranges_stream(&store, path, parts)
186 .try_collect::<Vec<Bytes>>()
187 },
188 )
189 .await?;
190
191 let mut combined = Vec::with_capacity(range.len());
192
193 for part in parts {
194 combined.extend_from_slice(&part)
195 }
196
197 assert_eq!(combined.len(), range.len());
198
199 PolarsResult::Ok(Bytes::from(combined))
200 }
201 }
202 })
203 .await
204 }
205
206 pub async fn get_ranges_sort<
212 K: TryFrom<usize, Error = impl std::fmt::Debug> + std::hash::Hash + Eq,
213 T: From<Bytes>,
214 >(
215 &self,
216 path: &Path,
217 ranges: &mut [Range<usize>],
218 ) -> PolarsResult<PlHashMap<K, T>> {
219 if ranges.is_empty() {
220 return Ok(Default::default());
221 }
222
223 ranges.sort_unstable_by_key(|x| x.start);
224
225 let ranges_len = ranges.len();
226 let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();
227
228 self.try_exec_rebuild_on_err(|store| {
229 let st = store.clone();
230
231 async {
232 let store = st;
233 let mut out = PlHashMap::with_capacity(ranges_len);
234
235 let mut stream =
236 Self::get_buffered_ranges_stream(&store, path, merged_ranges.iter().cloned());
237
238 tune_with_concurrency_budget(
239 merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
240 || async {
241 let mut len = 0;
242 let mut current_offset = 0;
243 let mut ends_iter = merged_ends.iter();
244
245 let mut splitted_parts = vec![];
246
247 while let Some(bytes) = stream.try_next().await? {
248 len += bytes.len();
249 let end = *ends_iter.next().unwrap();
250
251 if end == 0 {
252 splitted_parts.push(bytes);
253 continue;
254 }
255
256 let full_range = ranges[current_offset..end]
257 .iter()
258 .cloned()
259 .reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))
260 .unwrap();
261
262 let bytes = if splitted_parts.is_empty() {
263 bytes
264 } else {
265 let mut out = Vec::with_capacity(full_range.len());
266
267 for x in splitted_parts.drain(..) {
268 out.extend_from_slice(&x);
269 }
270
271 out.extend_from_slice(&bytes);
272 Bytes::from(out)
273 };
274
275 assert_eq!(bytes.len(), full_range.len());
276
277 for range in &ranges[current_offset..end] {
278 let v = out.insert(
279 K::try_from(range.start).unwrap(),
280 T::from(bytes.slice(
281 range.start - full_range.start
282 ..range.end - full_range.start,
283 )),
284 );
285
286 assert!(v.is_none()); }
288
289 current_offset = end;
290 }
291
292 assert!(splitted_parts.is_empty());
293
294 PolarsResult::Ok(pl_async::Size::from(len as u64))
295 },
296 )
297 .await?;
298
299 Ok(out)
300 }
301 })
302 .await
303 }
304
305 pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {
306 let opt_size = self.head(path).await.ok().map(|x| x.size);
307
308 let initial_pos = file.stream_position().await?;
309
310 self.try_exec_rebuild_on_err(|store| {
311 let st = store.clone();
312
313 let file: &mut tokio::fs::File = unsafe { std::mem::transmute_copy(&file) };
315
316 async {
317 file.set_len(initial_pos).await?; let store = st;
320 let parts = opt_size.map(|x| split_range(0..x)).filter(|x| x.len() > 1);
321
322 if let Some(parts) = parts {
323 tune_with_concurrency_budget(
324 parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
325 || async {
326 let mut stream = Self::get_buffered_ranges_stream(&store, path, parts);
327 let mut len = 0;
328 while let Some(bytes) = stream.try_next().await? {
329 len += bytes.len();
330 file.write_all(&bytes).await.map_err(to_compute_err)?;
331 }
332
333 assert_eq!(len, opt_size.unwrap());
334
335 PolarsResult::Ok(pl_async::Size::from(len as u64))
336 },
337 )
338 .await?
339 } else {
340 tune_with_concurrency_budget(1, || async {
341 let mut stream =
342 store.get(path).await.map_err(to_compute_err)?.into_stream();
343
344 let mut len = 0;
345 while let Some(bytes) = stream.try_next().await? {
346 len += bytes.len();
347 file.write_all(&bytes).await.map_err(to_compute_err)?;
348 }
349
350 PolarsResult::Ok(pl_async::Size::from(len as u64))
351 })
352 .await?
353 };
354
355 file.sync_all().await.map_err(PolarsError::from)?;
358
359 Ok(())
360 }
361 })
362 .await
363 }
364
365 pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {
367 self.try_exec_rebuild_on_err(|store| {
368 let st = store.clone();
369
370 async {
371 with_concurrency_budget(1, || async {
372 let store = st;
373 let head_result = store.head(path).await;
374
375 if head_result.is_err() {
376 let get_range_0_0_result = store
379 .get_opts(
380 path,
381 object_store::GetOptions {
382 range: Some((0..1).into()),
383 ..Default::default()
384 },
385 )
386 .await;
387
388 if let Ok(v) = get_range_0_0_result {
389 return Ok(v.meta);
390 }
391 }
392
393 head_result
394 })
395 .await
396 .map_err(to_compute_err)
397 }
398 })
399 .await
400 }
401}
402
403fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {
406 let chunk_size = get_download_chunk_size();
407
408 let n_parts = [
410 (range.len().div_ceil(chunk_size)).max(1),
411 (range.len() / chunk_size).max(1),
412 ]
413 .into_iter()
414 .min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))
415 .unwrap();
416
417 let chunk_size = (range.len() / n_parts).max(1);
418
419 assert_eq!(n_parts, (range.len() / chunk_size).max(1));
420 let bytes_rem = range.len() % chunk_size;
421
422 (0..n_parts).map(move |part_no| {
423 let (start, end) = if part_no == 0 {
424 let end = range.start + chunk_size + bytes_rem;
426 let end = if end > range.end { range.end } else { end };
427 (range.start, end)
428 } else {
429 let start = bytes_rem + range.start + part_no * chunk_size;
430 (start, start + chunk_size)
431 };
432
433 start..end
434 })
435}
436
437fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {
448 let chunk_size = get_download_chunk_size();
449
450 let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);
451 let mut current_n_bytes = current_merged_range.len();
453
454 (0..ranges.len())
455 .filter_map(move |current_idx| {
456 let current_idx = 1 + current_idx;
457
458 if current_idx == ranges.len() {
459 Some((current_merged_range.clone(), current_idx))
461 } else {
462 let range = ranges[current_idx].clone();
463
464 let new_merged = current_merged_range.start.min(range.start)
465 ..current_merged_range.end.max(range.end);
466
467 let (distance, is_overlapping) = {
478 let l = current_merged_range.end.min(range.end);
479 let r = current_merged_range.start.max(range.start);
480
481 (r.abs_diff(l), r < l)
482 };
483
484 let should_merge = is_overlapping || {
485 let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)
486 <= current_merged_range.len().abs_diff(chunk_size);
487 let gap_tolerance =
488 (current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);
489
490 leq_current_len_dist_to_chunk_size && distance <= gap_tolerance
491 };
492
493 if should_merge {
494 current_merged_range = new_merged;
496 current_n_bytes += if is_overlapping {
497 range.len() - distance
498 } else {
499 range.len()
500 };
501 None
502 } else {
503 let out = (current_merged_range.clone(), current_idx);
504 current_merged_range = range;
505 current_n_bytes = current_merged_range.len();
506 Some(out)
507 }
508 }
509 })
510 .flat_map(|x| {
511 let (range, end) = x;
513 let split = split_range(range.clone());
514 let len = split.len();
515
516 split
517 .enumerate()
518 .map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))
519 })
520}
521
522#[cfg(test)]
523mod tests {
524
525 #[test]
526 fn test_split_range() {
527 use super::{get_download_chunk_size, split_range};
528
529 let chunk_size = get_download_chunk_size();
530
531 assert_eq!(chunk_size, 64 * 1024 * 1024);
532
533 #[allow(clippy::single_range_in_vec_init)]
534 {
535 assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);
537 assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);
538 }
539
540 let n = 4 * chunk_size / 3;
547
548 #[allow(clippy::single_range_in_vec_init)]
549 {
550 assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);
551 }
552
553 assert_eq!(
554 split_range(0..n + 1).collect::<Vec<_>>(),
555 [0..44739243, 44739243..89478486]
556 );
557
558 let n = 12 * chunk_size / 5;
565
566 assert_eq!(
567 split_range(0..n).collect::<Vec<_>>(),
568 [0..80530637, 80530637..161061273]
569 );
570
571 assert_eq!(
572 split_range(0..n + 1).collect::<Vec<_>>(),
573 [0..53687092, 53687092..107374183, 107374183..161061274]
574 );
575 }
576
577 #[test]
578 fn test_merge_ranges() {
579 use super::{get_download_chunk_size, merge_ranges};
580
581 let chunk_size = get_download_chunk_size();
582
583 assert_eq!(chunk_size, 64 * 1024 * 1024);
584
585 assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);
587
588 assert_eq!(
592 merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),
593 [(0..66584576, 0), (66584576..133169152, 2)]
594 );
595
596 assert_eq!(
598 merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),
599 [(0..1048578, 2)]
600 );
601
602 assert_eq!(
604 merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),
605 [(0..1, 1), (1048578..1048579, 2)]
606 );
607
608 assert_eq!(
610 merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),
611 [(0..11, 2)]
612 );
613
614 assert_eq!(
616 merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),
617 [(0..11, 2)]
618 );
619
620 assert_eq!(
622 merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])
623 .collect::<Vec<_>>(),
624 [(0..80 * 1024 * 1024, 2)]
625 );
626 }
627}