polars_io/cloud/cloud_writer/
io_trait_wrap.rs

1use std::pin::Pin;
2use std::task::{Poll, ready};
3
4use futures::FutureExt;
5
6use crate::cloud::cloud_writer::CloudWriter;
7use crate::pl_async;
8use crate::utils::file::WriteableTrait;
9
10/// Wrapper on [`CloudWriter`] that implements [`std::io::Write`] and [`tokio::io::AsyncWrite`].
11pub struct CloudWriterIoTraitWrap {
12    state: WriterState,
13}
14
15enum WriterState {
16    Ready(Box<CloudWriter>),
17    Poll(
18        Pin<Box<dyn Future<Output = std::io::Result<WriterState>> + Send + 'static>>,
19        PollOperation,
20    ),
21    Finished,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25enum PollOperation {
26    // (slice_addr, slice_len)
27    Write { slice_ptr: usize, written: usize },
28    Flush,
29    Shutdown,
30}
31
32struct FinishActivePoll<'a>(Pin<&'a mut WriterState>);
33
34impl<'a> Future for FinishActivePoll<'a> {
35    type Output = std::io::Result<Option<PollOperation>>;
36
37    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
38        match &mut *self.0 {
39            WriterState::Poll(fut, _) => match fut.poll_unpin(cx) {
40                Poll::Ready(Ok(new_state)) => {
41                    debug_assert!(!matches!(&new_state, WriterState::Poll(..)));
42
43                    let WriterState::Poll(_, operation) =
44                        std::mem::replace(&mut *self.0, new_state)
45                    else {
46                        unreachable!()
47                    };
48
49                    Poll::Ready(Ok(Some(operation)))
50                },
51                Poll::Ready(Err(e)) => {
52                    *self.0 = WriterState::Finished;
53                    Poll::Ready(Err(e))
54                },
55                Poll::Pending => Poll::Pending,
56            },
57
58            WriterState::Ready(_) | WriterState::Finished => Poll::Ready(Ok(None)),
59        }
60    }
61}
62
63impl CloudWriterIoTraitWrap {
64    fn finish_active_poll(&mut self) -> FinishActivePoll<'_> {
65        FinishActivePoll(Pin::new(&mut self.state))
66    }
67
68    fn take_writer_from_ready_state(&mut self) -> Option<Box<CloudWriter>> {
69        if !matches!(&self.state, WriterState::Ready(_)) {
70            return None;
71        }
72
73        let WriterState::Ready(writer) = std::mem::replace(&mut self.state, WriterState::Finished)
74        else {
75            unreachable!()
76        };
77
78        Some(writer)
79    }
80
81    fn get_writer_mut_from_ready_state(&mut self) -> Option<&mut CloudWriter> {
82        if let WriterState::Ready(writer) = &mut self.state {
83            Some(writer.as_mut())
84        } else {
85            None
86        }
87    }
88
89    pub async fn into_cloud_writer(mut self) -> std::io::Result<CloudWriter> {
90        self.finish_active_poll().await?;
91
92        match self.state {
93            WriterState::Ready(writer) => Ok(*writer),
94            WriterState::Poll(..) => unreachable!(),
95            WriterState::Finished => panic!(),
96        }
97    }
98
99    pub fn as_cloud_writer(&mut self) -> std::io::Result<&mut CloudWriter> {
100        if !matches!(self.state, WriterState::Ready(_)) {
101            match &mut self.state {
102                WriterState::Ready(_) => unreachable!(),
103                WriterState::Poll(..) => {
104                    pl_async::get_runtime().block_in_place_on(self.finish_active_poll())?
105                },
106                WriterState::Finished => panic!(),
107            };
108        }
109
110        let WriterState::Ready(writer) = &mut self.state else {
111            panic!()
112        };
113
114        Ok(writer)
115    }
116}
117
118impl From<CloudWriter> for CloudWriterIoTraitWrap {
119    fn from(writer: CloudWriter) -> Self {
120        Self {
121            state: WriterState::Ready(Box::new(writer)),
122        }
123    }
124}
125
126impl std::io::Write for CloudWriterIoTraitWrap {
127    fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
128        let total_buf_len = buf.len();
129        let buf: &mut &[u8] = &mut buf;
130
131        if let Some(writer) = self.get_writer_mut_from_ready_state() {
132            let should_poll = writer.fill_buffer_from_slice(buf);
133
134            if !should_poll {
135                assert!(buf.is_empty());
136                return Ok(total_buf_len);
137            }
138        }
139
140        pl_async::get_runtime().block_in_place_on(async {
141            self.finish_active_poll().await?;
142
143            let writer = self.get_writer_mut_from_ready_state().unwrap();
144
145            loop {
146                writer.flush_complete_chunk().await?;
147
148                if !writer.fill_buffer_from_slice(buf) {
149                    break;
150                }
151            }
152
153            assert!(buf.is_empty());
154
155            Ok(total_buf_len)
156        })
157    }
158
159    fn flush(&mut self) -> std::io::Result<()> {
160        if self
161            .get_writer_mut_from_ready_state()
162            .is_some_and(|w| !w.has_buffered_bytes())
163        {
164            return Ok(());
165        }
166
167        pl_async::get_runtime().block_in_place_on(async {
168            self.finish_active_poll().await?;
169
170            self.get_writer_mut_from_ready_state()
171                .unwrap()
172                .flush()
173                .await?;
174
175            Ok(())
176        })
177    }
178}
179
180impl WriteableTrait for CloudWriterIoTraitWrap {
181    fn close(&mut self) -> std::io::Result<()> {
182        pl_async::get_runtime().block_in_place_on(async {
183            self.finish_active_poll().await?;
184
185            if let Some(mut writer) = self.take_writer_from_ready_state() {
186                writer.flush().await?;
187                writer.finish().await?;
188            }
189
190            Ok(())
191        })
192    }
193
194    fn sync_all(&self) -> std::io::Result<()> {
195        Ok(())
196    }
197
198    fn sync_data(&self) -> std::io::Result<()> {
199        Ok(())
200    }
201}
202
203impl tokio::io::AsyncWrite for CloudWriterIoTraitWrap {
204    fn poll_write(
205        mut self: Pin<&mut Self>,
206        cx: &mut std::task::Context<'_>,
207        buf: &[u8],
208    ) -> std::task::Poll<std::io::Result<usize>> {
209        loop {
210            let offset = match ready!(self.finish_active_poll().poll_unpin(cx))? {
211                Some(PollOperation::Write { slice_ptr, written })
212                    if slice_ptr == buf.as_ptr() as usize =>
213                {
214                    written
215                },
216                Some(_) => panic!(),
217                None => 0,
218            };
219
220            let writer = self.get_writer_mut_from_ready_state().unwrap();
221
222            let offset_buf: &mut &[u8] = &mut &buf[offset..];
223
224            let should_poll_flush = writer.fill_buffer_from_slice(offset_buf);
225
226            if !should_poll_flush {
227                assert!(offset_buf.is_empty());
228                return Poll::Ready(Ok(buf.len()));
229            };
230
231            let new_offset = buf.len() - offset_buf.len();
232
233            let mut writer = self.take_writer_from_ready_state().unwrap();
234
235            let fut = async move {
236                writer.flush_complete_chunk().await?;
237                Ok(WriterState::Ready(writer))
238            };
239
240            self.state = WriterState::Poll(
241                Box::pin(fut),
242                PollOperation::Write {
243                    slice_ptr: buf.as_ptr() as usize,
244                    written: new_offset,
245                },
246            );
247        }
248    }
249
250    fn poll_flush(
251        mut self: Pin<&mut Self>,
252        cx: &mut std::task::Context<'_>,
253    ) -> std::task::Poll<std::io::Result<()>> {
254        loop {
255            if let Some(operation) = ready!(self.finish_active_poll().poll_unpin(cx))? {
256                debug_assert_eq!(operation, PollOperation::Flush);
257
258                if operation == PollOperation::Flush {
259                    return Poll::Ready(Ok(()));
260                }
261            }
262
263            let mut writer = self.take_writer_from_ready_state().unwrap();
264
265            let fut = async move {
266                writer.flush().await?;
267                Ok(WriterState::Ready(writer))
268            };
269
270            self.state = WriterState::Poll(Box::pin(fut), PollOperation::Flush);
271        }
272    }
273
274    fn poll_shutdown(
275        mut self: Pin<&mut Self>,
276        cx: &mut std::task::Context<'_>,
277    ) -> std::task::Poll<std::io::Result<()>> {
278        loop {
279            if let Some(operation) = ready!(self.finish_active_poll().poll_unpin(cx))? {
280                debug_assert_eq!(operation, PollOperation::Shutdown);
281
282                if operation == PollOperation::Shutdown {
283                    return Poll::Ready(Ok(()));
284                }
285            }
286
287            let Some(mut writer) = self.take_writer_from_ready_state() else {
288                return Poll::Ready(Ok(()));
289            };
290
291            let fut = async move {
292                writer.finish().await?;
293                Ok(WriterState::Finished)
294            };
295
296            self.state = WriterState::Poll(Box::pin(fut), PollOperation::Shutdown);
297        }
298    }
299}