use std::io::Write;
use std::sync::Mutex;
use arrow::datatypes::PhysicalType;
use polars_core::prelude::*;
use polars_parquet::write::{
to_parquet_schema, transverse, CompressionOptions, Encoding, FileWriter, StatisticsOptions,
Version, WriteOptions,
};
use super::batched_writer::BatchedWriter;
use super::options::ParquetCompression;
use crate::prelude::chunk_df_for_writing;
#[must_use]
pub struct ParquetWriter<W> {
writer: W,
compression: CompressionOptions,
statistics: StatisticsOptions,
row_group_size: Option<usize>,
data_page_size: Option<usize>,
parallel: bool,
}
impl<W> ParquetWriter<W>
where
W: Write,
{
pub fn new(writer: W) -> Self
where
W: Write,
{
ParquetWriter {
writer,
compression: ParquetCompression::default().into(),
statistics: StatisticsOptions::default(),
row_group_size: None,
data_page_size: None,
parallel: true,
}
}
pub fn with_compression(mut self, compression: ParquetCompression) -> Self {
self.compression = compression.into();
self
}
pub fn with_statistics(mut self, statistics: StatisticsOptions) -> Self {
self.statistics = statistics;
self
}
pub fn with_row_group_size(mut self, size: Option<usize>) -> Self {
self.row_group_size = size;
self
}
pub fn with_data_page_size(mut self, limit: Option<usize>) -> Self {
self.data_page_size = limit;
self
}
pub fn set_parallel(mut self, parallel: bool) -> Self {
self.parallel = parallel;
self
}
pub fn batched(self, schema: &Schema) -> PolarsResult<BatchedWriter<W>> {
let fields = schema.to_arrow(true).fields;
let schema = ArrowSchema::from(fields);
let parquet_schema = to_parquet_schema(&schema)?;
let encodings = get_encodings(&schema);
let options = self.materialize_options();
let writer = Mutex::new(FileWriter::try_new(self.writer, schema, options)?);
Ok(BatchedWriter {
writer,
parquet_schema,
encodings,
options,
parallel: self.parallel,
})
}
fn materialize_options(&self) -> WriteOptions {
WriteOptions {
statistics: self.statistics,
compression: self.compression,
version: Version::V1,
data_pagesize_limit: self.data_page_size,
}
}
pub fn finish(self, df: &mut DataFrame) -> PolarsResult<u64> {
let chunked_df = chunk_df_for_writing(df, self.row_group_size.unwrap_or(512 * 512))?;
let mut batched = self.batched(&chunked_df.schema())?;
batched.write_batch(&chunked_df)?;
batched.finish()
}
}
fn get_encodings(schema: &ArrowSchema) -> Vec<Vec<Encoding>> {
schema
.fields
.iter()
.map(|f| transverse(&f.data_type, encoding_map))
.collect()
}
fn encoding_map(data_type: &ArrowDataType) -> Encoding {
match data_type.to_physical_type() {
PhysicalType::Dictionary(_)
| PhysicalType::LargeBinary
| PhysicalType::LargeUtf8
| PhysicalType::Utf8View
| PhysicalType::BinaryView => Encoding::RleDictionary,
PhysicalType::Primitive(dt) => {
use arrow::types::PrimitiveType::*;
match dt {
Float32 | Float64 | Float16 => Encoding::Plain,
_ => Encoding::RleDictionary,
}
},
_ => Encoding::Plain,
}
}