polars_io/cloud/cloud_writer/
io_trait_wrap.rs1use std::pin::Pin;
2use std::task::{Poll, ready};
3
4use bytes::Bytes;
5use futures::FutureExt;
6
7use crate::cloud::cloud_writer::CloudWriter;
8use crate::pl_async;
9use crate::utils::file::WriteableTrait;
10
11pub struct CloudWriterIoTraitWrap {
13 state: WriterState,
14}
15
16enum WriterState {
17 Ready(Box<CloudWriter>),
18 Poll(
19 Pin<Box<dyn Future<Output = std::io::Result<WriterState>> + Send + 'static>>,
20 PollOperation,
21 ),
22 Finished,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26enum PollOperation {
27 Write { slice_ptr: usize, written: usize },
29 Flush,
30 Shutdown,
31}
32
33struct FinishActivePoll<'a>(Pin<&'a mut WriterState>);
34
35impl<'a> Future for FinishActivePoll<'a> {
36 type Output = std::io::Result<Option<PollOperation>>;
37
38 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
39 match &mut *self.0 {
40 WriterState::Poll(fut, _) => match fut.poll_unpin(cx) {
41 Poll::Ready(Ok(new_state)) => {
42 debug_assert!(!matches!(&new_state, WriterState::Poll(..)));
43
44 let WriterState::Poll(_, operation) =
45 std::mem::replace(&mut *self.0, new_state)
46 else {
47 unreachable!()
48 };
49
50 Poll::Ready(Ok(Some(operation)))
51 },
52 Poll::Ready(Err(e)) => {
53 *self.0 = WriterState::Finished;
54 Poll::Ready(Err(e))
55 },
56 Poll::Pending => Poll::Pending,
57 },
58
59 WriterState::Ready(_) | WriterState::Finished => Poll::Ready(Ok(None)),
60 }
61 }
62}
63
64impl CloudWriterIoTraitWrap {
65 fn finish_active_poll(&mut self) -> FinishActivePoll<'_> {
66 FinishActivePoll(Pin::new(&mut self.state))
67 }
68
69 fn take_writer_from_ready_state(&mut self) -> Option<Box<CloudWriter>> {
70 if !matches!(&self.state, WriterState::Ready(_)) {
71 return None;
72 }
73
74 let WriterState::Ready(writer) = std::mem::replace(&mut self.state, WriterState::Finished)
75 else {
76 unreachable!()
77 };
78
79 Some(writer)
80 }
81
82 fn get_writer_mut_from_ready_state(&mut self) -> Option<&mut CloudWriter> {
83 if let WriterState::Ready(writer) = &mut self.state {
84 Some(writer.as_mut())
85 } else {
86 None
87 }
88 }
89
90 pub async fn write_all_owned(&mut self, bytes: Bytes) -> std::io::Result<()> {
91 self.finish_active_poll().await?;
92
93 self.get_writer_mut_from_ready_state()
94 .unwrap()
95 .write_all_owned(bytes)
96 .await?;
97
98 Ok(())
99 }
100
101 pub async fn into_cloud_writer(mut self) -> std::io::Result<CloudWriter> {
102 self.finish_active_poll().await?;
103
104 match self.state {
105 WriterState::Ready(writer) => Ok(*writer),
106 WriterState::Poll(..) => unreachable!(),
107 WriterState::Finished => panic!(),
108 }
109 }
110
111 pub fn as_cloud_writer(&mut self) -> std::io::Result<&mut CloudWriter> {
112 if !matches!(self.state, WriterState::Ready(_)) {
113 match &mut self.state {
114 WriterState::Ready(_) => unreachable!(),
115 WriterState::Poll(..) => {
116 pl_async::get_runtime().block_in_place_on(self.finish_active_poll())?
117 },
118 WriterState::Finished => panic!(),
119 };
120 }
121
122 let WriterState::Ready(writer) = &mut self.state else {
123 panic!()
124 };
125
126 Ok(writer)
127 }
128}
129
130impl From<CloudWriter> for CloudWriterIoTraitWrap {
131 fn from(writer: CloudWriter) -> Self {
132 Self {
133 state: WriterState::Ready(Box::new(writer)),
134 }
135 }
136}
137
138impl std::io::Write for CloudWriterIoTraitWrap {
139 fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
140 let total_buf_len = buf.len();
141 let buf: &mut &[u8] = &mut buf;
142
143 if let Some(writer) = self.get_writer_mut_from_ready_state() {
144 let full = writer.fill_buffer_from_slice(buf);
145
146 if !full {
147 assert!(buf.is_empty());
148 return Ok(total_buf_len);
149 }
150 }
151
152 pl_async::get_runtime().block_in_place_on(async {
153 self.finish_active_poll().await?;
154
155 let writer = self.get_writer_mut_from_ready_state().unwrap();
156
157 loop {
158 writer.flush_full_chunk().await?;
159
160 if !writer.fill_buffer_from_slice(buf) {
161 break;
162 }
163 }
164
165 assert!(buf.is_empty());
166
167 Ok(total_buf_len)
168 })
169 }
170
171 fn flush(&mut self) -> std::io::Result<()> {
172 if self
173 .get_writer_mut_from_ready_state()
174 .is_some_and(|w| !w.has_buffered_bytes())
175 {
176 return Ok(());
177 }
178
179 pl_async::get_runtime().block_in_place_on(async {
180 self.finish_active_poll().await?;
181
182 self.get_writer_mut_from_ready_state()
183 .unwrap()
184 .flush()
185 .await?;
186
187 Ok(())
188 })
189 }
190}
191
192impl WriteableTrait for CloudWriterIoTraitWrap {
193 fn close(&mut self) -> std::io::Result<()> {
194 pl_async::get_runtime().block_in_place_on(async {
195 self.finish_active_poll().await?;
196
197 let mut writer = self.take_writer_from_ready_state().unwrap();
198 writer.finish().await?;
199
200 Ok(())
201 })
202 }
203
204 fn sync_all(&self) -> std::io::Result<()> {
205 Ok(())
206 }
207
208 fn sync_data(&self) -> std::io::Result<()> {
209 Ok(())
210 }
211}
212
213impl tokio::io::AsyncWrite for CloudWriterIoTraitWrap {
214 fn poll_write(
215 mut self: Pin<&mut Self>,
216 cx: &mut std::task::Context<'_>,
217 buf: &[u8],
218 ) -> std::task::Poll<std::io::Result<usize>> {
219 loop {
220 let offset = match ready!(self.finish_active_poll().poll_unpin(cx))? {
221 Some(PollOperation::Write { slice_ptr, written })
222 if slice_ptr == buf.as_ptr() as usize =>
223 {
224 written
225 },
226 Some(_) => panic!(),
227 None => 0,
228 };
229
230 let writer = self.get_writer_mut_from_ready_state().unwrap();
231
232 let offset_buf: &mut &[u8] = &mut &buf[offset..];
233
234 let full = writer.fill_buffer_from_slice(offset_buf);
235
236 if !full {
237 assert!(offset_buf.is_empty());
238 return Poll::Ready(Ok(buf.len()));
239 };
240
241 let new_offset = buf.len() - offset_buf.len();
242
243 let mut writer = self.take_writer_from_ready_state().unwrap();
244
245 self.state = WriterState::Poll(
246 Box::pin(async move {
247 writer.flush_full_chunk().await?;
248 Ok(WriterState::Ready(writer))
249 }),
250 PollOperation::Write {
251 slice_ptr: buf.as_ptr() as usize,
252 written: new_offset,
253 },
254 );
255 }
256 }
257
258 fn poll_flush(
259 mut self: Pin<&mut Self>,
260 cx: &mut std::task::Context<'_>,
261 ) -> std::task::Poll<std::io::Result<()>> {
262 loop {
263 match ready!(self.finish_active_poll().poll_unpin(cx))? {
264 Some(PollOperation::Flush) => return Poll::Ready(Ok(())),
265 Some(_) => panic!(),
266 None => {
267 let mut writer = self.take_writer_from_ready_state().unwrap();
268
269 self.state = WriterState::Poll(
270 Box::pin(async move {
271 writer.flush().await?;
272 Ok(WriterState::Ready(writer))
273 }),
274 PollOperation::Flush,
275 )
276 },
277 }
278 }
279 }
280
281 fn poll_shutdown(
282 mut self: Pin<&mut Self>,
283 cx: &mut std::task::Context<'_>,
284 ) -> std::task::Poll<std::io::Result<()>> {
285 loop {
286 match ready!(self.finish_active_poll().poll_unpin(cx))? {
287 Some(PollOperation::Shutdown) => return Poll::Ready(Ok(())),
288 Some(_) => panic!(),
289 None => {
290 let mut writer = self.take_writer_from_ready_state().unwrap();
291
292 self.state = WriterState::Poll(
293 Box::pin(async move {
294 writer.finish().await?;
295 Ok(WriterState::Finished)
296 }),
297 PollOperation::Shutdown,
298 );
299 },
300 }
301 }
302 }
303}