polars_io/parquet/write/
batched_writer.rs

1use std::io::Write;
2use std::sync::Mutex;
3
4use arrow::record_batch::RecordBatch;
5use polars_core::POOL;
6use polars_core::prelude::*;
7use polars_parquet::read::{ParquetError, fallible_streaming_iterator};
8use polars_parquet::write::{
9    CompressedPage, Compressor, DynIter, DynStreamingIterator, Encoding, FallibleStreamingIterator,
10    FileWriter, Page, ParquetType, RowGroupIterColumns, SchemaDescriptor, WriteOptions,
11    array_to_columns,
12};
13use rayon::prelude::*;
14
15pub struct BatchedWriter<W: Write> {
16    // A mutex so that streaming engine can get concurrent read access to
17    // compress pages.
18    //
19    // @TODO: Remove mutex when old streaming engine is removed
20    pub(super) writer: Mutex<FileWriter<W>>,
21    // @TODO: Remove when old streaming engine is removed
22    pub(super) parquet_schema: SchemaDescriptor,
23    pub(super) encodings: Vec<Vec<Encoding>>,
24    pub(super) options: WriteOptions,
25    pub(super) parallel: bool,
26}
27
28impl<W: Write> BatchedWriter<W> {
29    pub fn new(
30        writer: Mutex<FileWriter<W>>,
31        encodings: Vec<Vec<Encoding>>,
32        options: WriteOptions,
33        parallel: bool,
34    ) -> Self {
35        Self {
36            writer,
37            parquet_schema: SchemaDescriptor::new(PlSmallStr::EMPTY, vec![]),
38            encodings,
39            options,
40            parallel,
41        }
42    }
43
44    pub fn encode_and_compress<'a>(
45        &'a self,
46        df: &'a DataFrame,
47    ) -> impl Iterator<Item = PolarsResult<RowGroupIterColumns<'static, PolarsError>>> + 'a {
48        let rb_iter = df.iter_chunks(CompatLevel::newest(), false);
49        rb_iter.filter_map(move |batch| match batch.len() {
50            0 => None,
51            _ => {
52                let row_group = create_eager_serializer(
53                    batch,
54                    self.parquet_schema.fields(),
55                    self.encodings.as_ref(),
56                    self.options,
57                );
58
59                Some(row_group)
60            },
61        })
62    }
63
64    /// Write a batch to the parquet writer.
65    ///
66    /// # Panics
67    /// The caller must ensure the chunks in the given [`DataFrame`] are aligned.
68    pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
69        let row_group_iter = prepare_rg_iter(
70            df,
71            &self.parquet_schema,
72            &self.encodings,
73            self.options,
74            self.parallel,
75        );
76        // Lock before looping so that order is maintained under contention.
77        let mut writer = self.writer.lock().unwrap();
78        for group in row_group_iter {
79            writer.write(group?)?;
80        }
81        Ok(())
82    }
83
84    pub fn parquet_schema(&mut self) -> &SchemaDescriptor {
85        let writer = self.writer.get_mut().unwrap();
86        writer.parquet_schema()
87    }
88
89    pub fn write_row_group(&mut self, rg: &[Vec<CompressedPage>]) -> PolarsResult<()> {
90        let writer = self.writer.get_mut().unwrap();
91        let rg = DynIter::new(rg.iter().map(|col_pages| {
92            Ok(DynStreamingIterator::new(
93                fallible_streaming_iterator::convert(col_pages.iter().map(PolarsResult::Ok)),
94            ))
95        }));
96        writer.write(rg)?;
97        Ok(())
98    }
99
100    pub fn get_writer(&self) -> &Mutex<FileWriter<W>> {
101        &self.writer
102    }
103
104    pub fn write_row_groups(
105        &self,
106        rgs: Vec<RowGroupIterColumns<'static, PolarsError>>,
107    ) -> PolarsResult<()> {
108        // Lock before looping so that order is maintained.
109        let mut writer = self.writer.lock().unwrap();
110        for group in rgs {
111            writer.write(group)?;
112        }
113        Ok(())
114    }
115
116    /// Writes the footer of the parquet file. Returns the total size of the file.
117    pub fn finish(&self) -> PolarsResult<u64> {
118        let mut writer = self.writer.lock().unwrap();
119        let size = writer.end(None)?;
120        Ok(size)
121    }
122}
123
124// Note that the df should be rechunked
125fn prepare_rg_iter<'a>(
126    df: &'a DataFrame,
127    parquet_schema: &'a SchemaDescriptor,
128    encodings: &'a [Vec<Encoding>],
129    options: WriteOptions,
130    parallel: bool,
131) -> impl Iterator<Item = PolarsResult<RowGroupIterColumns<'static, PolarsError>>> + 'a {
132    let rb_iter = df.iter_chunks(CompatLevel::newest(), false);
133    rb_iter.filter_map(move |batch| match batch.len() {
134        0 => None,
135        _ => {
136            let row_group =
137                create_serializer(batch, parquet_schema.fields(), encodings, options, parallel);
138
139            Some(row_group)
140        },
141    })
142}
143
144fn pages_iter_to_compressor(
145    encoded_columns: Vec<DynIter<'static, PolarsResult<Page>>>,
146    options: WriteOptions,
147) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {
148    encoded_columns
149        .into_iter()
150        .map(|encoded_pages| {
151            // iterator over pages
152            let pages = DynStreamingIterator::new(
153                Compressor::new_from_vec(
154                    encoded_pages.map(|result| {
155                        result.map_err(|e| {
156                            ParquetError::FeatureNotSupported(format!("reraised in polars: {e}",))
157                        })
158                    }),
159                    options.compression,
160                    vec![],
161                )
162                .map_err(PolarsError::from),
163            );
164
165            Ok(pages)
166        })
167        .collect::<Vec<_>>()
168}
169
170fn array_to_pages_iter(
171    array: &ArrayRef,
172    type_: &ParquetType,
173    encoding: &[Encoding],
174    options: WriteOptions,
175) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {
176    let encoded_columns = array_to_columns(array, type_.clone(), options, encoding).unwrap();
177    pages_iter_to_compressor(encoded_columns, options)
178}
179
180fn create_serializer(
181    batch: RecordBatch,
182    fields: &[ParquetType],
183    encodings: &[Vec<Encoding>],
184    options: WriteOptions,
185    parallel: bool,
186) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {
187    let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec<Encoding>)| {
188        array_to_pages_iter(array, type_, encoding, options)
189    };
190
191    let columns = if parallel {
192        POOL.install(|| {
193            batch
194                .columns()
195                .par_iter()
196                .zip(fields)
197                .zip(encodings)
198                .flat_map(func)
199                .collect::<Vec<_>>()
200        })
201    } else {
202        batch
203            .columns()
204            .iter()
205            .zip(fields)
206            .zip(encodings)
207            .flat_map(func)
208            .collect::<Vec<_>>()
209    };
210
211    let row_group = DynIter::new(columns.into_iter());
212
213    Ok(row_group)
214}
215
216/// This serializer encodes and compresses all eagerly in memory.
217/// Used for separating compute from IO.
218fn create_eager_serializer(
219    batch: RecordBatch,
220    fields: &[ParquetType],
221    encodings: &[Vec<Encoding>],
222    options: WriteOptions,
223) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {
224    let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec<Encoding>)| {
225        array_to_pages_iter(array, type_, encoding, options)
226    };
227
228    let columns = batch
229        .columns()
230        .iter()
231        .zip(fields)
232        .zip(encodings)
233        .flat_map(func)
234        .collect::<Vec<_>>();
235
236    let row_group = DynIter::new(columns.into_iter());
237
238    Ok(row_group)
239}