1use std::io::Write;
2use std::sync::Mutex;
3
4use arrow::record_batch::RecordBatch;
5use polars_core::POOL;
6use polars_core::prelude::*;
7use polars_parquet::read::{ParquetError, fallible_streaming_iterator};
8use polars_parquet::write::{
9 CompressedPage, Compressor, DynIter, DynStreamingIterator, Encoding, FallibleStreamingIterator,
10 FileWriter, Page, ParquetType, RowGroupIterColumns, SchemaDescriptor, WriteOptions,
11 array_to_columns,
12};
13use rayon::prelude::*;
14
15pub struct BatchedWriter<W: Write> {
16 pub(super) writer: Mutex<FileWriter<W>>,
21 pub(super) parquet_schema: SchemaDescriptor,
23 pub(super) encodings: Vec<Vec<Encoding>>,
24 pub(super) options: WriteOptions,
25 pub(super) parallel: bool,
26}
27
28impl<W: Write> BatchedWriter<W> {
29 pub fn new(
30 writer: Mutex<FileWriter<W>>,
31 encodings: Vec<Vec<Encoding>>,
32 options: WriteOptions,
33 parallel: bool,
34 ) -> Self {
35 Self {
36 writer,
37 parquet_schema: SchemaDescriptor::new(PlSmallStr::EMPTY, vec![]),
38 encodings,
39 options,
40 parallel,
41 }
42 }
43
44 pub fn encode_and_compress<'a>(
45 &'a self,
46 df: &'a DataFrame,
47 ) -> impl Iterator<Item = PolarsResult<RowGroupIterColumns<'static, PolarsError>>> + 'a {
48 let rb_iter = df.iter_chunks(CompatLevel::newest(), false);
49 rb_iter.filter_map(move |batch| match batch.len() {
50 0 => None,
51 _ => {
52 let row_group = create_eager_serializer(
53 batch,
54 self.parquet_schema.fields(),
55 self.encodings.as_ref(),
56 self.options,
57 );
58
59 Some(row_group)
60 },
61 })
62 }
63
64 pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
69 let row_group_iter = prepare_rg_iter(
70 df,
71 &self.parquet_schema,
72 &self.encodings,
73 self.options,
74 self.parallel,
75 );
76 let mut writer = self.writer.lock().unwrap();
78 for group in row_group_iter {
79 writer.write(group?)?;
80 }
81 Ok(())
82 }
83
84 pub fn parquet_schema(&mut self) -> &SchemaDescriptor {
85 let writer = self.writer.get_mut().unwrap();
86 writer.parquet_schema()
87 }
88
89 pub fn write_row_group(&mut self, rg: &[Vec<CompressedPage>]) -> PolarsResult<()> {
90 let writer = self.writer.get_mut().unwrap();
91 let rg = DynIter::new(rg.iter().map(|col_pages| {
92 Ok(DynStreamingIterator::new(
93 fallible_streaming_iterator::convert(col_pages.iter().map(PolarsResult::Ok)),
94 ))
95 }));
96 writer.write(rg)?;
97 Ok(())
98 }
99
100 pub fn get_writer(&self) -> &Mutex<FileWriter<W>> {
101 &self.writer
102 }
103
104 pub fn write_row_groups(
105 &self,
106 rgs: Vec<RowGroupIterColumns<'static, PolarsError>>,
107 ) -> PolarsResult<()> {
108 let mut writer = self.writer.lock().unwrap();
110 for group in rgs {
111 writer.write(group)?;
112 }
113 Ok(())
114 }
115
116 pub fn finish(&self) -> PolarsResult<u64> {
118 let mut writer = self.writer.lock().unwrap();
119 let size = writer.end(None)?;
120 Ok(size)
121 }
122}
123
124fn prepare_rg_iter<'a>(
126 df: &'a DataFrame,
127 parquet_schema: &'a SchemaDescriptor,
128 encodings: &'a [Vec<Encoding>],
129 options: WriteOptions,
130 parallel: bool,
131) -> impl Iterator<Item = PolarsResult<RowGroupIterColumns<'static, PolarsError>>> + 'a {
132 let rb_iter = df.iter_chunks(CompatLevel::newest(), false);
133 rb_iter.filter_map(move |batch| match batch.len() {
134 0 => None,
135 _ => {
136 let row_group =
137 create_serializer(batch, parquet_schema.fields(), encodings, options, parallel);
138
139 Some(row_group)
140 },
141 })
142}
143
144fn pages_iter_to_compressor(
145 encoded_columns: Vec<DynIter<'static, PolarsResult<Page>>>,
146 options: WriteOptions,
147) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {
148 encoded_columns
149 .into_iter()
150 .map(|encoded_pages| {
151 let pages = DynStreamingIterator::new(
153 Compressor::new_from_vec(
154 encoded_pages.map(|result| {
155 result.map_err(|e| {
156 ParquetError::FeatureNotSupported(format!("reraised in polars: {e}",))
157 })
158 }),
159 options.compression,
160 vec![],
161 )
162 .map_err(PolarsError::from),
163 );
164
165 Ok(pages)
166 })
167 .collect::<Vec<_>>()
168}
169
170fn array_to_pages_iter(
171 array: &ArrayRef,
172 type_: &ParquetType,
173 encoding: &[Encoding],
174 options: WriteOptions,
175) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {
176 let encoded_columns = array_to_columns(array, type_.clone(), options, encoding).unwrap();
177 pages_iter_to_compressor(encoded_columns, options)
178}
179
180fn create_serializer(
181 batch: RecordBatch,
182 fields: &[ParquetType],
183 encodings: &[Vec<Encoding>],
184 options: WriteOptions,
185 parallel: bool,
186) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {
187 let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec<Encoding>)| {
188 array_to_pages_iter(array, type_, encoding, options)
189 };
190
191 let columns = if parallel {
192 POOL.install(|| {
193 batch
194 .columns()
195 .par_iter()
196 .zip(fields)
197 .zip(encodings)
198 .flat_map(func)
199 .collect::<Vec<_>>()
200 })
201 } else {
202 batch
203 .columns()
204 .iter()
205 .zip(fields)
206 .zip(encodings)
207 .flat_map(func)
208 .collect::<Vec<_>>()
209 };
210
211 let row_group = DynIter::new(columns.into_iter());
212
213 Ok(row_group)
214}
215
216fn create_eager_serializer(
219 batch: RecordBatch,
220 fields: &[ParquetType],
221 encodings: &[Vec<Encoding>],
222 options: WriteOptions,
223) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {
224 let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec<Encoding>)| {
225 array_to_pages_iter(array, type_, encoding, options)
226 };
227
228 let columns = batch
229 .columns()
230 .iter()
231 .zip(fields)
232 .zip(encodings)
233 .flat_map(func)
234 .collect::<Vec<_>>();
235
236 let row_group = DynIter::new(columns.into_iter());
237
238 Ok(row_group)
239}