1use std::io::Write;
2use std::sync::Mutex;
3
4use arrow::record_batch::RecordBatch;
5use polars_core::POOL;
6use polars_core::prelude::*;
7use polars_parquet::read::{ParquetError, fallible_streaming_iterator};
8use polars_parquet::write::{
9 ColumnWriteOptions, CompressedPage, Compressor, DynIter, DynStreamingIterator,
10 FallibleStreamingIterator, FileWriter, Page, ParquetType, RowGroupIterColumns,
11 SchemaDescriptor, WriteOptions, array_to_columns, schema_to_metadata_key,
12};
13use rayon::prelude::*;
14
15use super::{KeyValueMetadata, ParquetMetadataContext};
16
17pub struct BatchedWriter<W: Write> {
18 pub(super) writer: Mutex<FileWriter<W>>,
23 pub(super) parquet_schema: SchemaDescriptor,
25 pub(super) column_options: Vec<ColumnWriteOptions>,
26 pub(super) options: WriteOptions,
27 pub(super) parallel: bool,
28 pub(super) key_value_metadata: Option<KeyValueMetadata>,
29}
30
31impl<W: Write> BatchedWriter<W> {
32 pub fn new(
33 writer: Mutex<FileWriter<W>>,
34 column_options: Vec<ColumnWriteOptions>,
35 options: WriteOptions,
36 parallel: bool,
37 key_value_metadata: Option<KeyValueMetadata>,
38 ) -> Self {
39 Self {
40 writer,
41 parquet_schema: SchemaDescriptor::new(PlSmallStr::EMPTY, vec![]),
42 column_options,
43 options,
44 parallel,
45 key_value_metadata,
46 }
47 }
48
49 pub fn encode_and_compress<'a>(
50 &'a self,
51 df: &'a DataFrame,
52 ) -> impl Iterator<Item = PolarsResult<RowGroupIterColumns<'static, PolarsError>>> + 'a {
53 let rb_iter = df.iter_chunks(CompatLevel::newest(), false);
54 rb_iter.filter_map(move |batch| match batch.len() {
55 0 => None,
56 _ => {
57 let row_group = create_eager_serializer(
58 batch,
59 self.parquet_schema.fields(),
60 self.column_options.as_ref(),
61 self.options,
62 );
63
64 Some(row_group)
65 },
66 })
67 }
68
69 pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
74 let row_group_iter = prepare_rg_iter(
75 df,
76 &self.parquet_schema,
77 &self.column_options,
78 self.options,
79 self.parallel,
80 );
81 let mut writer = self.writer.lock().unwrap();
83 for group in row_group_iter {
84 writer.write(group?)?;
85 }
86 Ok(())
87 }
88
89 pub fn parquet_schema(&mut self) -> &SchemaDescriptor {
90 let writer = self.writer.get_mut().unwrap();
91 writer.parquet_schema()
92 }
93
94 pub fn write_row_group(&mut self, rg: &[Vec<CompressedPage>]) -> PolarsResult<()> {
95 let writer = self.writer.get_mut().unwrap();
96 let rg = DynIter::new(rg.iter().map(|col_pages| {
97 Ok(DynStreamingIterator::new(
98 fallible_streaming_iterator::convert(col_pages.iter().map(PolarsResult::Ok)),
99 ))
100 }));
101 writer.write(rg)?;
102 Ok(())
103 }
104
105 pub fn get_writer(&self) -> &Mutex<FileWriter<W>> {
106 &self.writer
107 }
108
109 pub fn write_row_groups(
110 &self,
111 rgs: Vec<RowGroupIterColumns<'static, PolarsError>>,
112 ) -> PolarsResult<()> {
113 let mut writer = self.writer.lock().unwrap();
115 for group in rgs {
116 writer.write(group)?;
117 }
118 Ok(())
119 }
120
121 pub fn finish(&self) -> PolarsResult<u64> {
123 let mut writer = self.writer.lock().unwrap();
124
125 let key_value_metadata = self
126 .key_value_metadata
127 .as_ref()
128 .map(|meta| {
129 let arrow_schema = schema_to_metadata_key(writer.schema(), &self.column_options);
130 let ctx = ParquetMetadataContext {
131 arrow_schema: arrow_schema.value.as_ref().unwrap(),
132 };
133 let mut out = meta.collect(ctx)?;
134 if !out.iter().any(|kv| kv.key == arrow_schema.key) {
135 out.insert(0, arrow_schema);
136 }
137 PolarsResult::Ok(out)
138 })
139 .transpose()?;
140
141 let size = writer.end(key_value_metadata, &self.column_options)?;
142 Ok(size)
143 }
144}
145
146fn prepare_rg_iter<'a>(
148 df: &'a DataFrame,
149 parquet_schema: &'a SchemaDescriptor,
150 column_options: &'a [ColumnWriteOptions],
151 options: WriteOptions,
152 parallel: bool,
153) -> impl Iterator<Item = PolarsResult<RowGroupIterColumns<'static, PolarsError>>> + 'a {
154 let rb_iter = df.iter_chunks(CompatLevel::newest(), false);
155 rb_iter.filter_map(move |batch| match batch.len() {
156 0 => None,
157 _ => {
158 let row_group = create_serializer(
159 batch,
160 parquet_schema.fields(),
161 column_options,
162 options,
163 parallel,
164 );
165
166 Some(row_group)
167 },
168 })
169}
170
171fn pages_iter_to_compressor(
172 encoded_columns: Vec<DynIter<'static, PolarsResult<Page>>>,
173 options: WriteOptions,
174) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {
175 encoded_columns
176 .into_iter()
177 .map(|encoded_pages| {
178 let pages = DynStreamingIterator::new(
180 Compressor::new_from_vec(
181 encoded_pages.map(|result| {
182 result.map_err(|e| {
183 ParquetError::FeatureNotSupported(format!("reraised in polars: {e}",))
184 })
185 }),
186 options.compression,
187 vec![],
188 )
189 .map_err(PolarsError::from),
190 );
191
192 Ok(pages)
193 })
194 .collect::<Vec<_>>()
195}
196
197fn array_to_pages_iter(
198 array: &ArrayRef,
199 type_: &ParquetType,
200 column_options: &ColumnWriteOptions,
201 options: WriteOptions,
202) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {
203 let encoded_columns = array_to_columns(array, type_.clone(), column_options, options).unwrap();
204 pages_iter_to_compressor(encoded_columns, options)
205}
206
207fn create_serializer(
208 batch: RecordBatch,
209 fields: &[ParquetType],
210 column_options: &[ColumnWriteOptions],
211 options: WriteOptions,
212 parallel: bool,
213) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {
214 let func = move |((array, type_), column_options): (
215 (&ArrayRef, &ParquetType),
216 &ColumnWriteOptions,
217 )| { array_to_pages_iter(array, type_, column_options, options) };
218
219 let columns = if parallel {
220 POOL.install(|| {
221 batch
222 .columns()
223 .par_iter()
224 .zip(fields)
225 .zip(column_options)
226 .flat_map(func)
227 .collect::<Vec<_>>()
228 })
229 } else {
230 batch
231 .columns()
232 .iter()
233 .zip(fields)
234 .zip(column_options)
235 .flat_map(func)
236 .collect::<Vec<_>>()
237 };
238
239 let row_group = DynIter::new(columns.into_iter());
240
241 Ok(row_group)
242}
243
244fn create_eager_serializer(
247 batch: RecordBatch,
248 fields: &[ParquetType],
249 column_options: &[ColumnWriteOptions],
250 options: WriteOptions,
251) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {
252 let func = move |((array, type_), column_options): (
253 (&ArrayRef, &ParquetType),
254 &ColumnWriteOptions,
255 )| { array_to_pages_iter(array, type_, column_options, options) };
256
257 let columns = batch
258 .columns()
259 .iter()
260 .zip(fields)
261 .zip(column_options)
262 .flat_map(func)
263 .collect::<Vec<_>>();
264
265 let row_group = DynIter::new(columns.into_iter());
266
267 Ok(row_group)
268}