polars_io/parquet/write/
batched_writer.rs

1use std::io::Write;
2use std::sync::Mutex;
3
4use arrow::record_batch::RecordBatch;
5use polars_core::POOL;
6use polars_core::prelude::*;
7use polars_parquet::read::{ParquetError, fallible_streaming_iterator};
8use polars_parquet::write::{
9    ColumnWriteOptions, CompressedPage, Compressor, DynIter, DynStreamingIterator,
10    FallibleStreamingIterator, FileWriter, Page, ParquetType, RowGroupIterColumns,
11    SchemaDescriptor, WriteOptions, array_to_columns, schema_to_metadata_key,
12};
13use rayon::prelude::*;
14
15use super::{KeyValueMetadata, ParquetMetadataContext};
16
17pub struct BatchedWriter<W: Write> {
18    // A mutex so that streaming engine can get concurrent read access to
19    // compress pages.
20    //
21    // @TODO: Remove mutex when old streaming engine is removed
22    pub(super) writer: Mutex<FileWriter<W>>,
23    // @TODO: Remove when old streaming engine is removed
24    pub(super) parquet_schema: SchemaDescriptor,
25    pub(super) column_options: Vec<ColumnWriteOptions>,
26    pub(super) options: WriteOptions,
27    pub(super) parallel: bool,
28    pub(super) key_value_metadata: Option<KeyValueMetadata>,
29}
30
31impl<W: Write> BatchedWriter<W> {
32    pub fn new(
33        writer: Mutex<FileWriter<W>>,
34        column_options: Vec<ColumnWriteOptions>,
35        options: WriteOptions,
36        parallel: bool,
37        key_value_metadata: Option<KeyValueMetadata>,
38    ) -> Self {
39        Self {
40            writer,
41            parquet_schema: SchemaDescriptor::new(PlSmallStr::EMPTY, vec![]),
42            column_options,
43            options,
44            parallel,
45            key_value_metadata,
46        }
47    }
48
49    pub fn encode_and_compress<'a>(
50        &'a self,
51        df: &'a DataFrame,
52    ) -> impl Iterator<Item = PolarsResult<RowGroupIterColumns<'static, PolarsError>>> + 'a {
53        let rb_iter = df.iter_chunks(CompatLevel::newest(), false);
54        rb_iter.filter_map(move |batch| match batch.len() {
55            0 => None,
56            _ => {
57                let row_group = create_eager_serializer(
58                    batch,
59                    self.parquet_schema.fields(),
60                    self.column_options.as_ref(),
61                    self.options,
62                );
63
64                Some(row_group)
65            },
66        })
67    }
68
69    /// Write a batch to the parquet writer.
70    ///
71    /// # Panics
72    /// The caller must ensure the chunks in the given [`DataFrame`] are aligned.
73    pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
74        let row_group_iter = prepare_rg_iter(
75            df,
76            &self.parquet_schema,
77            &self.column_options,
78            self.options,
79            self.parallel,
80        );
81        // Lock before looping so that order is maintained under contention.
82        let mut writer = self.writer.lock().unwrap();
83        for group in row_group_iter {
84            writer.write(group?)?;
85        }
86        Ok(())
87    }
88
89    pub fn parquet_schema(&mut self) -> &SchemaDescriptor {
90        let writer = self.writer.get_mut().unwrap();
91        writer.parquet_schema()
92    }
93
94    pub fn write_row_group(&mut self, rg: &[Vec<CompressedPage>]) -> PolarsResult<()> {
95        let writer = self.writer.get_mut().unwrap();
96        let rg = DynIter::new(rg.iter().map(|col_pages| {
97            Ok(DynStreamingIterator::new(
98                fallible_streaming_iterator::convert(col_pages.iter().map(PolarsResult::Ok)),
99            ))
100        }));
101        writer.write(rg)?;
102        Ok(())
103    }
104
105    pub fn get_writer(&self) -> &Mutex<FileWriter<W>> {
106        &self.writer
107    }
108
109    pub fn write_row_groups(
110        &self,
111        rgs: Vec<RowGroupIterColumns<'static, PolarsError>>,
112    ) -> PolarsResult<()> {
113        // Lock before looping so that order is maintained.
114        let mut writer = self.writer.lock().unwrap();
115        for group in rgs {
116            writer.write(group)?;
117        }
118        Ok(())
119    }
120
121    /// Writes the footer of the parquet file. Returns the total size of the file.
122    pub fn finish(&self) -> PolarsResult<u64> {
123        let mut writer = self.writer.lock().unwrap();
124
125        let key_value_metadata = self
126            .key_value_metadata
127            .as_ref()
128            .map(|meta| {
129                let arrow_schema = schema_to_metadata_key(writer.schema(), &self.column_options);
130                let ctx = ParquetMetadataContext {
131                    arrow_schema: arrow_schema.value.as_ref().unwrap(),
132                };
133                let mut out = meta.collect(ctx)?;
134                if !out.iter().any(|kv| kv.key == arrow_schema.key) {
135                    out.insert(0, arrow_schema);
136                }
137                PolarsResult::Ok(out)
138            })
139            .transpose()?;
140
141        let size = writer.end(key_value_metadata, &self.column_options)?;
142        Ok(size)
143    }
144}
145
146// Note that the df should be rechunked
147fn prepare_rg_iter<'a>(
148    df: &'a DataFrame,
149    parquet_schema: &'a SchemaDescriptor,
150    column_options: &'a [ColumnWriteOptions],
151    options: WriteOptions,
152    parallel: bool,
153) -> impl Iterator<Item = PolarsResult<RowGroupIterColumns<'static, PolarsError>>> + 'a {
154    let rb_iter = df.iter_chunks(CompatLevel::newest(), false);
155    rb_iter.filter_map(move |batch| match batch.len() {
156        0 => None,
157        _ => {
158            let row_group = create_serializer(
159                batch,
160                parquet_schema.fields(),
161                column_options,
162                options,
163                parallel,
164            );
165
166            Some(row_group)
167        },
168    })
169}
170
171fn pages_iter_to_compressor(
172    encoded_columns: Vec<DynIter<'static, PolarsResult<Page>>>,
173    options: WriteOptions,
174) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {
175    encoded_columns
176        .into_iter()
177        .map(|encoded_pages| {
178            // iterator over pages
179            let pages = DynStreamingIterator::new(
180                Compressor::new_from_vec(
181                    encoded_pages.map(|result| {
182                        result.map_err(|e| {
183                            ParquetError::FeatureNotSupported(format!("reraised in polars: {e}",))
184                        })
185                    }),
186                    options.compression,
187                    vec![],
188                )
189                .map_err(PolarsError::from),
190            );
191
192            Ok(pages)
193        })
194        .collect::<Vec<_>>()
195}
196
197fn array_to_pages_iter(
198    array: &ArrayRef,
199    type_: &ParquetType,
200    column_options: &ColumnWriteOptions,
201    options: WriteOptions,
202) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {
203    let encoded_columns = array_to_columns(array, type_.clone(), column_options, options).unwrap();
204    pages_iter_to_compressor(encoded_columns, options)
205}
206
207fn create_serializer(
208    batch: RecordBatch,
209    fields: &[ParquetType],
210    column_options: &[ColumnWriteOptions],
211    options: WriteOptions,
212    parallel: bool,
213) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {
214    let func = move |((array, type_), column_options): (
215        (&ArrayRef, &ParquetType),
216        &ColumnWriteOptions,
217    )| { array_to_pages_iter(array, type_, column_options, options) };
218
219    let columns = if parallel {
220        POOL.install(|| {
221            batch
222                .columns()
223                .par_iter()
224                .zip(fields)
225                .zip(column_options)
226                .flat_map(func)
227                .collect::<Vec<_>>()
228        })
229    } else {
230        batch
231            .columns()
232            .iter()
233            .zip(fields)
234            .zip(column_options)
235            .flat_map(func)
236            .collect::<Vec<_>>()
237    };
238
239    let row_group = DynIter::new(columns.into_iter());
240
241    Ok(row_group)
242}
243
244/// This serializer encodes and compresses all eagerly in memory.
245/// Used for separating compute from IO.
246fn create_eager_serializer(
247    batch: RecordBatch,
248    fields: &[ParquetType],
249    column_options: &[ColumnWriteOptions],
250    options: WriteOptions,
251) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {
252    let func = move |((array, type_), column_options): (
253        (&ArrayRef, &ParquetType),
254        &ColumnWriteOptions,
255    )| { array_to_pages_iter(array, type_, column_options, options) };
256
257    let columns = batch
258        .columns()
259        .iter()
260        .zip(fields)
261        .zip(column_options)
262        .flat_map(func)
263        .collect::<Vec<_>>();
264
265    let row_group = DynIter::new(columns.into_iter());
266
267    Ok(row_group)
268}