polars_io/cloud/cloud_writer/
io_trait_wrap.rs1use std::pin::Pin;
2use std::task::{Poll, ready};
3
4use futures::FutureExt;
5
6use crate::cloud::cloud_writer::CloudWriter;
7use crate::pl_async;
8use crate::utils::file::WriteableTrait;
9
10pub struct CloudWriterIoTraitWrap {
12 state: WriterState,
13}
14
15enum WriterState {
16 Ready(Box<CloudWriter>),
17 Poll(
18 Pin<Box<dyn Future<Output = std::io::Result<WriterState>> + Send + 'static>>,
19 PollOperation,
20 ),
21 Finished,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25enum PollOperation {
26 Write { slice_ptr: usize, written: usize },
28 Flush,
29 Shutdown,
30}
31
32struct FinishActivePoll<'a>(Pin<&'a mut WriterState>);
33
34impl<'a> Future for FinishActivePoll<'a> {
35 type Output = std::io::Result<Option<PollOperation>>;
36
37 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
38 match &mut *self.0 {
39 WriterState::Poll(fut, _) => match fut.poll_unpin(cx) {
40 Poll::Ready(Ok(new_state)) => {
41 debug_assert!(!matches!(&new_state, WriterState::Poll(..)));
42
43 let WriterState::Poll(_, operation) =
44 std::mem::replace(&mut *self.0, new_state)
45 else {
46 unreachable!()
47 };
48
49 Poll::Ready(Ok(Some(operation)))
50 },
51 Poll::Ready(Err(e)) => {
52 *self.0 = WriterState::Finished;
53 Poll::Ready(Err(e))
54 },
55 Poll::Pending => Poll::Pending,
56 },
57
58 WriterState::Ready(_) | WriterState::Finished => Poll::Ready(Ok(None)),
59 }
60 }
61}
62
63impl CloudWriterIoTraitWrap {
64 fn finish_active_poll(&mut self) -> FinishActivePoll<'_> {
65 FinishActivePoll(Pin::new(&mut self.state))
66 }
67
68 fn take_writer_from_ready_state(&mut self) -> Option<Box<CloudWriter>> {
69 if !matches!(&self.state, WriterState::Ready(_)) {
70 return None;
71 }
72
73 let WriterState::Ready(writer) = std::mem::replace(&mut self.state, WriterState::Finished)
74 else {
75 unreachable!()
76 };
77
78 Some(writer)
79 }
80
81 fn get_writer_mut_from_ready_state(&mut self) -> Option<&mut CloudWriter> {
82 if let WriterState::Ready(writer) = &mut self.state {
83 Some(writer.as_mut())
84 } else {
85 None
86 }
87 }
88
89 pub async fn into_cloud_writer(mut self) -> std::io::Result<CloudWriter> {
90 self.finish_active_poll().await?;
91
92 match self.state {
93 WriterState::Ready(writer) => Ok(*writer),
94 WriterState::Poll(..) => unreachable!(),
95 WriterState::Finished => panic!(),
96 }
97 }
98
99 pub fn as_cloud_writer(&mut self) -> std::io::Result<&mut CloudWriter> {
100 if !matches!(self.state, WriterState::Ready(_)) {
101 match &mut self.state {
102 WriterState::Ready(_) => unreachable!(),
103 WriterState::Poll(..) => {
104 pl_async::get_runtime().block_in_place_on(self.finish_active_poll())?
105 },
106 WriterState::Finished => panic!(),
107 };
108 }
109
110 let WriterState::Ready(writer) = &mut self.state else {
111 panic!()
112 };
113
114 Ok(writer)
115 }
116}
117
118impl From<CloudWriter> for CloudWriterIoTraitWrap {
119 fn from(writer: CloudWriter) -> Self {
120 Self {
121 state: WriterState::Ready(Box::new(writer)),
122 }
123 }
124}
125
126impl std::io::Write for CloudWriterIoTraitWrap {
127 fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
128 let total_buf_len = buf.len();
129 let buf: &mut &[u8] = &mut buf;
130
131 if let Some(writer) = self.get_writer_mut_from_ready_state() {
132 let should_poll = writer.fill_buffer_from_slice(buf);
133
134 if !should_poll {
135 assert!(buf.is_empty());
136 return Ok(total_buf_len);
137 }
138 }
139
140 pl_async::get_runtime().block_in_place_on(async {
141 self.finish_active_poll().await?;
142
143 let writer = self.get_writer_mut_from_ready_state().unwrap();
144
145 loop {
146 writer.flush_complete_chunk().await?;
147
148 if !writer.fill_buffer_from_slice(buf) {
149 break;
150 }
151 }
152
153 assert!(buf.is_empty());
154
155 Ok(total_buf_len)
156 })
157 }
158
159 fn flush(&mut self) -> std::io::Result<()> {
160 if self
161 .get_writer_mut_from_ready_state()
162 .is_some_and(|w| !w.has_buffered_bytes())
163 {
164 return Ok(());
165 }
166
167 pl_async::get_runtime().block_in_place_on(async {
168 self.finish_active_poll().await?;
169
170 self.get_writer_mut_from_ready_state()
171 .unwrap()
172 .flush()
173 .await?;
174
175 Ok(())
176 })
177 }
178}
179
180impl WriteableTrait for CloudWriterIoTraitWrap {
181 fn close(&mut self) -> std::io::Result<()> {
182 pl_async::get_runtime().block_in_place_on(async {
183 self.finish_active_poll().await?;
184
185 if let Some(mut writer) = self.take_writer_from_ready_state() {
186 writer.flush().await?;
187 writer.finish().await?;
188 }
189
190 Ok(())
191 })
192 }
193
194 fn sync_all(&self) -> std::io::Result<()> {
195 Ok(())
196 }
197
198 fn sync_data(&self) -> std::io::Result<()> {
199 Ok(())
200 }
201}
202
203impl tokio::io::AsyncWrite for CloudWriterIoTraitWrap {
204 fn poll_write(
205 mut self: Pin<&mut Self>,
206 cx: &mut std::task::Context<'_>,
207 buf: &[u8],
208 ) -> std::task::Poll<std::io::Result<usize>> {
209 loop {
210 let offset = match ready!(self.finish_active_poll().poll_unpin(cx))? {
211 Some(PollOperation::Write { slice_ptr, written })
212 if slice_ptr == buf.as_ptr() as usize =>
213 {
214 written
215 },
216 Some(_) => panic!(),
217 None => 0,
218 };
219
220 let writer = self.get_writer_mut_from_ready_state().unwrap();
221
222 let offset_buf: &mut &[u8] = &mut &buf[offset..];
223
224 let should_poll_flush = writer.fill_buffer_from_slice(offset_buf);
225
226 if !should_poll_flush {
227 assert!(offset_buf.is_empty());
228 return Poll::Ready(Ok(buf.len()));
229 };
230
231 let new_offset = buf.len() - offset_buf.len();
232
233 let mut writer = self.take_writer_from_ready_state().unwrap();
234
235 let fut = async move {
236 writer.flush_complete_chunk().await?;
237 Ok(WriterState::Ready(writer))
238 };
239
240 self.state = WriterState::Poll(
241 Box::pin(fut),
242 PollOperation::Write {
243 slice_ptr: buf.as_ptr() as usize,
244 written: new_offset,
245 },
246 );
247 }
248 }
249
250 fn poll_flush(
251 mut self: Pin<&mut Self>,
252 cx: &mut std::task::Context<'_>,
253 ) -> std::task::Poll<std::io::Result<()>> {
254 loop {
255 if let Some(operation) = ready!(self.finish_active_poll().poll_unpin(cx))? {
256 debug_assert_eq!(operation, PollOperation::Flush);
257
258 if operation == PollOperation::Flush {
259 return Poll::Ready(Ok(()));
260 }
261 }
262
263 let mut writer = self.take_writer_from_ready_state().unwrap();
264
265 let fut = async move {
266 writer.flush().await?;
267 Ok(WriterState::Ready(writer))
268 };
269
270 self.state = WriterState::Poll(Box::pin(fut), PollOperation::Flush);
271 }
272 }
273
274 fn poll_shutdown(
275 mut self: Pin<&mut Self>,
276 cx: &mut std::task::Context<'_>,
277 ) -> std::task::Poll<std::io::Result<()>> {
278 loop {
279 if let Some(operation) = ready!(self.finish_active_poll().poll_unpin(cx))? {
280 debug_assert_eq!(operation, PollOperation::Shutdown);
281
282 if operation == PollOperation::Shutdown {
283 return Poll::Ready(Ok(()));
284 }
285 }
286
287 let Some(mut writer) = self.take_writer_from_ready_state() else {
288 return Poll::Ready(Ok(()));
289 };
290
291 let fut = async move {
292 writer.finish().await?;
293 Ok(WriterState::Finished)
294 };
295
296 self.state = WriterState::Poll(Box::pin(fut), PollOperation::Shutdown);
297 }
298 }
299}