polars_io/parquet/write/
writer.rs

1use std::io::Write;
2use std::sync::Mutex;
3
4use arrow::datatypes::PhysicalType;
5use polars_core::frame::chunk_df_for_writing;
6use polars_core::prelude::*;
7use polars_parquet::write::{
8    ChildWriteOptions, ColumnWriteOptions, CompressionOptions, Encoding, FieldWriteOptions,
9    FileWriter, KeyValue, ListLikeFieldWriteOptions, StatisticsOptions, StructFieldWriteOptions,
10    Version, WriteOptions, to_parquet_schema,
11};
12
13use super::batched_writer::BatchedWriter;
14use super::options::ParquetCompression;
15use super::{KeyValueMetadata, MetadataKeyValue, ParquetFieldOverwrites, ParquetWriteOptions};
16use crate::prelude::ChildFieldOverwrites;
17use crate::shared::schema_to_arrow_checked;
18
19impl ParquetWriteOptions {
20    pub fn to_writer<F>(&self, f: F) -> ParquetWriter<F>
21    where
22        F: Write,
23    {
24        ParquetWriter::new(f)
25            .with_compression(self.compression)
26            .with_statistics(self.statistics)
27            .with_row_group_size(self.row_group_size)
28            .with_data_page_size(self.data_page_size)
29            .with_key_value_metadata(self.key_value_metadata.clone())
30    }
31}
32
33/// Write a DataFrame to Parquet format.
34#[must_use]
35pub struct ParquetWriter<W> {
36    writer: W,
37    /// Data page compression
38    compression: CompressionOptions,
39    /// Compute and write column statistics.
40    statistics: StatisticsOptions,
41    /// if `None` will be 512^2 rows
42    row_group_size: Option<usize>,
43    /// if `None` will be 1024^2 bytes
44    data_page_size: Option<usize>,
45    /// Serialize columns in parallel
46    parallel: bool,
47    field_overwrites: Vec<ParquetFieldOverwrites>,
48    /// Custom file-level key value metadata
49    key_value_metadata: Option<KeyValueMetadata>,
50    /// Context info for the Parquet file being written.
51    context_info: Option<PlHashMap<String, String>>,
52}
53
54impl<W> ParquetWriter<W>
55where
56    W: Write,
57{
58    /// Create a new writer
59    pub fn new(writer: W) -> Self
60    where
61        W: Write,
62    {
63        ParquetWriter {
64            writer,
65            compression: ParquetCompression::default().into(),
66            statistics: StatisticsOptions::default(),
67            row_group_size: None,
68            data_page_size: None,
69            parallel: true,
70            field_overwrites: Vec::new(),
71            key_value_metadata: None,
72            context_info: None,
73        }
74    }
75
76    /// Set the compression used. Defaults to `Zstd`.
77    ///
78    /// The default compression `Zstd` has very good performance, but may not yet been supported
79    /// by older readers. If you want more compatibility guarantees, consider using `Snappy`.
80    pub fn with_compression(mut self, compression: ParquetCompression) -> Self {
81        self.compression = compression.into();
82        self
83    }
84
85    /// Compute and write statistic
86    pub fn with_statistics(mut self, statistics: StatisticsOptions) -> Self {
87        self.statistics = statistics;
88        self
89    }
90
91    /// Set the row group size (in number of rows) during writing. This can reduce memory pressure and improve
92    /// writing performance.
93    pub fn with_row_group_size(mut self, size: Option<usize>) -> Self {
94        self.row_group_size = size;
95        self
96    }
97
98    /// Sets the maximum bytes size of a data page. If `None` will be 1024^2 bytes.
99    pub fn with_data_page_size(mut self, limit: Option<usize>) -> Self {
100        self.data_page_size = limit;
101        self
102    }
103
104    /// Serialize columns in parallel
105    pub fn set_parallel(mut self, parallel: bool) -> Self {
106        self.parallel = parallel;
107        self
108    }
109
110    /// Set custom file-level key value metadata for the Parquet file
111    pub fn with_key_value_metadata(mut self, key_value_metadata: Option<KeyValueMetadata>) -> Self {
112        self.key_value_metadata = key_value_metadata;
113        self
114    }
115
116    /// Set context information for the writer
117    pub fn with_context_info(mut self, context_info: Option<PlHashMap<String, String>>) -> Self {
118        self.context_info = context_info;
119        self
120    }
121
122    pub fn batched(self, schema: &Schema) -> PolarsResult<BatchedWriter<W>> {
123        let schema = schema_to_arrow_checked(schema, CompatLevel::newest(), "parquet")?;
124        let column_options = get_column_write_options(&schema, &self.field_overwrites);
125        let parquet_schema = to_parquet_schema(&schema, &column_options)?;
126        let options = self.materialize_options();
127        let writer = Mutex::new(FileWriter::try_new(
128            self.writer,
129            schema,
130            options,
131            &column_options,
132        )?);
133
134        Ok(BatchedWriter {
135            writer,
136            parquet_schema,
137            column_options,
138            options,
139            parallel: self.parallel,
140            key_value_metadata: self.key_value_metadata,
141        })
142    }
143
144    fn materialize_options(&self) -> WriteOptions {
145        WriteOptions {
146            statistics: self.statistics,
147            compression: self.compression,
148            version: Version::V1,
149            data_page_size: self.data_page_size,
150        }
151    }
152
153    /// Write the given DataFrame in the writer `W`.
154    /// Returns the total size of the file.
155    pub fn finish(self, df: &mut DataFrame) -> PolarsResult<u64> {
156        let chunked_df = chunk_df_for_writing(df, self.row_group_size.unwrap_or(512 * 512))?;
157        let mut batched = self.batched(chunked_df.schema())?;
158        batched.write_batch(&chunked_df)?;
159        batched.finish()
160    }
161}
162
163fn convert_metadata(md: &Option<Vec<MetadataKeyValue>>) -> Vec<KeyValue> {
164    md.as_ref()
165        .map(|metadata| {
166            metadata
167                .iter()
168                .map(|kv| KeyValue {
169                    key: kv.key.to_string(),
170                    value: kv.value.as_ref().map(|v| v.to_string()),
171                })
172                .collect()
173        })
174        .unwrap_or_default()
175}
176
177fn to_column_write_options_rec(
178    field: &ArrowField,
179    overwrites: Option<&ParquetFieldOverwrites>,
180) -> ColumnWriteOptions {
181    let mut column_options = ColumnWriteOptions {
182        field_id: None,
183        metadata: Vec::new(),
184        required: None,
185
186        // Dummy value.
187        children: ChildWriteOptions::Leaf(FieldWriteOptions {
188            encoding: Encoding::Plain,
189        }),
190    };
191
192    if let Some(overwrites) = overwrites {
193        column_options.field_id = overwrites.field_id;
194        column_options.metadata = convert_metadata(&overwrites.metadata);
195        column_options.required = overwrites.required;
196    }
197
198    use arrow::datatypes::PhysicalType::*;
199    match field.dtype().to_physical_type() {
200        Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8
201        | Dictionary(_) | LargeUtf8 | BinaryView | Utf8View => {
202            column_options.children = ChildWriteOptions::Leaf(FieldWriteOptions {
203                encoding: encoding_map(field.dtype()),
204            });
205        },
206        List | FixedSizeList | LargeList => {
207            let child_overwrites = overwrites.and_then(|o| match &o.children {
208                ChildFieldOverwrites::None => None,
209                ChildFieldOverwrites::ListLike(child_overwrites) => Some(child_overwrites.as_ref()),
210                _ => unreachable!(),
211            });
212
213            let a = field.dtype().to_logical_type();
214            let child = if let ArrowDataType::List(inner) = a {
215                to_column_write_options_rec(inner, child_overwrites)
216            } else if let ArrowDataType::LargeList(inner) = a {
217                to_column_write_options_rec(inner, child_overwrites)
218            } else if let ArrowDataType::FixedSizeList(inner, _) = a {
219                to_column_write_options_rec(inner, child_overwrites)
220            } else {
221                unreachable!()
222            };
223
224            column_options.children =
225                ChildWriteOptions::ListLike(Box::new(ListLikeFieldWriteOptions { child }));
226        },
227        Struct => {
228            if let ArrowDataType::Struct(fields) = field.dtype().to_logical_type() {
229                let children_overwrites = overwrites.and_then(|o| match &o.children {
230                    ChildFieldOverwrites::None => None,
231                    ChildFieldOverwrites::Struct(child_overwrites) => Some(PlHashMap::from_iter(
232                        child_overwrites
233                            .iter()
234                            .map(|f| (f.name.as_ref().unwrap(), f)),
235                    )),
236                    _ => unreachable!(),
237                });
238
239                let children = fields
240                    .iter()
241                    .map(|f| {
242                        let overwrites = children_overwrites
243                            .as_ref()
244                            .and_then(|o| o.get(&f.name).copied());
245                        to_column_write_options_rec(f, overwrites)
246                    })
247                    .collect();
248
249                column_options.children =
250                    ChildWriteOptions::Struct(Box::new(StructFieldWriteOptions { children }));
251            } else {
252                unreachable!()
253            }
254        },
255
256        Map | Union => unreachable!(),
257    }
258
259    column_options
260}
261
262pub fn get_column_write_options(
263    schema: &ArrowSchema,
264    field_overwrites: &[ParquetFieldOverwrites],
265) -> Vec<ColumnWriteOptions> {
266    let field_overwrites = PlHashMap::from(
267        field_overwrites
268            .iter()
269            .map(|f| (f.name.as_ref().unwrap(), f))
270            .collect(),
271    );
272    schema
273        .iter_values()
274        .map(|f| to_column_write_options_rec(f, field_overwrites.get(&f.name).copied()))
275        .collect()
276}
277
278/// Declare encodings
279fn encoding_map(dtype: &ArrowDataType) -> Encoding {
280    match dtype.to_physical_type() {
281        PhysicalType::Dictionary(_)
282        | PhysicalType::LargeBinary
283        | PhysicalType::LargeUtf8
284        | PhysicalType::Utf8View
285        | PhysicalType::BinaryView => Encoding::RleDictionary,
286        PhysicalType::Primitive(dt) => {
287            use arrow::types::PrimitiveType::*;
288            match dt {
289                Float32 | Float64 | Float16 => Encoding::Plain,
290                _ => Encoding::RleDictionary,
291            }
292        },
293        // remaining is plain
294        _ => Encoding::Plain,
295    }
296}