use std::io::Write;
use std::path::PathBuf;
use arrow::io::ipc::write;
use arrow::io::ipc::write::WriteOptions;
use polars_core::prelude::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::prelude::*;
use crate::shared::WriterFactory;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct IpcWriterOptions {
pub compression: Option<IpcCompression>,
pub maintain_order: bool,
}
#[must_use]
pub struct IpcWriter<W> {
pub(super) writer: W,
pub(super) compression: Option<IpcCompression>,
pub(super) pl_flavor: bool,
}
impl<W: Write> IpcWriter<W> {
pub fn with_compression(mut self, compression: Option<IpcCompression>) -> Self {
self.compression = compression;
self
}
pub fn with_pl_flavor(mut self, pl_flavor: bool) -> Self {
self.pl_flavor = pl_flavor;
self
}
pub fn batched(self, schema: &Schema) -> PolarsResult<BatchedWriter<W>> {
let mut writer = write::FileWriter::new(
self.writer,
Arc::new(schema.to_arrow(self.pl_flavor)),
None,
WriteOptions {
compression: self.compression.map(|c| c.into()),
},
);
writer.start()?;
Ok(BatchedWriter {
writer,
pl_flavor: self.pl_flavor,
})
}
}
impl<W> SerWriter<W> for IpcWriter<W>
where
W: Write,
{
fn new(writer: W) -> Self {
IpcWriter {
writer,
compression: None,
pl_flavor: false,
}
}
fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> {
let mut ipc_writer = write::FileWriter::try_new(
&mut self.writer,
Arc::new(df.schema().to_arrow(self.pl_flavor)),
None,
WriteOptions {
compression: self.compression.map(|c| c.into()),
},
)?;
df.align_chunks();
let iter = df.iter_chunks(self.pl_flavor);
for batch in iter {
ipc_writer.write(&batch, None)?
}
ipc_writer.finish()?;
Ok(())
}
}
pub struct BatchedWriter<W: Write> {
writer: write::FileWriter<W>,
pl_flavor: bool,
}
impl<W: Write> BatchedWriter<W> {
pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
let iter = df.iter_chunks(self.pl_flavor);
for batch in iter {
self.writer.write(&batch, None)?
}
Ok(())
}
pub fn finish(&mut self) -> PolarsResult<()> {
self.writer.finish()?;
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum IpcCompression {
LZ4,
#[default]
ZSTD,
}
impl From<IpcCompression> for write::Compression {
fn from(value: IpcCompression) -> Self {
match value {
IpcCompression::LZ4 => write::Compression::LZ4,
IpcCompression::ZSTD => write::Compression::ZSTD,
}
}
}
pub struct IpcWriterOption {
compression: Option<IpcCompression>,
extension: PathBuf,
}
impl IpcWriterOption {
pub fn new() -> Self {
Self {
compression: None,
extension: PathBuf::from(".ipc"),
}
}
pub fn with_compression(mut self, compression: Option<IpcCompression>) -> Self {
self.compression = compression;
self
}
pub fn with_extension(mut self, extension: PathBuf) -> Self {
self.extension = extension;
self
}
}
impl Default for IpcWriterOption {
fn default() -> Self {
Self::new()
}
}
impl WriterFactory for IpcWriterOption {
fn create_writer<W: Write + 'static>(&self, writer: W) -> Box<dyn SerWriter<W>> {
Box::new(IpcWriter::new(writer).with_compression(self.compression))
}
fn extension(&self) -> PathBuf {
self.extension.to_owned()
}
}