1use std::io::Write;
2use std::sync::Mutex;
3
4use arrow::datatypes::PhysicalType;
5use polars_core::frame::chunk_df_for_writing;
6use polars_core::prelude::*;
7use polars_parquet::write::{
8 ChildWriteOptions, ColumnWriteOptions, CompressionOptions, Encoding, FieldWriteOptions,
9 FileWriter, KeyValue, ListLikeFieldWriteOptions, StatisticsOptions, StructFieldWriteOptions,
10 Version, WriteOptions, to_parquet_schema,
11};
12
13use super::batched_writer::BatchedWriter;
14use super::options::ParquetCompression;
15use super::{KeyValueMetadata, MetadataKeyValue, ParquetFieldOverwrites, ParquetWriteOptions};
16use crate::prelude::ChildFieldOverwrites;
17use crate::shared::schema_to_arrow_checked;
18
19impl ParquetWriteOptions {
20 pub fn to_writer<F>(&self, f: F) -> ParquetWriter<F>
21 where
22 F: Write,
23 {
24 ParquetWriter::new(f)
25 .with_compression(self.compression)
26 .with_statistics(self.statistics)
27 .with_row_group_size(self.row_group_size)
28 .with_data_page_size(self.data_page_size)
29 .with_key_value_metadata(self.key_value_metadata.clone())
30 }
31}
32
33#[must_use]
35pub struct ParquetWriter<W> {
36 writer: W,
37 compression: CompressionOptions,
39 statistics: StatisticsOptions,
41 row_group_size: Option<usize>,
43 data_page_size: Option<usize>,
45 parallel: bool,
47 field_overwrites: Vec<ParquetFieldOverwrites>,
48 key_value_metadata: Option<KeyValueMetadata>,
50 context_info: Option<PlHashMap<String, String>>,
52}
53
54impl<W> ParquetWriter<W>
55where
56 W: Write,
57{
58 pub fn new(writer: W) -> Self
60 where
61 W: Write,
62 {
63 ParquetWriter {
64 writer,
65 compression: ParquetCompression::default().into(),
66 statistics: StatisticsOptions::default(),
67 row_group_size: None,
68 data_page_size: None,
69 parallel: true,
70 field_overwrites: Vec::new(),
71 key_value_metadata: None,
72 context_info: None,
73 }
74 }
75
76 pub fn with_compression(mut self, compression: ParquetCompression) -> Self {
81 self.compression = compression.into();
82 self
83 }
84
85 pub fn with_statistics(mut self, statistics: StatisticsOptions) -> Self {
87 self.statistics = statistics;
88 self
89 }
90
91 pub fn with_row_group_size(mut self, size: Option<usize>) -> Self {
94 self.row_group_size = size;
95 self
96 }
97
98 pub fn with_data_page_size(mut self, limit: Option<usize>) -> Self {
100 self.data_page_size = limit;
101 self
102 }
103
104 pub fn set_parallel(mut self, parallel: bool) -> Self {
106 self.parallel = parallel;
107 self
108 }
109
110 pub fn with_key_value_metadata(mut self, key_value_metadata: Option<KeyValueMetadata>) -> Self {
112 self.key_value_metadata = key_value_metadata;
113 self
114 }
115
116 pub fn with_context_info(mut self, context_info: Option<PlHashMap<String, String>>) -> Self {
118 self.context_info = context_info;
119 self
120 }
121
122 pub fn batched(self, schema: &Schema) -> PolarsResult<BatchedWriter<W>> {
123 let schema = schema_to_arrow_checked(schema, CompatLevel::newest(), "parquet")?;
124 let column_options = get_column_write_options(&schema, &self.field_overwrites);
125 let parquet_schema = to_parquet_schema(&schema, &column_options)?;
126 let options = self.materialize_options();
127 let writer = Mutex::new(FileWriter::try_new(
128 self.writer,
129 schema,
130 options,
131 &column_options,
132 )?);
133
134 Ok(BatchedWriter {
135 writer,
136 parquet_schema,
137 column_options,
138 options,
139 parallel: self.parallel,
140 key_value_metadata: self.key_value_metadata,
141 })
142 }
143
144 fn materialize_options(&self) -> WriteOptions {
145 WriteOptions {
146 statistics: self.statistics,
147 compression: self.compression,
148 version: Version::V1,
149 data_page_size: self.data_page_size,
150 }
151 }
152
153 pub fn finish(self, df: &mut DataFrame) -> PolarsResult<u64> {
156 let chunked_df = chunk_df_for_writing(df, self.row_group_size.unwrap_or(512 * 512))?;
157 let mut batched = self.batched(chunked_df.schema())?;
158 batched.write_batch(&chunked_df)?;
159 batched.finish()
160 }
161}
162
163fn convert_metadata(md: &Option<Vec<MetadataKeyValue>>) -> Vec<KeyValue> {
164 md.as_ref()
165 .map(|metadata| {
166 metadata
167 .iter()
168 .map(|kv| KeyValue {
169 key: kv.key.to_string(),
170 value: kv.value.as_ref().map(|v| v.to_string()),
171 })
172 .collect()
173 })
174 .unwrap_or_default()
175}
176
177fn to_column_write_options_rec(
178 field: &ArrowField,
179 overwrites: Option<&ParquetFieldOverwrites>,
180) -> ColumnWriteOptions {
181 let mut column_options = ColumnWriteOptions {
182 field_id: None,
183 metadata: Vec::new(),
184
185 children: ChildWriteOptions::Leaf(FieldWriteOptions {
187 encoding: Encoding::Plain,
188 }),
189 };
190
191 if let Some(overwrites) = overwrites {
192 column_options.field_id = overwrites.field_id;
193 column_options.metadata = convert_metadata(&overwrites.metadata);
194 }
195
196 use arrow::datatypes::PhysicalType::*;
197 match field.dtype().to_physical_type() {
198 Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8
199 | Dictionary(_) | LargeUtf8 | BinaryView | Utf8View => {
200 column_options.children = ChildWriteOptions::Leaf(FieldWriteOptions {
201 encoding: encoding_map(field.dtype()),
202 });
203 },
204 List | FixedSizeList | LargeList => {
205 let child_overwrites = overwrites.map(|o| match &o.children {
206 ChildFieldOverwrites::ListLike(child_overwrites) => child_overwrites.as_ref(),
207 _ => unreachable!(),
208 });
209
210 let a = field.dtype().to_logical_type();
211 let child = if let ArrowDataType::List(inner) = a {
212 to_column_write_options_rec(inner, child_overwrites)
213 } else if let ArrowDataType::LargeList(inner) = a {
214 to_column_write_options_rec(inner, child_overwrites)
215 } else if let ArrowDataType::FixedSizeList(inner, _) = a {
216 to_column_write_options_rec(inner, child_overwrites)
217 } else {
218 unreachable!()
219 };
220
221 column_options.children =
222 ChildWriteOptions::ListLike(Box::new(ListLikeFieldWriteOptions { child }));
223 },
224 Struct => {
225 if let ArrowDataType::Struct(fields) = field.dtype().to_logical_type() {
226 let children_overwrites = overwrites.map(|o| match &o.children {
227 ChildFieldOverwrites::Struct(child_overwrites) => PlHashMap::from_iter(
228 child_overwrites
229 .iter()
230 .map(|f| (f.name.as_ref().unwrap(), f)),
231 ),
232 _ => unreachable!(),
233 });
234
235 let children = fields
236 .iter()
237 .map(|f| {
238 let overwrites = children_overwrites
239 .as_ref()
240 .and_then(|o| o.get(&f.name).copied());
241 to_column_write_options_rec(f, overwrites)
242 })
243 .collect();
244
245 column_options.children =
246 ChildWriteOptions::Struct(Box::new(StructFieldWriteOptions { children }));
247 } else {
248 unreachable!()
249 }
250 },
251
252 Map | Union => unreachable!(),
253 }
254
255 column_options
256}
257
258pub fn get_column_write_options(
259 schema: &ArrowSchema,
260 field_overwrites: &[ParquetFieldOverwrites],
261) -> Vec<ColumnWriteOptions> {
262 let field_overwrites = PlHashMap::from(
263 field_overwrites
264 .iter()
265 .map(|f| (f.name.as_ref().unwrap(), f))
266 .collect(),
267 );
268 schema
269 .iter_values()
270 .map(|f| to_column_write_options_rec(f, field_overwrites.get(&f.name).copied()))
271 .collect()
272}
273
274fn encoding_map(dtype: &ArrowDataType) -> Encoding {
276 match dtype.to_physical_type() {
277 PhysicalType::Dictionary(_)
278 | PhysicalType::LargeBinary
279 | PhysicalType::LargeUtf8
280 | PhysicalType::Utf8View
281 | PhysicalType::BinaryView => Encoding::RleDictionary,
282 PhysicalType::Primitive(dt) => {
283 use arrow::types::PrimitiveType::*;
284 match dt {
285 Float32 | Float64 | Float16 => Encoding::Plain,
286 _ => Encoding::RleDictionary,
287 }
288 },
289 _ => Encoding::Plain,
291 }
292}