1use std::ops::Range;
2
3use bytes::Bytes;
4use futures::{StreamExt, TryStreamExt};
5use hashbrown::hash_map::RawEntryMut;
6use object_store::path::Path;
7use object_store::{ObjectMeta, ObjectStore};
8use polars_core::prelude::{InitHashMaps, PlHashMap};
9use polars_error::{PolarsError, PolarsResult};
10use polars_utils::mmap::MemSlice;
11use tokio::io::{AsyncSeekExt, AsyncWriteExt};
12
13use crate::pl_async::{
14 self, MAX_BUDGET_PER_REQUEST, get_concurrency_limit, get_download_chunk_size,
15 tune_with_concurrency_budget, with_concurrency_budget,
16};
17
18mod inner {
19 use std::future::Future;
20 use std::sync::Arc;
21
22 use object_store::ObjectStore;
23 use polars_core::config;
24 use polars_error::PolarsResult;
25 use polars_utils::relaxed_cell::RelaxedCell;
26
27 use crate::cloud::PolarsObjectStoreBuilder;
28
29 #[derive(Debug)]
30 struct Inner {
31 store: tokio::sync::Mutex<Arc<dyn ObjectStore>>,
32 builder: PolarsObjectStoreBuilder,
33 }
34
35 #[derive(Clone, Debug)]
37 pub struct PolarsObjectStore {
38 inner: Arc<Inner>,
39 initial_store: std::sync::Arc<dyn ObjectStore>,
41 rebuilt: RelaxedCell<bool>,
44 }
45
46 impl PolarsObjectStore {
47 pub(crate) fn new_from_inner(
48 store: Arc<dyn ObjectStore>,
49 builder: PolarsObjectStoreBuilder,
50 ) -> Self {
51 let initial_store = store.clone();
52 Self {
53 inner: Arc::new(Inner {
54 store: tokio::sync::Mutex::new(store),
55 builder,
56 }),
57 initial_store,
58 rebuilt: RelaxedCell::from(false),
59 }
60 }
61
62 pub async fn to_dyn_object_store(&self) -> Arc<dyn ObjectStore> {
64 if !self.rebuilt.load() {
65 self.initial_store.clone()
66 } else {
67 self.inner.store.lock().await.clone()
68 }
69 }
70
71 pub async fn rebuild_inner(
72 &self,
73 from_version: &Arc<dyn ObjectStore>,
74 ) -> PolarsResult<Arc<dyn ObjectStore>> {
75 let mut current_store = self.inner.store.lock().await;
76
77 if Arc::ptr_eq(&*current_store, from_version) {
79 *current_store =
80 self.inner
81 .builder
82 .clone()
83 .build_impl(true)
84 .await
85 .map_err(|e| {
86 e.wrap_msg(|e| format!("attempt to rebuild object store failed: {e}"))
87 })?;
88 }
89
90 self.rebuilt.store(true);
91
92 Ok((*current_store).clone())
93 }
94
95 pub async fn try_exec_rebuild_on_err<Fn, Fut, O>(&self, mut func: Fn) -> PolarsResult<O>
96 where
97 Fn: FnMut(&Arc<dyn ObjectStore>) -> Fut,
98 Fut: Future<Output = PolarsResult<O>>,
99 {
100 let store = self.to_dyn_object_store().await;
101
102 let out = func(&store).await;
103
104 let orig_err = match out {
105 Ok(v) => return Ok(v),
106 Err(e) => e,
107 };
108
109 if config::verbose() {
110 eprintln!(
111 "[PolarsObjectStore]: got error: {}, will attempt re-build",
112 &orig_err
113 );
114 }
115
116 let store = self
117 .rebuild_inner(&store)
118 .await
119 .map_err(|e| e.wrap_msg(|e| format!("{e}; original error: {orig_err}")))?;
120
121 func(&store).await.map_err(|e| {
122 if self.inner.builder.is_azure()
123 && std::env::var("POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY").as_deref()
124 != Ok("1")
125 {
126 e.wrap_msg(|e| {
129 format!(
130 "{e}; note: if you are using Python, consider setting \
131POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY=1 if you would like polars to try to retrieve \
132and use the storage account keys from Azure CLI to authenticate"
133 )
134 })
135 } else {
136 e
137 }
138 })
139 }
140 }
141}
142
143pub use inner::PolarsObjectStore;
144
145pub type ObjectStorePath = object_store::path::Path;
146
147impl PolarsObjectStore {
148 fn get_buffered_ranges_stream<'a, T: Iterator<Item = Range<usize>>>(
150 store: &'a dyn ObjectStore,
151 path: &'a Path,
152 ranges: T,
153 ) -> impl StreamExt<Item = PolarsResult<Bytes>>
154 + TryStreamExt<Ok = Bytes, Error = PolarsError, Item = PolarsResult<Bytes>>
155 + use<'a, T> {
156 futures::stream::iter(ranges.map(move |range| async move {
157 if range.is_empty() {
158 return Ok(Bytes::new());
159 }
160
161 let out = store
162 .get_range(path, range.start as u64..range.end as u64)
163 .await?;
164 Ok(out)
165 }))
166 .buffered(get_concurrency_limit() as usize)
168 }
169
170 pub async fn get_range(&self, path: &Path, range: Range<usize>) -> PolarsResult<Bytes> {
171 if range.is_empty() {
172 return Ok(Bytes::new());
173 }
174
175 self.try_exec_rebuild_on_err(move |store| {
176 let range = range.clone();
177 let st = store.clone();
178
179 async move {
180 let store = st;
181 let parts = split_range(range.clone());
182
183 if parts.len() == 1 {
184 let out = tune_with_concurrency_budget(1, move || async move {
185 store
186 .get_range(path, range.start as u64..range.end as u64)
187 .await
188 })
189 .await?;
190
191 Ok(out)
192 } else {
193 let parts = tune_with_concurrency_budget(
194 parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
195 || {
196 Self::get_buffered_ranges_stream(&store, path, parts)
197 .try_collect::<Vec<Bytes>>()
198 },
199 )
200 .await?;
201
202 let mut combined = Vec::with_capacity(range.len());
203
204 for part in parts {
205 combined.extend_from_slice(&part)
206 }
207
208 assert_eq!(combined.len(), range.len());
209
210 PolarsResult::Ok(Bytes::from(combined))
211 }
212 }
213 })
214 .await
215 }
216
217 pub async fn get_ranges_sort(
223 &self,
224 path: &Path,
225 ranges: &mut [Range<usize>],
226 ) -> PolarsResult<PlHashMap<usize, MemSlice>> {
227 if ranges.is_empty() {
228 return Ok(Default::default());
229 }
230
231 ranges.sort_unstable_by_key(|x| x.start);
232
233 let ranges_len = ranges.len();
234 let (merged_ranges, merged_ends): (Vec<_>, Vec<_>) = merge_ranges(ranges).unzip();
235
236 self.try_exec_rebuild_on_err(|store| {
237 let st = store.clone();
238
239 async {
240 let store = st;
241 let mut out = PlHashMap::with_capacity(ranges_len);
242
243 let mut stream =
244 Self::get_buffered_ranges_stream(&store, path, merged_ranges.iter().cloned());
245
246 tune_with_concurrency_budget(
247 merged_ranges.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
248 || async {
249 let mut len = 0;
250 let mut current_offset = 0;
251 let mut ends_iter = merged_ends.iter();
252
253 let mut splitted_parts = vec![];
254
255 while let Some(bytes) = stream.try_next().await? {
256 len += bytes.len();
257 let end = *ends_iter.next().unwrap();
258
259 if end == 0 {
260 splitted_parts.push(bytes);
261 continue;
262 }
263
264 let full_range = ranges[current_offset..end]
265 .iter()
266 .cloned()
267 .reduce(|l, r| l.start.min(r.start)..l.end.max(r.end))
268 .unwrap();
269
270 let bytes = if splitted_parts.is_empty() {
271 bytes
272 } else {
273 let mut out = Vec::with_capacity(full_range.len());
274
275 for x in splitted_parts.drain(..) {
276 out.extend_from_slice(&x);
277 }
278
279 out.extend_from_slice(&bytes);
280 Bytes::from(out)
281 };
282
283 assert_eq!(bytes.len(), full_range.len());
284
285 let bytes = MemSlice::from_bytes(bytes);
286
287 for range in &ranges[current_offset..end] {
288 let mem_slice = bytes.slice(
289 range.start - full_range.start..range.end - full_range.start,
290 );
291
292 match out.raw_entry_mut().from_key(&range.start) {
293 RawEntryMut::Vacant(slot) => {
294 slot.insert(range.start, mem_slice);
295 },
296 RawEntryMut::Occupied(mut slot) => {
297 if slot.get_mut().len() < mem_slice.len() {
298 *slot.get_mut() = mem_slice;
299 }
300 },
301 }
302 }
303
304 current_offset = end;
305 }
306
307 assert!(splitted_parts.is_empty());
308
309 PolarsResult::Ok(pl_async::Size::from(len as u64))
310 },
311 )
312 .await?;
313
314 Ok(out)
315 }
316 })
317 .await
318 }
319
320 pub async fn download(&self, path: &Path, file: &mut tokio::fs::File) -> PolarsResult<()> {
321 let opt_size = self.head(path).await.ok().map(|x| x.size);
322
323 let initial_pos = file.stream_position().await?;
324
325 self.try_exec_rebuild_on_err(|store| {
326 let st = store.clone();
327
328 let file: &mut tokio::fs::File = unsafe { std::mem::transmute_copy(&file) };
330
331 async {
332 file.set_len(initial_pos).await?; let store = st;
335 let parts = opt_size
336 .map(|x| split_range(0..x as usize))
337 .filter(|x| x.len() > 1);
338
339 if let Some(parts) = parts {
340 tune_with_concurrency_budget(
341 parts.len().clamp(0, MAX_BUDGET_PER_REQUEST) as u32,
342 || async {
343 let mut stream = Self::get_buffered_ranges_stream(&store, path, parts);
344 let mut len = 0;
345 while let Some(bytes) = stream.try_next().await? {
346 len += bytes.len();
347 file.write_all(&bytes).await?;
348 }
349
350 assert_eq!(len, opt_size.unwrap() as usize);
351
352 PolarsResult::Ok(pl_async::Size::from(len as u64))
353 },
354 )
355 .await?
356 } else {
357 tune_with_concurrency_budget(1, || async {
358 let mut stream = store.get(path).await?.into_stream();
359
360 let mut len = 0;
361 while let Some(bytes) = stream.try_next().await? {
362 len += bytes.len();
363 file.write_all(&bytes).await?;
364 }
365
366 PolarsResult::Ok(pl_async::Size::from(len as u64))
367 })
368 .await?
369 };
370
371 file.sync_all().await.map_err(PolarsError::from)?;
374
375 Ok(())
376 }
377 })
378 .await
379 }
380
381 pub async fn head(&self, path: &Path) -> PolarsResult<ObjectMeta> {
383 self.try_exec_rebuild_on_err(|store| {
384 let st = store.clone();
385
386 async {
387 with_concurrency_budget(1, || async {
388 let store = st;
389 let head_result = store.head(path).await;
390
391 if head_result.is_err() {
392 let get_range_0_0_result = store
395 .get_opts(
396 path,
397 object_store::GetOptions {
398 range: Some((0..1).into()),
399 ..Default::default()
400 },
401 )
402 .await;
403
404 if let Ok(v) = get_range_0_0_result {
405 return Ok(v.meta);
406 }
407 }
408
409 let out = head_result?;
410
411 Ok(out)
412 })
413 .await
414 }
415 })
416 .await
417 }
418}
419
420fn split_range(range: Range<usize>) -> impl ExactSizeIterator<Item = Range<usize>> {
423 let chunk_size = get_download_chunk_size();
424
425 let n_parts = [
427 (range.len().div_ceil(chunk_size)).max(1),
428 (range.len() / chunk_size).max(1),
429 ]
430 .into_iter()
431 .min_by_key(|x| (range.len() / *x).abs_diff(chunk_size))
432 .unwrap();
433
434 let chunk_size = (range.len() / n_parts).max(1);
435
436 assert_eq!(n_parts, (range.len() / chunk_size).max(1));
437 let bytes_rem = range.len() % chunk_size;
438
439 (0..n_parts).map(move |part_no| {
440 let (start, end) = if part_no == 0 {
441 let end = range.start + chunk_size + bytes_rem;
443 let end = if end > range.end { range.end } else { end };
444 (range.start, end)
445 } else {
446 let start = bytes_rem + range.start + part_no * chunk_size;
447 (start, start + chunk_size)
448 };
449
450 start..end
451 })
452}
453
454fn merge_ranges(ranges: &[Range<usize>]) -> impl Iterator<Item = (Range<usize>, usize)> + '_ {
465 let chunk_size = get_download_chunk_size();
466
467 let mut current_merged_range = ranges.first().map_or(0..0, Clone::clone);
468 let mut current_n_bytes = current_merged_range.len();
470
471 (0..ranges.len())
472 .filter_map(move |current_idx| {
473 let current_idx = 1 + current_idx;
474
475 if current_idx == ranges.len() {
476 Some((current_merged_range.clone(), current_idx))
478 } else {
479 let range = ranges[current_idx].clone();
480
481 let new_merged = current_merged_range.start.min(range.start)
482 ..current_merged_range.end.max(range.end);
483
484 let (distance, is_overlapping) = {
495 let l = current_merged_range.end.min(range.end);
496 let r = current_merged_range.start.max(range.start);
497
498 (r.abs_diff(l), r < l)
499 };
500
501 let should_merge = is_overlapping || {
502 let leq_current_len_dist_to_chunk_size = new_merged.len().abs_diff(chunk_size)
503 <= current_merged_range.len().abs_diff(chunk_size);
504 let gap_tolerance =
505 (current_n_bytes.max(range.len()) / 8).clamp(1024 * 1024, 8 * 1024 * 1024);
506
507 leq_current_len_dist_to_chunk_size && distance <= gap_tolerance
508 };
509
510 if should_merge {
511 current_merged_range = new_merged;
513 current_n_bytes += if is_overlapping {
514 range.len() - distance
515 } else {
516 range.len()
517 };
518 None
519 } else {
520 let out = (current_merged_range.clone(), current_idx);
521 current_merged_range = range;
522 current_n_bytes = current_merged_range.len();
523 Some(out)
524 }
525 }
526 })
527 .flat_map(|x| {
528 let (range, end) = x;
530 let split = split_range(range);
531 let len = split.len();
532
533 split
534 .enumerate()
535 .map(move |(i, range)| (range, if 1 + i == len { end } else { 0 }))
536 })
537}
538
539#[cfg(test)]
540mod tests {
541
542 #[test]
543 fn test_split_range() {
544 use super::{get_download_chunk_size, split_range};
545
546 let chunk_size = get_download_chunk_size();
547
548 assert_eq!(chunk_size, 64 * 1024 * 1024);
549
550 #[allow(clippy::single_range_in_vec_init)]
551 {
552 assert_eq!(split_range(0..0).collect::<Vec<_>>(), [0..0]);
554 assert_eq!(split_range(3..3).collect::<Vec<_>>(), [3..3]);
555 }
556
557 let n = 4 * chunk_size / 3;
564
565 #[allow(clippy::single_range_in_vec_init)]
566 {
567 assert_eq!(split_range(0..n).collect::<Vec<_>>(), [0..89478485]);
568 }
569
570 assert_eq!(
571 split_range(0..n + 1).collect::<Vec<_>>(),
572 [0..44739243, 44739243..89478486]
573 );
574
575 let n = 12 * chunk_size / 5;
582
583 assert_eq!(
584 split_range(0..n).collect::<Vec<_>>(),
585 [0..80530637, 80530637..161061273]
586 );
587
588 assert_eq!(
589 split_range(0..n + 1).collect::<Vec<_>>(),
590 [0..53687092, 53687092..107374183, 107374183..161061274]
591 );
592 }
593
594 #[test]
595 fn test_merge_ranges() {
596 use super::{get_download_chunk_size, merge_ranges};
597
598 let chunk_size = get_download_chunk_size();
599
600 assert_eq!(chunk_size, 64 * 1024 * 1024);
601
602 assert_eq!(merge_ranges(&[]).collect::<Vec<_>>(), []);
604
605 assert_eq!(
609 merge_ranges(&[0..1, 1..127 * 1024 * 1024]).collect::<Vec<_>>(),
610 [(0..66584576, 0), (66584576..133169152, 2)]
611 );
612
613 assert_eq!(
615 merge_ranges(&[0..1, 1024 * 1024 + 1..1024 * 1024 + 2]).collect::<Vec<_>>(),
616 [(0..1048578, 2)]
617 );
618
619 assert_eq!(
621 merge_ranges(&[0..1, 1024 * 1024 + 2..1024 * 1024 + 3]).collect::<Vec<_>>(),
622 [(0..1, 1), (1048578..1048579, 2)]
623 );
624
625 assert_eq!(
627 merge_ranges(&[0..8, 10..11]).collect::<Vec<_>>(),
628 [(0..11, 2)]
629 );
630
631 assert_eq!(
633 merge_ranges(&[0..1, 3..11]).collect::<Vec<_>>(),
634 [(0..11, 2)]
635 );
636
637 assert_eq!(
639 merge_ranges(&[0..80 * 1024 * 1024, 10 * 1024 * 1024..70 * 1024 * 1024])
640 .collect::<Vec<_>>(),
641 [(0..80 * 1024 * 1024, 2)]
642 );
643 }
644}