Skip to main content

polars_io/utils/
stream_buf_reader.rs

1use std::io::{BufRead, Cursor};
2
3use polars_buffer::Buffer;
4#[cfg(feature = "async")]
5use polars_core::runtime::ASYNC;
6use polars_error::PolarsResult;
7#[cfg(feature = "async")]
8use polars_utils::async_utils::tokio_handle_ext;
9#[cfg(feature = "async")]
10use tokio::sync::OwnedSemaphorePermit;
11
12#[cfg(feature = "async")]
13pub struct OpenReaderState {
14    receiver: tokio::sync::mpsc::Receiver<(
15        tokio_handle_ext::AbortOnDropHandle<PolarsResult<Buffer<u8>>>,
16        OwnedSemaphorePermit,
17    )>,
18    producer_task_handle: tokio_handle_ext::AbortOnDropHandle<std::io::Result<()>>,
19    current: Buffer<u8>,
20}
21
22/// `BufRead` interface for a channel that is receiving `Buffer<u8>` bytes.
23#[cfg(feature = "async")]
24pub enum StreamBufReader {
25    Open(OpenReaderState),
26    Finished,
27}
28
29#[cfg(feature = "async")]
30impl StreamBufReader {
31    pub fn new(
32        receiver: tokio::sync::mpsc::Receiver<(
33            tokio_handle_ext::AbortOnDropHandle<PolarsResult<Buffer<u8>>>,
34            OwnedSemaphorePermit,
35        )>,
36        producer_task_handle: tokio_handle_ext::AbortOnDropHandle<std::io::Result<()>>,
37    ) -> Self {
38        Self::Open(OpenReaderState {
39            receiver,
40            producer_task_handle,
41            current: Buffer::default(),
42        })
43    }
44
45    fn get_open_state(&mut self) -> Option<&mut OpenReaderState> {
46        match self {
47            Self::Open(state) => Some(state),
48            Self::Finished => None,
49        }
50    }
51
52    fn finish(&mut self) -> std::io::Result<()> {
53        let Self::Open(state) = std::mem::replace(self, Self::Finished) else {
54            return Ok(());
55        };
56
57        drop(state.receiver);
58
59        ASYNC.block_in_place_on(state.producer_task_handle)?
60    }
61}
62
63#[cfg(feature = "async")]
64impl std::io::Read for StreamBufReader {
65    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
66        let remaining = self.fill_buf()?;
67        if remaining.is_empty() {
68            return Ok(0);
69        }
70        let n = buf.len().min(remaining.len());
71        buf[..n].copy_from_slice(&remaining[..n]);
72        self.consume(n);
73        Ok(n)
74    }
75}
76
77#[cfg(feature = "async")]
78impl std::io::BufRead for StreamBufReader {
79    fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
80        let Some(state) = self.get_open_state() else {
81            return Ok(&[]);
82        };
83
84        if state.current.is_empty() {
85            match state.receiver.blocking_recv() {
86                Some((handle, _permit)) => {
87                    let fetched_bytes = ASYNC.block_in_place_on(handle).unwrap()?;
88                    state.current = fetched_bytes;
89                },
90                None => {
91                    self.finish()?;
92                    return Ok(&[]);
93                },
94            }
95        }
96
97        let Some(state) = self.get_open_state() else {
98            unreachable!();
99        };
100
101        Ok(state.current.as_ref())
102    }
103
104    fn consume(&mut self, amt: usize) {
105        if let Some(state) = self.get_open_state() {
106            state.current.slice_in_place(amt..);
107        }
108    }
109}
110
111// Supported reader sources for respectively from_memory and streaming.
112pub enum ReaderSource {
113    Memory(Cursor<Buffer<u8>>),
114    #[cfg(feature = "async")]
115    Streaming(StreamBufReader),
116}
117
118impl std::io::Read for ReaderSource {
119    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
120        match self {
121            Self::Memory(r) => r.read(buf),
122            #[cfg(feature = "async")]
123            Self::Streaming(r) => r.read(buf),
124        }
125    }
126}
127
128impl std::io::BufRead for ReaderSource {
129    fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
130        match self {
131            Self::Memory(r) => r.fill_buf(),
132            #[cfg(feature = "async")]
133            Self::Streaming(r) => r.fill_buf(),
134        }
135    }
136
137    fn consume(&mut self, amt: usize) {
138        match self {
139            Self::Memory(r) => r.consume(amt),
140            #[cfg(feature = "async")]
141            Self::Streaming(r) => r.consume(amt),
142        }
143    }
144}