Skip to main content

polars_io/cloud/cloud_writer/
io_trait_wrap.rs

1use std::pin::Pin;
2use std::task::{Poll, ready};
3
4use bytes::Bytes;
5use futures::FutureExt;
6
7use crate::cloud::cloud_writer::CloudWriter;
8use crate::pl_async;
9use crate::utils::file::WriteableTrait;
10
11/// Wrapper on [`CloudWriter`] that implements [`std::io::Write`] and [`tokio::io::AsyncWrite`].
12pub struct CloudWriterIoTraitWrap {
13    state: WriterState,
14}
15
16enum WriterState {
17    Ready(Box<CloudWriter>),
18    Poll(
19        Pin<Box<dyn Future<Output = std::io::Result<WriterState>> + Send + 'static>>,
20        PollOperation,
21    ),
22    Finished,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26enum PollOperation {
27    // (slice_addr, slice_len)
28    Write { slice_ptr: usize, written: usize },
29    Flush,
30    Shutdown,
31}
32
33struct FinishActivePoll<'a>(Pin<&'a mut WriterState>);
34
35impl<'a> Future for FinishActivePoll<'a> {
36    type Output = std::io::Result<Option<PollOperation>>;
37
38    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
39        match &mut *self.0 {
40            WriterState::Poll(fut, _) => match fut.poll_unpin(cx) {
41                Poll::Ready(Ok(new_state)) => {
42                    debug_assert!(!matches!(&new_state, WriterState::Poll(..)));
43
44                    let WriterState::Poll(_, operation) =
45                        std::mem::replace(&mut *self.0, new_state)
46                    else {
47                        unreachable!()
48                    };
49
50                    Poll::Ready(Ok(Some(operation)))
51                },
52                Poll::Ready(Err(e)) => {
53                    *self.0 = WriterState::Finished;
54                    Poll::Ready(Err(e))
55                },
56                Poll::Pending => Poll::Pending,
57            },
58
59            WriterState::Ready(_) | WriterState::Finished => Poll::Ready(Ok(None)),
60        }
61    }
62}
63
64impl CloudWriterIoTraitWrap {
65    fn finish_active_poll(&mut self) -> FinishActivePoll<'_> {
66        FinishActivePoll(Pin::new(&mut self.state))
67    }
68
69    fn take_writer_from_ready_state(&mut self) -> Option<Box<CloudWriter>> {
70        if !matches!(&self.state, WriterState::Ready(_)) {
71            return None;
72        }
73
74        let WriterState::Ready(writer) = std::mem::replace(&mut self.state, WriterState::Finished)
75        else {
76            unreachable!()
77        };
78
79        Some(writer)
80    }
81
82    fn get_writer_mut_from_ready_state(&mut self) -> Option<&mut CloudWriter> {
83        if let WriterState::Ready(writer) = &mut self.state {
84            Some(writer.as_mut())
85        } else {
86            None
87        }
88    }
89
90    pub async fn write_all_owned(&mut self, bytes: Bytes) -> std::io::Result<()> {
91        self.finish_active_poll().await?;
92
93        self.get_writer_mut_from_ready_state()
94            .unwrap()
95            .write_all_owned(bytes)
96            .await?;
97
98        Ok(())
99    }
100
101    pub async fn into_cloud_writer(mut self) -> std::io::Result<CloudWriter> {
102        self.finish_active_poll().await?;
103
104        match self.state {
105            WriterState::Ready(writer) => Ok(*writer),
106            WriterState::Poll(..) => unreachable!(),
107            WriterState::Finished => panic!(),
108        }
109    }
110
111    pub fn as_cloud_writer(&mut self) -> std::io::Result<&mut CloudWriter> {
112        if !matches!(self.state, WriterState::Ready(_)) {
113            match &mut self.state {
114                WriterState::Ready(_) => unreachable!(),
115                WriterState::Poll(..) => {
116                    pl_async::get_runtime().block_in_place_on(self.finish_active_poll())?
117                },
118                WriterState::Finished => panic!(),
119            };
120        }
121
122        let WriterState::Ready(writer) = &mut self.state else {
123            panic!()
124        };
125
126        Ok(writer)
127    }
128}
129
130impl From<CloudWriter> for CloudWriterIoTraitWrap {
131    fn from(writer: CloudWriter) -> Self {
132        Self {
133            state: WriterState::Ready(Box::new(writer)),
134        }
135    }
136}
137
138impl std::io::Write for CloudWriterIoTraitWrap {
139    fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
140        let total_buf_len = buf.len();
141        let buf: &mut &[u8] = &mut buf;
142
143        if let Some(writer) = self.get_writer_mut_from_ready_state() {
144            let full = writer.fill_buffer_from_slice(buf);
145
146            if !full {
147                assert!(buf.is_empty());
148                return Ok(total_buf_len);
149            }
150        }
151
152        pl_async::get_runtime().block_in_place_on(async {
153            self.finish_active_poll().await?;
154
155            let writer = self.get_writer_mut_from_ready_state().unwrap();
156
157            loop {
158                writer.flush_full_chunk().await?;
159
160                if !writer.fill_buffer_from_slice(buf) {
161                    break;
162                }
163            }
164
165            assert!(buf.is_empty());
166
167            Ok(total_buf_len)
168        })
169    }
170
171    fn flush(&mut self) -> std::io::Result<()> {
172        if self
173            .get_writer_mut_from_ready_state()
174            .is_some_and(|w| !w.has_buffered_bytes())
175        {
176            return Ok(());
177        }
178
179        pl_async::get_runtime().block_in_place_on(async {
180            self.finish_active_poll().await?;
181
182            self.get_writer_mut_from_ready_state()
183                .unwrap()
184                .flush()
185                .await?;
186
187            Ok(())
188        })
189    }
190}
191
192impl WriteableTrait for CloudWriterIoTraitWrap {
193    fn close(&mut self) -> std::io::Result<()> {
194        pl_async::get_runtime().block_in_place_on(async {
195            self.finish_active_poll().await?;
196
197            let mut writer = self.take_writer_from_ready_state().unwrap();
198            writer.finish().await?;
199
200            Ok(())
201        })
202    }
203
204    fn sync_all(&self) -> std::io::Result<()> {
205        Ok(())
206    }
207
208    fn sync_data(&self) -> std::io::Result<()> {
209        Ok(())
210    }
211}
212
213impl tokio::io::AsyncWrite for CloudWriterIoTraitWrap {
214    fn poll_write(
215        mut self: Pin<&mut Self>,
216        cx: &mut std::task::Context<'_>,
217        buf: &[u8],
218    ) -> std::task::Poll<std::io::Result<usize>> {
219        loop {
220            let offset = match ready!(self.finish_active_poll().poll_unpin(cx))? {
221                Some(PollOperation::Write { slice_ptr, written })
222                    if slice_ptr == buf.as_ptr() as usize =>
223                {
224                    written
225                },
226                Some(_) => panic!(),
227                None => 0,
228            };
229
230            let writer = self.get_writer_mut_from_ready_state().unwrap();
231
232            let offset_buf: &mut &[u8] = &mut &buf[offset..];
233
234            let full = writer.fill_buffer_from_slice(offset_buf);
235
236            if !full {
237                assert!(offset_buf.is_empty());
238                return Poll::Ready(Ok(buf.len()));
239            };
240
241            let new_offset = buf.len() - offset_buf.len();
242
243            let mut writer = self.take_writer_from_ready_state().unwrap();
244
245            self.state = WriterState::Poll(
246                Box::pin(async move {
247                    writer.flush_full_chunk().await?;
248                    Ok(WriterState::Ready(writer))
249                }),
250                PollOperation::Write {
251                    slice_ptr: buf.as_ptr() as usize,
252                    written: new_offset,
253                },
254            );
255        }
256    }
257
258    fn poll_flush(
259        mut self: Pin<&mut Self>,
260        cx: &mut std::task::Context<'_>,
261    ) -> std::task::Poll<std::io::Result<()>> {
262        loop {
263            match ready!(self.finish_active_poll().poll_unpin(cx))? {
264                Some(PollOperation::Flush) => return Poll::Ready(Ok(())),
265                Some(_) => panic!(),
266                None => {
267                    let mut writer = self.take_writer_from_ready_state().unwrap();
268
269                    self.state = WriterState::Poll(
270                        Box::pin(async move {
271                            writer.flush().await?;
272                            Ok(WriterState::Ready(writer))
273                        }),
274                        PollOperation::Flush,
275                    )
276                },
277            }
278        }
279    }
280
281    fn poll_shutdown(
282        mut self: Pin<&mut Self>,
283        cx: &mut std::task::Context<'_>,
284    ) -> std::task::Poll<std::io::Result<()>> {
285        loop {
286            match ready!(self.finish_active_poll().poll_unpin(cx))? {
287                Some(PollOperation::Shutdown) => return Poll::Ready(Ok(())),
288                Some(_) => panic!(),
289                None => {
290                    let mut writer = self.take_writer_from_ready_state().unwrap();
291
292                    self.state = WriterState::Poll(
293                        Box::pin(async move {
294                            writer.finish().await?;
295                            Ok(WriterState::Finished)
296                        }),
297                        PollOperation::Shutdown,
298                    );
299                },
300            }
301        }
302    }
303}