polars_io/parquet/write/
batched_writer.rs

1use std::io::Write;
2use std::sync::Mutex;
3
4use arrow::record_batch::RecordBatch;
5use polars_buffer::Buffer;
6use polars_core::POOL;
7use polars_core::prelude::*;
8use polars_parquet::read::{ParquetError, fallible_streaming_iterator};
9use polars_parquet::write::{
10    CompressedPage, Compressor, DynIter, DynStreamingIterator, Encoding, FallibleStreamingIterator,
11    FileWriter, Page, ParquetType, RowGroupIterColumns, SchemaDescriptor, WriteOptions,
12    array_to_columns, schema_to_metadata_key,
13};
14use rayon::prelude::*;
15
16use super::{KeyValueMetadata, ParquetMetadataContext};
17
18pub struct BatchedWriter<W: Write> {
19    // A mutex so that streaming engine can get concurrent read access to
20    // compress pages.
21    //
22    // @TODO: Remove mutex when old streaming engine is removed
23    pub(super) writer: Mutex<FileWriter<W>>,
24    // @TODO: Remove when old streaming engine is removed
25    pub(super) parquet_schema: SchemaDescriptor,
26    pub(super) encodings: Buffer<Vec<Encoding>>,
27    pub(super) options: WriteOptions,
28    pub(super) parallel: bool,
29    pub(super) key_value_metadata: Option<KeyValueMetadata>,
30}
31
32impl<W: Write> BatchedWriter<W> {
33    pub fn new(
34        writer: Mutex<FileWriter<W>>,
35        encodings: Buffer<Vec<Encoding>>,
36        options: WriteOptions,
37        parallel: bool,
38        key_value_metadata: Option<KeyValueMetadata>,
39    ) -> Self {
40        Self {
41            writer,
42            parquet_schema: SchemaDescriptor::new(PlSmallStr::EMPTY, vec![]),
43            encodings,
44            options,
45            parallel,
46            key_value_metadata,
47        }
48    }
49
50    pub fn encode_and_compress<'a>(
51        &'a self,
52        df: &'a DataFrame,
53    ) -> impl Iterator<Item = PolarsResult<RowGroupIterColumns<'static, PolarsError>>> + 'a {
54        let rb_iter = df.iter_chunks(CompatLevel::newest(), false);
55        rb_iter.filter_map(move |batch| match batch.len() {
56            0 => None,
57            _ => {
58                let row_group = create_eager_serializer(
59                    batch,
60                    self.parquet_schema.fields(),
61                    self.encodings.as_ref(),
62                    self.options,
63                );
64
65                Some(row_group)
66            },
67        })
68    }
69
70    /// Write a batch to the parquet writer.
71    ///
72    /// # Panics
73    /// The caller must ensure the chunks in the given [`DataFrame`] are aligned.
74    pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
75        let row_group_iter = prepare_rg_iter(
76            df,
77            &self.parquet_schema,
78            &self.encodings,
79            self.options,
80            self.parallel,
81        );
82        // Lock before looping so that order is maintained under contention.
83        let mut writer = self.writer.lock().unwrap();
84        for (num_rows, group) in row_group_iter {
85            writer.write(num_rows as u64, group?)?;
86        }
87        Ok(())
88    }
89
90    pub fn parquet_schema(&mut self) -> &SchemaDescriptor {
91        let writer = self.writer.get_mut().unwrap();
92        writer.parquet_schema()
93    }
94
95    /// Note: `num_rows` can be passed as `u64::MAX` to infer `num_rows` from the encoded data.
96    pub fn write_row_group(
97        &mut self,
98        num_rows: u64,
99        rg: &[Vec<CompressedPage>],
100    ) -> PolarsResult<()> {
101        let writer = self.writer.get_mut().unwrap();
102        let rg = DynIter::new(rg.iter().map(|col_pages| {
103            Ok(DynStreamingIterator::new(
104                fallible_streaming_iterator::convert(col_pages.iter().map(PolarsResult::Ok)),
105            ))
106        }));
107        writer.write(num_rows, rg)?;
108        Ok(())
109    }
110
111    pub fn get_writer(&self) -> &Mutex<FileWriter<W>> {
112        &self.writer
113    }
114
115    pub fn write_row_groups(
116        &self,
117        rgs: Vec<RowGroupIterColumns<'static, PolarsError>>,
118    ) -> PolarsResult<()> {
119        // Lock before looping so that order is maintained.
120        let mut writer = self.writer.lock().unwrap();
121        for group in rgs {
122            writer.write(u64::MAX, group)?;
123        }
124        Ok(())
125    }
126
127    /// Writes the footer of the parquet file. Returns the total size of the file.
128    pub fn finish(&self) -> PolarsResult<u64> {
129        let mut writer = self.writer.lock().unwrap();
130
131        let key_value_metadata = self
132            .key_value_metadata
133            .as_ref()
134            .map(|meta| {
135                let arrow_schema = schema_to_metadata_key(writer.schema());
136                let ctx = ParquetMetadataContext {
137                    arrow_schema: arrow_schema.value.as_ref().unwrap(),
138                };
139                let mut out = meta.collect(ctx)?;
140                if !out.iter().any(|kv| kv.key == arrow_schema.key) {
141                    out.insert(0, arrow_schema);
142                }
143                PolarsResult::Ok(out)
144            })
145            .transpose()?;
146
147        let size = writer.end(key_value_metadata)?;
148        Ok(size)
149    }
150}
151
152// Note that the df should be rechunked
153fn prepare_rg_iter<'a>(
154    df: &'a DataFrame,
155    parquet_schema: &'a SchemaDescriptor,
156    encodings: &'a [Vec<Encoding>],
157    options: WriteOptions,
158    parallel: bool,
159) -> impl Iterator<
160    Item = (
161        usize,
162        PolarsResult<RowGroupIterColumns<'static, PolarsError>>,
163    ),
164> + 'a {
165    let rb_iter = df.iter_chunks(CompatLevel::newest(), false);
166    rb_iter.filter_map(move |batch| match batch.len() {
167        0 => None,
168        num_rows => {
169            let row_group =
170                create_serializer(batch, parquet_schema.fields(), encodings, options, parallel);
171
172            Some((num_rows, row_group))
173        },
174    })
175}
176
177fn pages_iter_to_compressor(
178    encoded_columns: Vec<DynIter<'static, PolarsResult<Page>>>,
179    options: WriteOptions,
180) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {
181    encoded_columns
182        .into_iter()
183        .map(|encoded_pages| {
184            // iterator over pages
185            let pages = DynStreamingIterator::new(
186                Compressor::new_from_vec(
187                    encoded_pages.map(|result| {
188                        result.map_err(|e| {
189                            ParquetError::FeatureNotSupported(format!("reraised in polars: {e}",))
190                        })
191                    }),
192                    options.compression,
193                    vec![],
194                )
195                .map_err(PolarsError::from),
196            );
197
198            Ok(pages)
199        })
200        .collect::<Vec<_>>()
201}
202
203fn array_to_pages_iter(
204    array: &ArrayRef,
205    type_: &ParquetType,
206    encoding: &[Encoding],
207    options: WriteOptions,
208) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {
209    let encoded_columns = array_to_columns(array, type_.clone(), options, encoding).unwrap();
210    pages_iter_to_compressor(encoded_columns, options)
211}
212
213fn create_serializer(
214    batch: RecordBatch,
215    fields: &[ParquetType],
216    encodings: &[Vec<Encoding>],
217    options: WriteOptions,
218    parallel: bool,
219) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {
220    let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec<Encoding>)| {
221        array_to_pages_iter(array, type_, encoding, options)
222    };
223
224    let columns = if parallel {
225        POOL.install(|| {
226            batch
227                .columns()
228                .par_iter()
229                .zip(fields)
230                .zip(encodings)
231                .flat_map(func)
232                .collect::<Vec<_>>()
233        })
234    } else {
235        batch
236            .columns()
237            .iter()
238            .zip(fields)
239            .zip(encodings)
240            .flat_map(func)
241            .collect::<Vec<_>>()
242    };
243
244    let row_group = DynIter::new(columns.into_iter());
245
246    Ok(row_group)
247}
248
249/// This serializer encodes and compresses all eagerly in memory.
250/// Used for separating compute from IO.
251fn create_eager_serializer(
252    batch: RecordBatch,
253    fields: &[ParquetType],
254    encodings: &[Vec<Encoding>],
255    options: WriteOptions,
256) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {
257    let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec<Encoding>)| {
258        array_to_pages_iter(array, type_, encoding, options)
259    };
260
261    let columns = batch
262        .columns()
263        .iter()
264        .zip(fields)
265        .zip(encodings)
266        .flat_map(func)
267        .collect::<Vec<_>>();
268
269    let row_group = DynIter::new(columns.into_iter());
270
271    Ok(row_group)
272}