polars_io/parquet/write/
writer.rs

1use std::io::Write;
2use std::sync::Mutex;
3
4use arrow::datatypes::PhysicalType;
5use polars_core::frame::chunk_df_for_writing;
6use polars_core::prelude::*;
7use polars_parquet::write::{
8    ChildWriteOptions, ColumnWriteOptions, CompressionOptions, Encoding, FieldWriteOptions,
9    FileWriter, KeyValue, ListLikeFieldWriteOptions, StatisticsOptions, StructFieldWriteOptions,
10    Version, WriteOptions, to_parquet_schema,
11};
12
13use super::batched_writer::BatchedWriter;
14use super::options::ParquetCompression;
15use super::{KeyValueMetadata, MetadataKeyValue, ParquetFieldOverwrites, ParquetWriteOptions};
16use crate::prelude::ChildFieldOverwrites;
17use crate::shared::schema_to_arrow_checked;
18
19impl ParquetWriteOptions {
20    pub fn to_writer<F>(&self, f: F) -> ParquetWriter<F>
21    where
22        F: Write,
23    {
24        ParquetWriter::new(f)
25            .with_compression(self.compression)
26            .with_statistics(self.statistics)
27            .with_row_group_size(self.row_group_size)
28            .with_data_page_size(self.data_page_size)
29            .with_key_value_metadata(self.key_value_metadata.clone())
30    }
31}
32
33/// Write a DataFrame to Parquet format.
34#[must_use]
35pub struct ParquetWriter<W> {
36    writer: W,
37    /// Data page compression
38    compression: CompressionOptions,
39    /// Compute and write column statistics.
40    statistics: StatisticsOptions,
41    /// if `None` will be 512^2 rows
42    row_group_size: Option<usize>,
43    /// if `None` will be 1024^2 bytes
44    data_page_size: Option<usize>,
45    /// Serialize columns in parallel
46    parallel: bool,
47    field_overwrites: Vec<ParquetFieldOverwrites>,
48    /// Custom file-level key value metadata
49    key_value_metadata: Option<KeyValueMetadata>,
50    /// Context info for the Parquet file being written.
51    context_info: Option<PlHashMap<String, String>>,
52}
53
54impl<W> ParquetWriter<W>
55where
56    W: Write,
57{
58    /// Create a new writer
59    pub fn new(writer: W) -> Self
60    where
61        W: Write,
62    {
63        ParquetWriter {
64            writer,
65            compression: ParquetCompression::default().into(),
66            statistics: StatisticsOptions::default(),
67            row_group_size: None,
68            data_page_size: None,
69            parallel: true,
70            field_overwrites: Vec::new(),
71            key_value_metadata: None,
72            context_info: None,
73        }
74    }
75
76    /// Set the compression used. Defaults to `Zstd`.
77    ///
78    /// The default compression `Zstd` has very good performance, but may not yet been supported
79    /// by older readers. If you want more compatibility guarantees, consider using `Snappy`.
80    pub fn with_compression(mut self, compression: ParquetCompression) -> Self {
81        self.compression = compression.into();
82        self
83    }
84
85    /// Compute and write statistic
86    pub fn with_statistics(mut self, statistics: StatisticsOptions) -> Self {
87        self.statistics = statistics;
88        self
89    }
90
91    /// Set the row group size (in number of rows) during writing. This can reduce memory pressure and improve
92    /// writing performance.
93    pub fn with_row_group_size(mut self, size: Option<usize>) -> Self {
94        self.row_group_size = size;
95        self
96    }
97
98    /// Sets the maximum bytes size of a data page. If `None` will be 1024^2 bytes.
99    pub fn with_data_page_size(mut self, limit: Option<usize>) -> Self {
100        self.data_page_size = limit;
101        self
102    }
103
104    /// Serialize columns in parallel
105    pub fn set_parallel(mut self, parallel: bool) -> Self {
106        self.parallel = parallel;
107        self
108    }
109
110    /// Set custom file-level key value metadata for the Parquet file
111    pub fn with_key_value_metadata(mut self, key_value_metadata: Option<KeyValueMetadata>) -> Self {
112        self.key_value_metadata = key_value_metadata;
113        self
114    }
115
116    /// Set context information for the writer
117    pub fn with_context_info(mut self, context_info: Option<PlHashMap<String, String>>) -> Self {
118        self.context_info = context_info;
119        self
120    }
121
122    pub fn batched(self, schema: &Schema) -> PolarsResult<BatchedWriter<W>> {
123        let schema = schema_to_arrow_checked(schema, CompatLevel::newest(), "parquet")?;
124        let column_options = get_column_write_options(&schema, &self.field_overwrites);
125        let parquet_schema = to_parquet_schema(&schema, &column_options)?;
126        let options = self.materialize_options();
127        let writer = Mutex::new(FileWriter::try_new(
128            self.writer,
129            schema,
130            options,
131            &column_options,
132        )?);
133
134        Ok(BatchedWriter {
135            writer,
136            parquet_schema,
137            column_options,
138            options,
139            parallel: self.parallel,
140            key_value_metadata: self.key_value_metadata,
141        })
142    }
143
144    fn materialize_options(&self) -> WriteOptions {
145        WriteOptions {
146            statistics: self.statistics,
147            compression: self.compression,
148            version: Version::V1,
149            data_page_size: self.data_page_size,
150        }
151    }
152
153    /// Write the given DataFrame in the writer `W`.
154    /// Returns the total size of the file.
155    pub fn finish(self, df: &mut DataFrame) -> PolarsResult<u64> {
156        let chunked_df = chunk_df_for_writing(df, self.row_group_size.unwrap_or(512 * 512))?;
157        let mut batched = self.batched(chunked_df.schema())?;
158        batched.write_batch(&chunked_df)?;
159        batched.finish()
160    }
161}
162
163fn convert_metadata(md: &Option<Vec<MetadataKeyValue>>) -> Vec<KeyValue> {
164    md.as_ref()
165        .map(|metadata| {
166            metadata
167                .iter()
168                .map(|kv| KeyValue {
169                    key: kv.key.to_string(),
170                    value: kv.value.as_ref().map(|v| v.to_string()),
171                })
172                .collect()
173        })
174        .unwrap_or_default()
175}
176
177fn to_column_write_options_rec(
178    field: &ArrowField,
179    overwrites: Option<&ParquetFieldOverwrites>,
180) -> ColumnWriteOptions {
181    let mut column_options = ColumnWriteOptions {
182        field_id: None,
183        metadata: Vec::new(),
184
185        // Dummy value.
186        children: ChildWriteOptions::Leaf(FieldWriteOptions {
187            encoding: Encoding::Plain,
188        }),
189    };
190
191    if let Some(overwrites) = overwrites {
192        column_options.field_id = overwrites.field_id;
193        column_options.metadata = convert_metadata(&overwrites.metadata);
194    }
195
196    use arrow::datatypes::PhysicalType::*;
197    match field.dtype().to_physical_type() {
198        Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8
199        | Dictionary(_) | LargeUtf8 | BinaryView | Utf8View => {
200            column_options.children = ChildWriteOptions::Leaf(FieldWriteOptions {
201                encoding: encoding_map(field.dtype()),
202            });
203        },
204        List | FixedSizeList | LargeList => {
205            let child_overwrites = overwrites.map(|o| match &o.children {
206                ChildFieldOverwrites::ListLike(child_overwrites) => child_overwrites.as_ref(),
207                _ => unreachable!(),
208            });
209
210            let a = field.dtype().to_logical_type();
211            let child = if let ArrowDataType::List(inner) = a {
212                to_column_write_options_rec(inner, child_overwrites)
213            } else if let ArrowDataType::LargeList(inner) = a {
214                to_column_write_options_rec(inner, child_overwrites)
215            } else if let ArrowDataType::FixedSizeList(inner, _) = a {
216                to_column_write_options_rec(inner, child_overwrites)
217            } else {
218                unreachable!()
219            };
220
221            column_options.children =
222                ChildWriteOptions::ListLike(Box::new(ListLikeFieldWriteOptions { child }));
223        },
224        Struct => {
225            if let ArrowDataType::Struct(fields) = field.dtype().to_logical_type() {
226                let children_overwrites = overwrites.map(|o| match &o.children {
227                    ChildFieldOverwrites::Struct(child_overwrites) => PlHashMap::from_iter(
228                        child_overwrites
229                            .iter()
230                            .map(|f| (f.name.as_ref().unwrap(), f)),
231                    ),
232                    _ => unreachable!(),
233                });
234
235                let children = fields
236                    .iter()
237                    .map(|f| {
238                        let overwrites = children_overwrites
239                            .as_ref()
240                            .and_then(|o| o.get(&f.name).copied());
241                        to_column_write_options_rec(f, overwrites)
242                    })
243                    .collect();
244
245                column_options.children =
246                    ChildWriteOptions::Struct(Box::new(StructFieldWriteOptions { children }));
247            } else {
248                unreachable!()
249            }
250        },
251
252        Map | Union => unreachable!(),
253    }
254
255    column_options
256}
257
258pub fn get_column_write_options(
259    schema: &ArrowSchema,
260    field_overwrites: &[ParquetFieldOverwrites],
261) -> Vec<ColumnWriteOptions> {
262    let field_overwrites = PlHashMap::from(
263        field_overwrites
264            .iter()
265            .map(|f| (f.name.as_ref().unwrap(), f))
266            .collect(),
267    );
268    schema
269        .iter_values()
270        .map(|f| to_column_write_options_rec(f, field_overwrites.get(&f.name).copied()))
271        .collect()
272}
273
274/// Declare encodings
275fn encoding_map(dtype: &ArrowDataType) -> Encoding {
276    match dtype.to_physical_type() {
277        PhysicalType::Dictionary(_)
278        | PhysicalType::LargeBinary
279        | PhysicalType::LargeUtf8
280        | PhysicalType::Utf8View
281        | PhysicalType::BinaryView => Encoding::RleDictionary,
282        PhysicalType::Primitive(dt) => {
283            use arrow::types::PrimitiveType::*;
284            match dt {
285                Float32 | Float64 | Float16 => Encoding::Plain,
286                _ => Encoding::RleDictionary,
287            }
288        },
289        // remaining is plain
290        _ => Encoding::Plain,
291    }
292}