Skip to main content

polars_io/cloud/cloud_writer/
internal_writer.rs

1use std::num::NonZeroUsize;
2
3use futures::StreamExt as _;
4use futures::stream::FuturesUnordered;
5use object_store::PutPayload;
6use polars_error::{PolarsError, PolarsResult};
7use polars_utils::async_utils::error_capture::{ErrorCapture, ErrorHandle};
8use polars_utils::async_utils::tokio_handle_ext;
9
10use crate::cloud::PolarsObjectStore;
11use crate::cloud::cloud_writer::multipart_upload::PlMultipartUpload;
12use crate::metrics::OptIOMetrics;
13
14/// Cloud writer that provides the `put()` function, does not perform any buffering.
15pub(super) struct InternalCloudWriter {
16    pub(super) store: PolarsObjectStore,
17    pub(super) path: object_store::path::Path,
18    pub(super) max_concurrency: NonZeroUsize,
19    pub(super) io_metrics: OptIOMetrics,
20    pub(super) state: InternalCloudWriterState,
21}
22
23pub(super) enum InternalCloudWriterState {
24    NotStarted,
25    Started(StartedState),
26    Finished,
27}
28
29type WriterState = InternalCloudWriterState;
30
31pub(super) struct StartedState {
32    multipart: PlMultipartUpload,
33    tasks: FuturesUnordered<tokio_handle_ext::AbortOnDropHandle<()>>,
34    error_handle: ErrorHandle<PolarsError>,
35    error_capture: ErrorCapture<PolarsError>,
36}
37
38impl InternalCloudWriter {
39    pub(super) async fn start(&mut self) -> PolarsResult<()> {
40        if let WriterState::NotStarted = &self.state {
41            let path_ref = &self.path;
42            let multipart = PlMultipartUpload::new(
43                self.store
44                    .exec_with_rebuild_retry_on_err(|s| async move {
45                        s.put_multipart_opts(path_ref, object_store::PutMultipartOptions::default())
46                            .await
47                    })
48                    .await?,
49                self.store.error_context(),
50            );
51
52            let (error_capture, error_handle) = ErrorCapture::new();
53
54            self.state = WriterState::Started(StartedState {
55                multipart,
56                tasks: FuturesUnordered::new(),
57                error_handle,
58                error_capture,
59            });
60        }
61
62        Ok(())
63    }
64
65    async fn get_or_init_started_state(&mut self) -> PolarsResult<&mut StartedState> {
66        loop {
67            match &self.state {
68                WriterState::Started(_) => {
69                    let WriterState::Started(state) = &mut self.state else {
70                        unreachable!()
71                    };
72                    return Ok(state);
73                },
74                WriterState::NotStarted => self.start().await?,
75                WriterState::Finished => panic!(),
76            }
77        }
78    }
79
80    /// Takes `self.state`, replacing with it `Finished`. Returns `None` if `self.state` is not
81    /// `Started`.
82    fn take_started_state(&mut self) -> Option<StartedState> {
83        if !matches!(&self.state, WriterState::Started(_)) {
84            return None;
85        }
86
87        let WriterState::Started(state) = std::mem::replace(&mut self.state, WriterState::Finished)
88        else {
89            unreachable!()
90        };
91
92        Some(state)
93    }
94
95    pub(super) async fn put(&mut self, payload: PutPayload) -> PolarsResult<()> {
96        let io_metrics = self.io_metrics.clone();
97        let max_concurrency = self.max_concurrency.get();
98
99        let state = self.get_or_init_started_state().await?;
100
101        if state.error_handle.has_errored() {
102            let state = self.take_started_state().unwrap();
103            return Err(state.error_handle.join().await.unwrap_err());
104        }
105
106        while state.tasks.len() >= max_concurrency {
107            state.tasks.next().await;
108        }
109
110        let num_bytes = payload.content_length() as u64;
111        let upload_fut = state.multipart.put(payload);
112
113        let fut = async move { io_metrics.record_bytes_tx(num_bytes, upload_fut).await };
114
115        let handle = tokio_handle_ext::AbortOnDropHandle(tokio::spawn(
116            state.error_capture.clone().wrap_future(fut),
117        ));
118
119        state.tasks.push(handle);
120
121        Ok(())
122    }
123
124    pub(super) async fn finish(&mut self) -> PolarsResult<()> {
125        let Some(StartedState {
126            mut multipart,
127            tasks,
128            error_handle,
129            error_capture,
130        }) = self.take_started_state()
131        else {
132            return Ok(());
133        };
134
135        drop(error_capture);
136        error_handle.join().await?;
137
138        for handle in tasks {
139            handle.await.unwrap();
140        }
141
142        multipart.finish().await?;
143
144        Ok(())
145    }
146}