1use std::io::Write;
2
3pub use Compression as AvroCompression;
4pub use arrow::io::avro::avro_schema::file::Compression;
5use arrow::io::avro::avro_schema::{self};
6use arrow::io::avro::write;
7use polars_core::error::to_compute_err;
8use polars_core::prelude::*;
9
10use crate::shared::{SerWriter, schema_to_arrow_checked};
11
12#[must_use]
32pub struct AvroWriter<W> {
33 writer: W,
34 compression: Option<AvroCompression>,
35 name: String,
36}
37
38impl<W> AvroWriter<W>
39where
40 W: Write,
41{
42 pub fn with_compression(mut self, compression: Option<AvroCompression>) -> Self {
44 self.compression = compression;
45 self
46 }
47
48 pub fn with_name(mut self, name: String) -> Self {
49 self.name = name;
50 self
51 }
52}
53
54impl<W> SerWriter<W> for AvroWriter<W>
55where
56 W: Write,
57{
58 fn new(writer: W) -> Self {
59 Self {
60 writer,
61 compression: None,
62 name: "".to_string(),
63 }
64 }
65
66 fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> {
67 let schema = schema_to_arrow_checked(df.schema(), CompatLevel::oldest(), "avro")?;
68 let record = write::to_record(&schema, self.name.clone())?;
69
70 let mut data = vec![];
71 let mut compressed_block = avro_schema::file::CompressedBlock::default();
72 for chunk in df.iter_chunks(CompatLevel::oldest(), true) {
73 let mut serializers = chunk
74 .iter()
75 .zip(record.fields.iter())
76 .map(|(array, field)| write::new_serializer(array.as_ref(), &field.schema))
77 .collect::<Vec<_>>();
78
79 let mut block =
80 avro_schema::file::Block::new(chunk.arrays()[0].len(), std::mem::take(&mut data));
81 write::serialize(&mut serializers, &mut block);
82 let _was_compressed =
83 avro_schema::write::compress(&mut block, &mut compressed_block, self.compression)
84 .map_err(to_compute_err)?;
85
86 avro_schema::write::write_metadata(&mut self.writer, record.clone(), self.compression)
87 .map_err(to_compute_err)?;
88
89 avro_schema::write::write_block(&mut self.writer, &compressed_block)
90 .map_err(to_compute_err)?;
91 data = block.data;
93 data.clear();
94
95 compressed_block.data.clear();
97 compressed_block.number_of_rows = 0
98 }
99
100 Ok(())
101 }
102}