polars_io/utils/
stream_buf_reader.rs1use std::io::{BufRead, Cursor};
2
3use polars_buffer::Buffer;
4use polars_error::PolarsResult;
5#[cfg(feature = "async")]
6use polars_utils::async_utils::tokio_handle_ext;
7#[cfg(feature = "async")]
8use tokio::sync::OwnedSemaphorePermit;
9
10#[cfg(feature = "async")]
11use crate::pl_async;
12
13#[cfg(feature = "async")]
14pub struct OpenReaderState {
15 receiver: tokio::sync::mpsc::Receiver<(
16 tokio_handle_ext::AbortOnDropHandle<PolarsResult<Buffer<u8>>>,
17 OwnedSemaphorePermit,
18 )>,
19 producer_task_handle: tokio_handle_ext::AbortOnDropHandle<std::io::Result<()>>,
20 current: Buffer<u8>,
21}
22
23#[cfg(feature = "async")]
25pub enum StreamBufReader {
26 Open(OpenReaderState),
27 Finished,
28}
29
30#[cfg(feature = "async")]
31impl StreamBufReader {
32 pub fn new(
33 receiver: tokio::sync::mpsc::Receiver<(
34 tokio_handle_ext::AbortOnDropHandle<PolarsResult<Buffer<u8>>>,
35 OwnedSemaphorePermit,
36 )>,
37 producer_task_handle: tokio_handle_ext::AbortOnDropHandle<std::io::Result<()>>,
38 ) -> Self {
39 Self::Open(OpenReaderState {
40 receiver,
41 producer_task_handle,
42 current: Buffer::default(),
43 })
44 }
45
46 fn get_open_state(&mut self) -> Option<&mut OpenReaderState> {
47 match self {
48 Self::Open(state) => Some(state),
49 Self::Finished => None,
50 }
51 }
52
53 fn finish(&mut self) -> std::io::Result<()> {
54 let Self::Open(state) = std::mem::replace(self, Self::Finished) else {
55 return Ok(());
56 };
57
58 drop(state.receiver);
59
60 pl_async::get_runtime().block_in_place_on(state.producer_task_handle)?
61 }
62}
63
64#[cfg(feature = "async")]
65impl std::io::Read for StreamBufReader {
66 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
67 let remaining = self.fill_buf()?;
68 if remaining.is_empty() {
69 return Ok(0);
70 }
71 let n = buf.len().min(remaining.len());
72 buf[..n].copy_from_slice(&remaining[..n]);
73 self.consume(n);
74 Ok(n)
75 }
76}
77
78#[cfg(feature = "async")]
79impl std::io::BufRead for StreamBufReader {
80 fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
81 let Some(state) = self.get_open_state() else {
82 return Ok(&[]);
83 };
84
85 if state.current.is_empty() {
86 match state.receiver.blocking_recv() {
87 Some((handle, _permit)) => {
88 let fetched_bytes =
89 pl_async::get_runtime().block_in_place_on(handle).unwrap()?;
90 state.current = fetched_bytes;
91 },
92 None => {
93 self.finish()?;
94 return Ok(&[]);
95 },
96 }
97 }
98
99 let Some(state) = self.get_open_state() else {
100 unreachable!();
101 };
102
103 Ok(state.current.as_ref())
104 }
105
106 fn consume(&mut self, amt: usize) {
107 if let Some(state) = self.get_open_state() {
108 state.current.slice_in_place(amt..);
109 }
110 }
111}
112
113pub enum ReaderSource {
115 Memory(Cursor<Buffer<u8>>),
116 #[cfg(feature = "async")]
117 Streaming(StreamBufReader),
118}
119
120impl std::io::Read for ReaderSource {
121 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
122 match self {
123 Self::Memory(r) => r.read(buf),
124 #[cfg(feature = "async")]
125 Self::Streaming(r) => r.read(buf),
126 }
127 }
128}
129
130impl std::io::BufRead for ReaderSource {
131 fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
132 match self {
133 Self::Memory(r) => r.fill_buf(),
134 #[cfg(feature = "async")]
135 Self::Streaming(r) => r.fill_buf(),
136 }
137 }
138
139 fn consume(&mut self, amt: usize) {
140 match self {
141 Self::Memory(r) => r.consume(amt),
142 #[cfg(feature = "async")]
143 Self::Streaming(r) => r.consume(amt),
144 }
145 }
146}