polars_io/cloud/cloud_writer/
internal_writer.rs1use std::num::NonZeroUsize;
2
3use futures::StreamExt as _;
4use futures::stream::FuturesUnordered;
5use object_store::PutPayload;
6use polars_error::{PolarsError, PolarsResult};
7use polars_utils::async_utils::error_capture::{ErrorCapture, ErrorHandle};
8use polars_utils::async_utils::tokio_handle_ext;
9
10use crate::cloud::PolarsObjectStore;
11use crate::cloud::cloud_writer::multipart_upload::PlMultipartUpload;
12use crate::metrics::OptIOMetrics;
13
14pub(super) struct InternalCloudWriter {
16 pub(super) store: PolarsObjectStore,
17 pub(super) path: object_store::path::Path,
18 pub(super) max_concurrency: NonZeroUsize,
19 pub(super) io_metrics: OptIOMetrics,
20 pub(super) state: InternalCloudWriterState,
21}
22
23pub(super) enum InternalCloudWriterState {
24 NotStarted,
25 Started(StartedState),
26 Finished,
27}
28
29type WriterState = InternalCloudWriterState;
30
31pub(super) struct StartedState {
32 multipart: PlMultipartUpload,
33 tasks: FuturesUnordered<tokio_handle_ext::AbortOnDropHandle<()>>,
34 error_handle: ErrorHandle<PolarsError>,
35 error_capture: ErrorCapture<PolarsError>,
36}
37
38impl InternalCloudWriter {
39 pub(super) async fn start(&mut self) -> PolarsResult<()> {
40 if let WriterState::NotStarted = &self.state {
41 let path_ref = &self.path;
42 let multipart = PlMultipartUpload::new(
43 self.store
44 .exec_with_rebuild_retry_on_err(|s| async move {
45 s.put_multipart_opts(path_ref, object_store::PutMultipartOptions::default())
46 .await
47 })
48 .await?,
49 self.store.error_context(),
50 );
51
52 let (error_capture, error_handle) = ErrorCapture::new();
53
54 self.state = WriterState::Started(StartedState {
55 multipart,
56 tasks: FuturesUnordered::new(),
57 error_handle,
58 error_capture,
59 });
60 }
61
62 Ok(())
63 }
64
65 async fn get_or_init_started_state(&mut self) -> PolarsResult<&mut StartedState> {
66 loop {
67 match &self.state {
68 WriterState::Started(_) => {
69 let WriterState::Started(state) = &mut self.state else {
70 unreachable!()
71 };
72 return Ok(state);
73 },
74 WriterState::NotStarted => self.start().await?,
75 WriterState::Finished => panic!(),
76 }
77 }
78 }
79
80 fn take_started_state(&mut self) -> Option<StartedState> {
83 if !matches!(&self.state, WriterState::Started(_)) {
84 return None;
85 }
86
87 let WriterState::Started(state) = std::mem::replace(&mut self.state, WriterState::Finished)
88 else {
89 unreachable!()
90 };
91
92 Some(state)
93 }
94
95 pub(super) async fn put(&mut self, payload: PutPayload) -> PolarsResult<()> {
96 let io_metrics = self.io_metrics.clone();
97 let max_concurrency = self.max_concurrency.get();
98
99 let state = self.get_or_init_started_state().await?;
100
101 if state.error_handle.has_errored() {
102 let state = self.take_started_state().unwrap();
103 return Err(state.error_handle.join().await.unwrap_err());
104 }
105
106 while state.tasks.len() >= max_concurrency {
107 state.tasks.next().await;
108 }
109
110 let num_bytes = payload.content_length() as u64;
111 let upload_fut = state.multipart.put(payload);
112
113 let fut = async move { io_metrics.record_bytes_tx(num_bytes, upload_fut).await };
114
115 let handle = tokio_handle_ext::AbortOnDropHandle(tokio::spawn(
116 state.error_capture.clone().wrap_future(fut),
117 ));
118
119 state.tasks.push(handle);
120
121 Ok(())
122 }
123
124 pub(super) async fn finish(&mut self) -> PolarsResult<()> {
125 let Some(StartedState {
126 mut multipart,
127 tasks,
128 error_handle,
129 error_capture,
130 }) = self.take_started_state()
131 else {
132 return Ok(());
133 };
134
135 drop(error_capture);
136 error_handle.join().await?;
137
138 for handle in tasks {
139 handle.await.unwrap();
140 }
141
142 multipart.finish().await?;
143
144 Ok(())
145 }
146}