1use std::io::Write;
2use std::sync::Mutex;
3
4use arrow::record_batch::RecordBatch;
5use polars_buffer::Buffer;
6use polars_core::POOL;
7use polars_core::prelude::*;
8use polars_parquet::read::{ParquetError, fallible_streaming_iterator};
9use polars_parquet::write::{
10 CompressedPage, Compressor, DynIter, DynStreamingIterator, Encoding, FallibleStreamingIterator,
11 FileWriter, Page, ParquetType, RowGroupIterColumns, SchemaDescriptor, WriteOptions,
12 array_to_columns, schema_to_metadata_key,
13};
14use rayon::prelude::*;
15
16use super::{KeyValueMetadata, ParquetMetadataContext};
17
18pub struct BatchedWriter<W: Write> {
19 pub(super) writer: Mutex<FileWriter<W>>,
24 pub(super) parquet_schema: SchemaDescriptor,
26 pub(super) encodings: Buffer<Vec<Encoding>>,
27 pub(super) options: WriteOptions,
28 pub(super) parallel: bool,
29 pub(super) key_value_metadata: Option<KeyValueMetadata>,
30}
31
32impl<W: Write> BatchedWriter<W> {
33 pub fn new(
34 writer: Mutex<FileWriter<W>>,
35 encodings: Buffer<Vec<Encoding>>,
36 options: WriteOptions,
37 parallel: bool,
38 key_value_metadata: Option<KeyValueMetadata>,
39 ) -> Self {
40 Self {
41 writer,
42 parquet_schema: SchemaDescriptor::new(PlSmallStr::EMPTY, vec![]),
43 encodings,
44 options,
45 parallel,
46 key_value_metadata,
47 }
48 }
49
50 pub fn encode_and_compress<'a>(
51 &'a self,
52 df: &'a DataFrame,
53 ) -> impl Iterator<Item = PolarsResult<RowGroupIterColumns<'static, PolarsError>>> + 'a {
54 let rb_iter = df.iter_chunks(CompatLevel::newest(), false);
55 rb_iter.filter_map(move |batch| match batch.len() {
56 0 => None,
57 _ => {
58 let row_group = create_eager_serializer(
59 batch,
60 self.parquet_schema.fields(),
61 self.encodings.as_ref(),
62 self.options,
63 );
64
65 Some(row_group)
66 },
67 })
68 }
69
70 pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
75 let row_group_iter = prepare_rg_iter(
76 df,
77 &self.parquet_schema,
78 &self.encodings,
79 self.options,
80 self.parallel,
81 );
82 let mut writer = self.writer.lock().unwrap();
84 for (num_rows, group) in row_group_iter {
85 writer.write(num_rows as u64, group?)?;
86 }
87 Ok(())
88 }
89
90 pub fn parquet_schema(&mut self) -> &SchemaDescriptor {
91 let writer = self.writer.get_mut().unwrap();
92 writer.parquet_schema()
93 }
94
95 pub fn write_row_group(
97 &mut self,
98 num_rows: u64,
99 rg: &[Vec<CompressedPage>],
100 ) -> PolarsResult<()> {
101 let writer = self.writer.get_mut().unwrap();
102 let rg = DynIter::new(rg.iter().map(|col_pages| {
103 Ok(DynStreamingIterator::new(
104 fallible_streaming_iterator::convert(col_pages.iter().map(PolarsResult::Ok)),
105 ))
106 }));
107 writer.write(num_rows, rg)?;
108 Ok(())
109 }
110
111 pub fn get_writer(&self) -> &Mutex<FileWriter<W>> {
112 &self.writer
113 }
114
115 pub fn write_row_groups(
116 &self,
117 rgs: Vec<RowGroupIterColumns<'static, PolarsError>>,
118 ) -> PolarsResult<()> {
119 let mut writer = self.writer.lock().unwrap();
121 for group in rgs {
122 writer.write(u64::MAX, group)?;
123 }
124 Ok(())
125 }
126
127 pub fn finish(&self) -> PolarsResult<u64> {
129 let mut writer = self.writer.lock().unwrap();
130
131 let key_value_metadata = self
132 .key_value_metadata
133 .as_ref()
134 .map(|meta| {
135 let arrow_schema = schema_to_metadata_key(writer.schema());
136 let ctx = ParquetMetadataContext {
137 arrow_schema: arrow_schema.value.as_ref().unwrap(),
138 };
139 let mut out = meta.collect(ctx)?;
140 if !out.iter().any(|kv| kv.key == arrow_schema.key) {
141 out.insert(0, arrow_schema);
142 }
143 PolarsResult::Ok(out)
144 })
145 .transpose()?;
146
147 let size = writer.end(key_value_metadata)?;
148 Ok(size)
149 }
150}
151
152fn prepare_rg_iter<'a>(
154 df: &'a DataFrame,
155 parquet_schema: &'a SchemaDescriptor,
156 encodings: &'a [Vec<Encoding>],
157 options: WriteOptions,
158 parallel: bool,
159) -> impl Iterator<
160 Item = (
161 usize,
162 PolarsResult<RowGroupIterColumns<'static, PolarsError>>,
163 ),
164> + 'a {
165 let rb_iter = df.iter_chunks(CompatLevel::newest(), false);
166 rb_iter.filter_map(move |batch| match batch.len() {
167 0 => None,
168 num_rows => {
169 let row_group =
170 create_serializer(batch, parquet_schema.fields(), encodings, options, parallel);
171
172 Some((num_rows, row_group))
173 },
174 })
175}
176
177fn pages_iter_to_compressor(
178 encoded_columns: Vec<DynIter<'static, PolarsResult<Page>>>,
179 options: WriteOptions,
180) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {
181 encoded_columns
182 .into_iter()
183 .map(|encoded_pages| {
184 let pages = DynStreamingIterator::new(
186 Compressor::new_from_vec(
187 encoded_pages.map(|result| {
188 result.map_err(|e| {
189 ParquetError::FeatureNotSupported(format!("reraised in polars: {e}",))
190 })
191 }),
192 options.compression,
193 vec![],
194 )
195 .map_err(PolarsError::from),
196 );
197
198 Ok(pages)
199 })
200 .collect::<Vec<_>>()
201}
202
203fn array_to_pages_iter(
204 array: &ArrayRef,
205 type_: &ParquetType,
206 encoding: &[Encoding],
207 options: WriteOptions,
208) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {
209 let encoded_columns = array_to_columns(array, type_.clone(), options, encoding).unwrap();
210 pages_iter_to_compressor(encoded_columns, options)
211}
212
213fn create_serializer(
214 batch: RecordBatch,
215 fields: &[ParquetType],
216 encodings: &[Vec<Encoding>],
217 options: WriteOptions,
218 parallel: bool,
219) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {
220 let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec<Encoding>)| {
221 array_to_pages_iter(array, type_, encoding, options)
222 };
223
224 let columns = if parallel {
225 POOL.install(|| {
226 batch
227 .columns()
228 .par_iter()
229 .zip(fields)
230 .zip(encodings)
231 .flat_map(func)
232 .collect::<Vec<_>>()
233 })
234 } else {
235 batch
236 .columns()
237 .iter()
238 .zip(fields)
239 .zip(encodings)
240 .flat_map(func)
241 .collect::<Vec<_>>()
242 };
243
244 let row_group = DynIter::new(columns.into_iter());
245
246 Ok(row_group)
247}
248
249fn create_eager_serializer(
252 batch: RecordBatch,
253 fields: &[ParquetType],
254 encodings: &[Vec<Encoding>],
255 options: WriteOptions,
256) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {
257 let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec<Encoding>)| {
258 array_to_pages_iter(array, type_, encoding, options)
259 };
260
261 let columns = batch
262 .columns()
263 .iter()
264 .zip(fields)
265 .zip(encodings)
266 .flat_map(func)
267 .collect::<Vec<_>>();
268
269 let row_group = DynIter::new(columns.into_iter());
270
271 Ok(row_group)
272}