polars_io/
shared.rs

1use std::io::{Read, Write};
2use std::sync::Arc;
3
4use arrow::array::new_empty_array;
5use arrow::record_batch::RecordBatch;
6use polars_core::prelude::*;
7
8use crate::cloud::CloudOptions;
9use crate::options::RowIndex;
10#[cfg(any(feature = "ipc", feature = "avro", feature = "ipc_streaming",))]
11use crate::predicates::PhysicalIoExpr;
12
13pub trait SerReader<R>
14where
15    R: Read,
16{
17    /// Create a new instance of the [`SerReader`]
18    fn new(reader: R) -> Self;
19
20    /// Make sure that all columns are contiguous in memory by
21    /// aggregating the chunks into a single array.
22    #[must_use]
23    fn set_rechunk(self, _rechunk: bool) -> Self
24    where
25        Self: Sized,
26    {
27        self
28    }
29
30    /// Take the SerReader and return a parsed DataFrame.
31    fn finish(self) -> PolarsResult<DataFrame>;
32}
33
34pub trait SerWriter<W>
35where
36    W: Write,
37{
38    fn new(writer: W) -> Self
39    where
40        Self: Sized;
41    fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()>;
42}
43
44pub trait WriteDataFrameToFile {
45    fn write_df_to_file(
46        &self,
47        df: &mut DataFrame,
48        path: &str,
49        cloud_options: Option<&CloudOptions>,
50    ) -> PolarsResult<()>;
51}
52
53pub trait ArrowReader {
54    fn next_record_batch(&mut self) -> PolarsResult<Option<RecordBatch>>;
55}
56
57#[cfg(any(feature = "ipc", feature = "avro", feature = "ipc_streaming",))]
58pub(crate) fn finish_reader<R: ArrowReader>(
59    mut reader: R,
60    rechunk: bool,
61    n_rows: Option<usize>,
62    predicate: Option<Arc<dyn PhysicalIoExpr>>,
63    arrow_schema: &ArrowSchema,
64    row_index: Option<RowIndex>,
65) -> PolarsResult<DataFrame> {
66    use polars_core::utils::accumulate_dataframes_vertical_unchecked;
67
68    let mut num_rows = 0;
69    let mut parsed_dfs = Vec::with_capacity(1024);
70
71    while let Some(batch) = reader.next_record_batch()? {
72        let current_num_rows = num_rows as IdxSize;
73        num_rows += batch.len();
74        let mut df = DataFrame::from(batch);
75
76        if let Some(rc) = &row_index {
77            unsafe { df.with_row_index_mut(rc.name.clone(), Some(current_num_rows + rc.offset)) };
78        }
79
80        if let Some(predicate) = &predicate {
81            let s = predicate.evaluate_io(&df)?;
82            let mask = s.bool().expect("filter predicates was not of type boolean");
83            df = df.filter(mask)?;
84        }
85
86        if let Some(n) = n_rows {
87            if num_rows >= n {
88                let len = n - parsed_dfs
89                    .iter()
90                    .map(|df: &DataFrame| df.height())
91                    .sum::<usize>();
92                if polars_core::config::verbose() {
93                    eprintln!(
94                        "sliced off {} rows of the 'DataFrame'. These lines were read because they were in a single chunk.",
95                        df.height().saturating_sub(n)
96                    )
97                }
98                parsed_dfs.push(df.slice(0, len));
99                break;
100            }
101        }
102        parsed_dfs.push(df);
103    }
104
105    let mut df = {
106        if parsed_dfs.is_empty() {
107            // Create an empty dataframe with the correct data types
108            let empty_cols = arrow_schema
109                .iter_values()
110                .map(|fld| {
111                    Series::try_from((fld.name.clone(), new_empty_array(fld.dtype.clone())))
112                        .map(Column::from)
113                })
114                .collect::<PolarsResult<_>>()?;
115            DataFrame::new(empty_cols)?
116        } else {
117            // If there are any rows, accumulate them into a df
118            accumulate_dataframes_vertical_unchecked(parsed_dfs)
119        }
120    };
121
122    if rechunk {
123        df.as_single_chunk_par();
124    }
125    Ok(df)
126}
127
128pub fn schema_to_arrow_checked(
129    schema: &Schema,
130    compat_level: CompatLevel,
131    _file_name: &str,
132) -> PolarsResult<ArrowSchema> {
133    schema
134        .iter_fields()
135        .map(|field| {
136            #[cfg(feature = "object")]
137            {
138                polars_ensure!(
139                    !matches!(field.dtype(), DataType::Object(_)),
140                    ComputeError: "cannot write 'Object' datatype to {}",
141                    _file_name
142                );
143            }
144
145            let field = field
146                .dtype()
147                .to_arrow_field(field.name().clone(), compat_level);
148            Ok((field.name.clone(), field))
149        })
150        .collect::<PolarsResult<ArrowSchema>>()
151}