Skip to main content

polars_io/utils/
stream_buf_reader.rs

1use std::io::{BufRead, Cursor};
2
3use polars_buffer::Buffer;
4use polars_error::PolarsResult;
5#[cfg(feature = "async")]
6use polars_utils::async_utils::tokio_handle_ext;
7#[cfg(feature = "async")]
8use tokio::sync::OwnedSemaphorePermit;
9
10#[cfg(feature = "async")]
11use crate::pl_async;
12
13#[cfg(feature = "async")]
14pub struct OpenReaderState {
15    receiver: tokio::sync::mpsc::Receiver<(
16        tokio_handle_ext::AbortOnDropHandle<PolarsResult<Buffer<u8>>>,
17        OwnedSemaphorePermit,
18    )>,
19    producer_task_handle: tokio_handle_ext::AbortOnDropHandle<std::io::Result<()>>,
20    current: Buffer<u8>,
21}
22
23/// `BufRead` interface for a channel that is receiving `Buffer<u8>` bytes.
24#[cfg(feature = "async")]
25pub enum StreamBufReader {
26    Open(OpenReaderState),
27    Finished,
28}
29
30#[cfg(feature = "async")]
31impl StreamBufReader {
32    pub fn new(
33        receiver: tokio::sync::mpsc::Receiver<(
34            tokio_handle_ext::AbortOnDropHandle<PolarsResult<Buffer<u8>>>,
35            OwnedSemaphorePermit,
36        )>,
37        producer_task_handle: tokio_handle_ext::AbortOnDropHandle<std::io::Result<()>>,
38    ) -> Self {
39        Self::Open(OpenReaderState {
40            receiver,
41            producer_task_handle,
42            current: Buffer::default(),
43        })
44    }
45
46    fn get_open_state(&mut self) -> Option<&mut OpenReaderState> {
47        match self {
48            Self::Open(state) => Some(state),
49            Self::Finished => None,
50        }
51    }
52
53    fn finish(&mut self) -> std::io::Result<()> {
54        let Self::Open(state) = std::mem::replace(self, Self::Finished) else {
55            return Ok(());
56        };
57
58        drop(state.receiver);
59
60        pl_async::get_runtime().block_in_place_on(state.producer_task_handle)?
61    }
62}
63
64#[cfg(feature = "async")]
65impl std::io::Read for StreamBufReader {
66    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
67        let remaining = self.fill_buf()?;
68        if remaining.is_empty() {
69            return Ok(0);
70        }
71        let n = buf.len().min(remaining.len());
72        buf[..n].copy_from_slice(&remaining[..n]);
73        self.consume(n);
74        Ok(n)
75    }
76}
77
78#[cfg(feature = "async")]
79impl std::io::BufRead for StreamBufReader {
80    fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
81        let Some(state) = self.get_open_state() else {
82            return Ok(&[]);
83        };
84
85        if state.current.is_empty() {
86            match state.receiver.blocking_recv() {
87                Some((handle, _permit)) => {
88                    let fetched_bytes =
89                        pl_async::get_runtime().block_in_place_on(handle).unwrap()?;
90                    state.current = fetched_bytes;
91                },
92                None => {
93                    self.finish()?;
94                    return Ok(&[]);
95                },
96            }
97        }
98
99        let Some(state) = self.get_open_state() else {
100            unreachable!();
101        };
102
103        Ok(state.current.as_ref())
104    }
105
106    fn consume(&mut self, amt: usize) {
107        if let Some(state) = self.get_open_state() {
108            state.current.slice_in_place(amt..);
109        }
110    }
111}
112
113// Supported reader sources for respectively from_memory and streaming.
114pub enum ReaderSource {
115    Memory(Cursor<Buffer<u8>>),
116    #[cfg(feature = "async")]
117    Streaming(StreamBufReader),
118}
119
120impl std::io::Read for ReaderSource {
121    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
122        match self {
123            Self::Memory(r) => r.read(buf),
124            #[cfg(feature = "async")]
125            Self::Streaming(r) => r.read(buf),
126        }
127    }
128}
129
130impl std::io::BufRead for ReaderSource {
131    fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
132        match self {
133            Self::Memory(r) => r.fill_buf(),
134            #[cfg(feature = "async")]
135            Self::Streaming(r) => r.fill_buf(),
136        }
137    }
138
139    fn consume(&mut self, amt: usize) {
140        match self {
141            Self::Memory(r) => r.consume(amt),
142            #[cfg(feature = "async")]
143            Self::Streaming(r) => r.consume(amt),
144        }
145    }
146}