polars_io/cloud/cloud_writer/
io_trait_wrap.rs1use std::pin::Pin;
2use std::task::{Poll, ready};
3
4use bytes::Bytes;
5use futures::FutureExt;
6use polars_core::runtime::ASYNC;
7
8use crate::cloud::cloud_writer::CloudWriter;
9use crate::utils::file::WriteableTrait;
10
11pub struct CloudWriterIoTraitWrap {
13 state: WriterState,
14}
15
16enum WriterState {
17 Ready(Box<CloudWriter>),
18 Poll(
19 Pin<Box<dyn Future<Output = std::io::Result<WriterState>> + Send + 'static>>,
20 PollOperation,
21 ),
22 Finished,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26enum PollOperation {
27 Write { slice_ptr: usize, written: usize },
29 Flush,
30 Shutdown,
31}
32
33struct FinishActivePoll<'a>(Pin<&'a mut WriterState>);
34
35impl<'a> Future for FinishActivePoll<'a> {
36 type Output = std::io::Result<Option<PollOperation>>;
37
38 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
39 match &mut *self.0 {
40 WriterState::Poll(fut, _) => match fut.poll_unpin(cx) {
41 Poll::Ready(Ok(new_state)) => {
42 debug_assert!(!matches!(&new_state, WriterState::Poll(..)));
43
44 let WriterState::Poll(_, operation) =
45 std::mem::replace(&mut *self.0, new_state)
46 else {
47 unreachable!()
48 };
49
50 Poll::Ready(Ok(Some(operation)))
51 },
52 Poll::Ready(Err(e)) => {
53 *self.0 = WriterState::Finished;
54 Poll::Ready(Err(e))
55 },
56 Poll::Pending => Poll::Pending,
57 },
58
59 WriterState::Ready(_) | WriterState::Finished => Poll::Ready(Ok(None)),
60 }
61 }
62}
63
64impl CloudWriterIoTraitWrap {
65 fn finish_active_poll(&mut self) -> FinishActivePoll<'_> {
66 FinishActivePoll(Pin::new(&mut self.state))
67 }
68
69 fn take_writer_from_ready_state(&mut self) -> Option<Box<CloudWriter>> {
70 if !matches!(&self.state, WriterState::Ready(_)) {
71 return None;
72 }
73
74 let WriterState::Ready(writer) = std::mem::replace(&mut self.state, WriterState::Finished)
75 else {
76 unreachable!()
77 };
78
79 Some(writer)
80 }
81
82 fn get_writer_mut_from_ready_state(&mut self) -> Option<&mut CloudWriter> {
83 if let WriterState::Ready(writer) = &mut self.state {
84 Some(writer.as_mut())
85 } else {
86 None
87 }
88 }
89
90 pub async fn write_all_owned(&mut self, bytes: Bytes) -> std::io::Result<()> {
91 self.finish_active_poll().await?;
92
93 self.get_writer_mut_from_ready_state()
94 .unwrap()
95 .write_all_owned(bytes)
96 .await?;
97
98 Ok(())
99 }
100
101 pub async fn into_cloud_writer(mut self) -> std::io::Result<CloudWriter> {
102 self.finish_active_poll().await?;
103
104 match self.state {
105 WriterState::Ready(writer) => Ok(*writer),
106 WriterState::Poll(..) => unreachable!(),
107 WriterState::Finished => panic!(),
108 }
109 }
110
111 pub fn as_cloud_writer(&mut self) -> std::io::Result<&mut CloudWriter> {
112 if !matches!(self.state, WriterState::Ready(_)) {
113 match &mut self.state {
114 WriterState::Ready(_) => unreachable!(),
115 WriterState::Poll(..) => ASYNC.block_in_place_on(self.finish_active_poll())?,
116 WriterState::Finished => panic!(),
117 };
118 }
119
120 let WriterState::Ready(writer) = &mut self.state else {
121 panic!()
122 };
123
124 Ok(writer)
125 }
126}
127
128impl From<CloudWriter> for CloudWriterIoTraitWrap {
129 fn from(writer: CloudWriter) -> Self {
130 Self {
131 state: WriterState::Ready(Box::new(writer)),
132 }
133 }
134}
135
136impl std::io::Write for CloudWriterIoTraitWrap {
137 fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
138 let total_buf_len = buf.len();
139 let buf: &mut &[u8] = &mut buf;
140
141 if let Some(writer) = self.get_writer_mut_from_ready_state() {
142 let full = writer.fill_buffer_from_slice(buf);
143
144 if !full {
145 assert!(buf.is_empty());
146 return Ok(total_buf_len);
147 }
148 }
149
150 ASYNC.block_in_place_on(async {
151 self.finish_active_poll().await?;
152
153 let writer = self.get_writer_mut_from_ready_state().unwrap();
154
155 loop {
156 writer.flush_full_chunk().await?;
157
158 if !writer.fill_buffer_from_slice(buf) {
159 break;
160 }
161 }
162
163 assert!(buf.is_empty());
164
165 Ok(total_buf_len)
166 })
167 }
168
169 fn flush(&mut self) -> std::io::Result<()> {
170 if self
171 .get_writer_mut_from_ready_state()
172 .is_some_and(|w| !w.has_buffered_bytes())
173 {
174 return Ok(());
175 }
176
177 ASYNC.block_in_place_on(async {
178 self.finish_active_poll().await?;
179
180 self.get_writer_mut_from_ready_state()
181 .unwrap()
182 .flush()
183 .await?;
184
185 Ok(())
186 })
187 }
188}
189
190impl WriteableTrait for CloudWriterIoTraitWrap {
191 fn close(&mut self) -> std::io::Result<()> {
192 ASYNC.block_in_place_on(async {
193 self.finish_active_poll().await?;
194
195 let mut writer = self.take_writer_from_ready_state().unwrap();
196 writer.finish().await?;
197
198 Ok(())
199 })
200 }
201
202 fn sync_all(&self) -> std::io::Result<()> {
203 Ok(())
204 }
205
206 fn sync_data(&self) -> std::io::Result<()> {
207 Ok(())
208 }
209}
210
211impl tokio::io::AsyncWrite for CloudWriterIoTraitWrap {
212 fn poll_write(
213 mut self: Pin<&mut Self>,
214 cx: &mut std::task::Context<'_>,
215 buf: &[u8],
216 ) -> std::task::Poll<std::io::Result<usize>> {
217 loop {
218 let offset = match ready!(self.finish_active_poll().poll_unpin(cx))? {
219 Some(PollOperation::Write { slice_ptr, written })
220 if slice_ptr == buf.as_ptr() as usize =>
221 {
222 written
223 },
224 Some(_) => panic!(),
225 None => 0,
226 };
227
228 let writer = self.get_writer_mut_from_ready_state().unwrap();
229
230 let offset_buf: &mut &[u8] = &mut &buf[offset..];
231
232 let full = writer.fill_buffer_from_slice(offset_buf);
233
234 if !full {
235 assert!(offset_buf.is_empty());
236 return Poll::Ready(Ok(buf.len()));
237 };
238
239 let new_offset = buf.len() - offset_buf.len();
240
241 let mut writer = self.take_writer_from_ready_state().unwrap();
242
243 self.state = WriterState::Poll(
244 Box::pin(async move {
245 writer.flush_full_chunk().await?;
246 Ok(WriterState::Ready(writer))
247 }),
248 PollOperation::Write {
249 slice_ptr: buf.as_ptr() as usize,
250 written: new_offset,
251 },
252 );
253 }
254 }
255
256 fn poll_flush(
257 mut self: Pin<&mut Self>,
258 cx: &mut std::task::Context<'_>,
259 ) -> std::task::Poll<std::io::Result<()>> {
260 loop {
261 match ready!(self.finish_active_poll().poll_unpin(cx))? {
262 Some(PollOperation::Flush) => return Poll::Ready(Ok(())),
263 Some(_) => panic!(),
264 None => {
265 let mut writer = self.take_writer_from_ready_state().unwrap();
266
267 self.state = WriterState::Poll(
268 Box::pin(async move {
269 writer.flush().await?;
270 Ok(WriterState::Ready(writer))
271 }),
272 PollOperation::Flush,
273 )
274 },
275 }
276 }
277 }
278
279 fn poll_shutdown(
280 mut self: Pin<&mut Self>,
281 cx: &mut std::task::Context<'_>,
282 ) -> std::task::Poll<std::io::Result<()>> {
283 loop {
284 match ready!(self.finish_active_poll().poll_unpin(cx))? {
285 Some(PollOperation::Shutdown) => return Poll::Ready(Ok(())),
286 Some(_) => panic!(),
287 None => {
288 let mut writer = self.take_writer_from_ready_state().unwrap();
289
290 self.state = WriterState::Poll(
291 Box::pin(async move {
292 writer.finish().await?;
293 Ok(WriterState::Finished)
294 }),
295 PollOperation::Shutdown,
296 );
297 },
298 }
299 }
300 }
301}