polars_io/
shared.rs

1use std::io::{Read, Write};
2use std::sync::Arc;
3
4use arrow::array::new_empty_array;
5use arrow::record_batch::RecordBatch;
6use polars_core::prelude::*;
7use polars_utils::plpath::PlPathRef;
8
9use crate::cloud::CloudOptions;
10use crate::options::RowIndex;
11#[cfg(any(feature = "ipc", feature = "avro", feature = "ipc_streaming",))]
12use crate::predicates::PhysicalIoExpr;
13
14pub trait SerReader<R>
15where
16    R: Read,
17{
18    /// Create a new instance of the [`SerReader`]
19    fn new(reader: R) -> Self;
20
21    /// Make sure that all columns are contiguous in memory by
22    /// aggregating the chunks into a single array.
23    #[must_use]
24    fn set_rechunk(self, _rechunk: bool) -> Self
25    where
26        Self: Sized,
27    {
28        self
29    }
30
31    /// Take the SerReader and return a parsed DataFrame.
32    fn finish(self) -> PolarsResult<DataFrame>;
33}
34
35pub trait SerWriter<W>
36where
37    W: Write,
38{
39    fn new(writer: W) -> Self
40    where
41        Self: Sized;
42    fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()>;
43}
44
45pub trait WriteDataFrameToFile {
46    fn write_df_to_file(
47        &self,
48        df: &mut DataFrame,
49        addr: PlPathRef<'_>,
50        cloud_options: Option<&CloudOptions>,
51    ) -> PolarsResult<()>;
52}
53
54pub trait ArrowReader {
55    fn next_record_batch(&mut self) -> PolarsResult<Option<RecordBatch>>;
56}
57
58#[cfg(any(feature = "ipc", feature = "avro", feature = "ipc_streaming",))]
59pub(crate) fn finish_reader<R: ArrowReader>(
60    mut reader: R,
61    rechunk: bool,
62    n_rows: Option<usize>,
63    predicate: Option<Arc<dyn PhysicalIoExpr>>,
64    arrow_schema: &ArrowSchema,
65    row_index: Option<RowIndex>,
66) -> PolarsResult<DataFrame> {
67    use polars_core::utils::accumulate_dataframes_vertical_unchecked;
68
69    let mut num_rows = 0;
70    let mut parsed_dfs = Vec::with_capacity(1024);
71
72    while let Some(batch) = reader.next_record_batch()? {
73        let current_num_rows = num_rows as IdxSize;
74        num_rows += batch.len();
75        let mut df = DataFrame::from(batch);
76
77        if let Some(rc) = &row_index {
78            unsafe { df.with_row_index_mut(rc.name.clone(), Some(current_num_rows + rc.offset)) };
79        }
80
81        if let Some(predicate) = &predicate {
82            let s = predicate.evaluate_io(&df)?;
83            let mask = s.bool().expect("filter predicates was not of type boolean");
84            df = df.filter(mask)?;
85        }
86
87        if let Some(n) = n_rows {
88            if num_rows >= n {
89                let len = n - parsed_dfs
90                    .iter()
91                    .map(|df: &DataFrame| df.height())
92                    .sum::<usize>();
93                if polars_core::config::verbose() {
94                    eprintln!(
95                        "sliced off {} rows of the 'DataFrame'. These lines were read because they were in a single chunk.",
96                        df.height().saturating_sub(n)
97                    )
98                }
99                parsed_dfs.push(df.slice(0, len));
100                break;
101            }
102        }
103        parsed_dfs.push(df);
104    }
105
106    let mut df = {
107        if parsed_dfs.is_empty() {
108            // Create an empty dataframe with the correct data types
109            let empty_cols = arrow_schema
110                .iter_values()
111                .map(|fld| {
112                    Series::try_from((fld.name.clone(), new_empty_array(fld.dtype.clone())))
113                        .map(Column::from)
114                })
115                .collect::<PolarsResult<_>>()?;
116            DataFrame::new(empty_cols)?
117        } else {
118            // If there are any rows, accumulate them into a df
119            accumulate_dataframes_vertical_unchecked(parsed_dfs)
120        }
121    };
122
123    if rechunk {
124        df.as_single_chunk_par();
125    }
126    Ok(df)
127}
128
129pub fn schema_to_arrow_checked(
130    schema: &Schema,
131    compat_level: CompatLevel,
132    _file_name: &str,
133) -> PolarsResult<ArrowSchema> {
134    schema
135        .iter_fields()
136        .map(|field| {
137            #[cfg(feature = "object")]
138            {
139                polars_ensure!(
140                    !matches!(field.dtype(), DataType::Object(_)),
141                    ComputeError: "cannot write 'Object' datatype to {}",
142                    _file_name
143                );
144            }
145
146            let field = field
147                .dtype()
148                .to_arrow_field(field.name().clone(), compat_level);
149            Ok((field.name.clone(), field))
150        })
151        .collect::<PolarsResult<ArrowSchema>>()
152}