polars_io/
shared.rs

1use std::io::{Read, Write};
2use std::sync::Arc;
3
4use arrow::record_batch::RecordBatch;
5use polars_core::prelude::*;
6
7use crate::options::RowIndex;
8#[cfg(any(feature = "ipc", feature = "avro", feature = "ipc_streaming",))]
9use crate::predicates::PhysicalIoExpr;
10
11pub trait SerReader<R>
12where
13    R: Read,
14{
15    /// Create a new instance of the [`SerReader`]
16    fn new(reader: R) -> Self;
17
18    /// Make sure that all columns are contiguous in memory by
19    /// aggregating the chunks into a single array.
20    #[must_use]
21    fn set_rechunk(self, _rechunk: bool) -> Self
22    where
23        Self: Sized,
24    {
25        self
26    }
27
28    /// Take the SerReader and return a parsed DataFrame.
29    fn finish(self) -> PolarsResult<DataFrame>;
30}
31
32pub trait SerWriter<W>
33where
34    W: Write,
35{
36    fn new(writer: W) -> Self
37    where
38        Self: Sized;
39    fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()>;
40}
41
42pub trait ArrowReader {
43    fn next_record_batch(&mut self) -> PolarsResult<Option<RecordBatch>>;
44}
45
46#[cfg(any(feature = "ipc", feature = "avro", feature = "ipc_streaming",))]
47pub(crate) fn finish_reader<R: ArrowReader>(
48    mut reader: R,
49    rechunk: bool,
50    n_rows: Option<usize>,
51    predicate: Option<Arc<dyn PhysicalIoExpr>>,
52    arrow_schema: &ArrowSchema,
53    row_index: Option<RowIndex>,
54) -> PolarsResult<DataFrame> {
55    use polars_core::utils::accumulate_dataframes_vertical_unchecked;
56
57    let mut num_rows = 0;
58    let mut parsed_dfs = Vec::with_capacity(1024);
59
60    while let Some(batch) = reader.next_record_batch()? {
61        let current_num_rows = num_rows as IdxSize;
62        num_rows += batch.len();
63        let mut df = DataFrame::from(batch);
64
65        if let Some(rc) = &row_index {
66            unsafe { df.with_row_index_mut(rc.name.clone(), Some(current_num_rows + rc.offset)) };
67        }
68
69        if let Some(predicate) = &predicate {
70            let s = predicate.evaluate_io(&df)?;
71            let mask = s.bool().expect("filter predicates was not of type boolean");
72            df = df.filter(mask)?;
73        }
74
75        if let Some(n) = n_rows {
76            if num_rows >= n {
77                let len = n - parsed_dfs
78                    .iter()
79                    .map(|df: &DataFrame| df.height())
80                    .sum::<usize>();
81                if polars_core::config::verbose() {
82                    eprintln!(
83                        "sliced off {} rows of the 'DataFrame'. These lines were read because they were in a single chunk.",
84                        df.height().saturating_sub(n)
85                    )
86                }
87                parsed_dfs.push(df.slice(0, len));
88                break;
89            }
90        }
91        parsed_dfs.push(df);
92    }
93
94    let mut df = {
95        if parsed_dfs.is_empty() {
96            DataFrame::empty_with_schema(&Schema::from_arrow_schema(arrow_schema))
97        } else {
98            // If there are any rows, accumulate them into a df
99            accumulate_dataframes_vertical_unchecked(parsed_dfs)
100        }
101    };
102
103    if rechunk {
104        df.rechunk_mut_par();
105    }
106    Ok(df)
107}
108
109pub fn schema_to_arrow_checked(
110    schema: &Schema,
111    compat_level: CompatLevel,
112    _file_name: &str,
113) -> PolarsResult<ArrowSchema> {
114    schema
115        .iter_fields()
116        .map(|field| {
117            #[cfg(feature = "object")]
118            {
119                polars_ensure!(
120                    !matches!(field.dtype(), DataType::Object(_)),
121                    ComputeError: "cannot write 'Object' datatype to {}",
122                    _file_name
123                );
124            }
125
126            let field = field
127                .dtype()
128                .to_arrow_field(field.name().clone(), compat_level);
129            Ok((field.name.clone(), field))
130        })
131        .collect::<PolarsResult<ArrowSchema>>()
132}