Skip to main content

polars_io/cloud/cloud_writer/
io_trait_wrap.rs

1use std::pin::Pin;
2use std::task::{Poll, ready};
3
4use bytes::Bytes;
5use futures::FutureExt;
6use polars_core::runtime::ASYNC;
7
8use crate::cloud::cloud_writer::CloudWriter;
9use crate::utils::file::WriteableTrait;
10
11/// Wrapper on [`CloudWriter`] that implements [`std::io::Write`] and [`tokio::io::AsyncWrite`].
12pub struct CloudWriterIoTraitWrap {
13    state: WriterState,
14}
15
16enum WriterState {
17    Ready(Box<CloudWriter>),
18    Poll(
19        Pin<Box<dyn Future<Output = std::io::Result<WriterState>> + Send + 'static>>,
20        PollOperation,
21    ),
22    Finished,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26enum PollOperation {
27    // (slice_addr, slice_len)
28    Write { slice_ptr: usize, written: usize },
29    Flush,
30    Shutdown,
31}
32
33struct FinishActivePoll<'a>(Pin<&'a mut WriterState>);
34
35impl<'a> Future for FinishActivePoll<'a> {
36    type Output = std::io::Result<Option<PollOperation>>;
37
38    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
39        match &mut *self.0 {
40            WriterState::Poll(fut, _) => match fut.poll_unpin(cx) {
41                Poll::Ready(Ok(new_state)) => {
42                    debug_assert!(!matches!(&new_state, WriterState::Poll(..)));
43
44                    let WriterState::Poll(_, operation) =
45                        std::mem::replace(&mut *self.0, new_state)
46                    else {
47                        unreachable!()
48                    };
49
50                    Poll::Ready(Ok(Some(operation)))
51                },
52                Poll::Ready(Err(e)) => {
53                    *self.0 = WriterState::Finished;
54                    Poll::Ready(Err(e))
55                },
56                Poll::Pending => Poll::Pending,
57            },
58
59            WriterState::Ready(_) | WriterState::Finished => Poll::Ready(Ok(None)),
60        }
61    }
62}
63
64impl CloudWriterIoTraitWrap {
65    fn finish_active_poll(&mut self) -> FinishActivePoll<'_> {
66        FinishActivePoll(Pin::new(&mut self.state))
67    }
68
69    fn take_writer_from_ready_state(&mut self) -> Option<Box<CloudWriter>> {
70        if !matches!(&self.state, WriterState::Ready(_)) {
71            return None;
72        }
73
74        let WriterState::Ready(writer) = std::mem::replace(&mut self.state, WriterState::Finished)
75        else {
76            unreachable!()
77        };
78
79        Some(writer)
80    }
81
82    fn get_writer_mut_from_ready_state(&mut self) -> Option<&mut CloudWriter> {
83        if let WriterState::Ready(writer) = &mut self.state {
84            Some(writer.as_mut())
85        } else {
86            None
87        }
88    }
89
90    pub async fn write_all_owned(&mut self, bytes: Bytes) -> std::io::Result<()> {
91        self.finish_active_poll().await?;
92
93        self.get_writer_mut_from_ready_state()
94            .unwrap()
95            .write_all_owned(bytes)
96            .await?;
97
98        Ok(())
99    }
100
101    pub async fn into_cloud_writer(mut self) -> std::io::Result<CloudWriter> {
102        self.finish_active_poll().await?;
103
104        match self.state {
105            WriterState::Ready(writer) => Ok(*writer),
106            WriterState::Poll(..) => unreachable!(),
107            WriterState::Finished => panic!(),
108        }
109    }
110
111    pub fn as_cloud_writer(&mut self) -> std::io::Result<&mut CloudWriter> {
112        if !matches!(self.state, WriterState::Ready(_)) {
113            match &mut self.state {
114                WriterState::Ready(_) => unreachable!(),
115                WriterState::Poll(..) => ASYNC.block_in_place_on(self.finish_active_poll())?,
116                WriterState::Finished => panic!(),
117            };
118        }
119
120        let WriterState::Ready(writer) = &mut self.state else {
121            panic!()
122        };
123
124        Ok(writer)
125    }
126}
127
128impl From<CloudWriter> for CloudWriterIoTraitWrap {
129    fn from(writer: CloudWriter) -> Self {
130        Self {
131            state: WriterState::Ready(Box::new(writer)),
132        }
133    }
134}
135
136impl std::io::Write for CloudWriterIoTraitWrap {
137    fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
138        let total_buf_len = buf.len();
139        let buf: &mut &[u8] = &mut buf;
140
141        if let Some(writer) = self.get_writer_mut_from_ready_state() {
142            let full = writer.fill_buffer_from_slice(buf);
143
144            if !full {
145                assert!(buf.is_empty());
146                return Ok(total_buf_len);
147            }
148        }
149
150        ASYNC.block_in_place_on(async {
151            self.finish_active_poll().await?;
152
153            let writer = self.get_writer_mut_from_ready_state().unwrap();
154
155            loop {
156                writer.flush_full_chunk().await?;
157
158                if !writer.fill_buffer_from_slice(buf) {
159                    break;
160                }
161            }
162
163            assert!(buf.is_empty());
164
165            Ok(total_buf_len)
166        })
167    }
168
169    fn flush(&mut self) -> std::io::Result<()> {
170        if self
171            .get_writer_mut_from_ready_state()
172            .is_some_and(|w| !w.has_buffered_bytes())
173        {
174            return Ok(());
175        }
176
177        ASYNC.block_in_place_on(async {
178            self.finish_active_poll().await?;
179
180            self.get_writer_mut_from_ready_state()
181                .unwrap()
182                .flush()
183                .await?;
184
185            Ok(())
186        })
187    }
188}
189
190impl WriteableTrait for CloudWriterIoTraitWrap {
191    fn close(&mut self) -> std::io::Result<()> {
192        ASYNC.block_in_place_on(async {
193            self.finish_active_poll().await?;
194
195            let mut writer = self.take_writer_from_ready_state().unwrap();
196            writer.finish().await?;
197
198            Ok(())
199        })
200    }
201
202    fn sync_all(&self) -> std::io::Result<()> {
203        Ok(())
204    }
205
206    fn sync_data(&self) -> std::io::Result<()> {
207        Ok(())
208    }
209}
210
211impl tokio::io::AsyncWrite for CloudWriterIoTraitWrap {
212    fn poll_write(
213        mut self: Pin<&mut Self>,
214        cx: &mut std::task::Context<'_>,
215        buf: &[u8],
216    ) -> std::task::Poll<std::io::Result<usize>> {
217        loop {
218            let offset = match ready!(self.finish_active_poll().poll_unpin(cx))? {
219                Some(PollOperation::Write { slice_ptr, written })
220                    if slice_ptr == buf.as_ptr() as usize =>
221                {
222                    written
223                },
224                Some(_) => panic!(),
225                None => 0,
226            };
227
228            let writer = self.get_writer_mut_from_ready_state().unwrap();
229
230            let offset_buf: &mut &[u8] = &mut &buf[offset..];
231
232            let full = writer.fill_buffer_from_slice(offset_buf);
233
234            if !full {
235                assert!(offset_buf.is_empty());
236                return Poll::Ready(Ok(buf.len()));
237            };
238
239            let new_offset = buf.len() - offset_buf.len();
240
241            let mut writer = self.take_writer_from_ready_state().unwrap();
242
243            self.state = WriterState::Poll(
244                Box::pin(async move {
245                    writer.flush_full_chunk().await?;
246                    Ok(WriterState::Ready(writer))
247                }),
248                PollOperation::Write {
249                    slice_ptr: buf.as_ptr() as usize,
250                    written: new_offset,
251                },
252            );
253        }
254    }
255
256    fn poll_flush(
257        mut self: Pin<&mut Self>,
258        cx: &mut std::task::Context<'_>,
259    ) -> std::task::Poll<std::io::Result<()>> {
260        loop {
261            match ready!(self.finish_active_poll().poll_unpin(cx))? {
262                Some(PollOperation::Flush) => return Poll::Ready(Ok(())),
263                Some(_) => panic!(),
264                None => {
265                    let mut writer = self.take_writer_from_ready_state().unwrap();
266
267                    self.state = WriterState::Poll(
268                        Box::pin(async move {
269                            writer.flush().await?;
270                            Ok(WriterState::Ready(writer))
271                        }),
272                        PollOperation::Flush,
273                    )
274                },
275            }
276        }
277    }
278
279    fn poll_shutdown(
280        mut self: Pin<&mut Self>,
281        cx: &mut std::task::Context<'_>,
282    ) -> std::task::Poll<std::io::Result<()>> {
283        loop {
284            match ready!(self.finish_active_poll().poll_unpin(cx))? {
285                Some(PollOperation::Shutdown) => return Poll::Ready(Ok(())),
286                Some(_) => panic!(),
287                None => {
288                    let mut writer = self.take_writer_from_ready_state().unwrap();
289
290                    self.state = WriterState::Poll(
291                        Box::pin(async move {
292                            writer.finish().await?;
293                            Ok(WriterState::Finished)
294                        }),
295                        PollOperation::Shutdown,
296                    );
297                },
298            }
299        }
300    }
301}