1use std::io::Write;
2
3use arrow::datatypes::Metadata;
4use arrow::io::ipc::IpcField;
5use arrow::io::ipc::write::{self, EncodedData, WriteOptions};
6use polars_core::prelude::*;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10use crate::prelude::*;
11use crate::shared::schema_to_arrow_checked;
12
13#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
15#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
16pub struct IpcWriterOptions {
17 pub compression: Option<IpcCompression>,
19 pub compat_level: CompatLevel,
21 pub record_batch_size: Option<usize>,
23 pub record_batch_statistics: bool,
25 pub chunk_size: IdxSize,
27}
28
29impl Default for IpcWriterOptions {
30 fn default() -> Self {
31 Self {
32 compression: None,
33 compat_level: CompatLevel::newest(),
34 record_batch_size: None,
35 record_batch_statistics: false,
36 chunk_size: 1 << 18,
37 }
38 }
39}
40
41impl IpcWriterOptions {
42 pub fn to_writer<W: Write>(&self, writer: W) -> IpcWriter<W> {
43 IpcWriter::new(writer)
44 .with_compression(self.compression)
45 .with_record_batch_size(self.record_batch_size)
46 .with_record_batch_statistics(self.record_batch_statistics)
47 }
48}
49
50#[must_use]
77pub struct IpcWriter<W> {
78 pub(super) writer: W,
79 pub(super) compression: Option<IpcCompression>,
80 pub(super) compat_level: CompatLevel,
82 pub(super) record_batch_size: Option<usize>,
83 pub(super) record_batch_statistics: bool,
84 pub(super) parallel: bool,
85 pub(super) custom_schema_metadata: Option<Arc<Metadata>>,
86}
87
88impl<W: Write> IpcWriter<W> {
89 pub fn with_compression(mut self, compression: Option<IpcCompression>) -> Self {
91 self.compression = compression;
92 self
93 }
94
95 pub fn with_compat_level(mut self, compat_level: CompatLevel) -> Self {
96 self.compat_level = compat_level;
97 self
98 }
99
100 pub fn with_record_batch_size(mut self, record_batch_size: Option<usize>) -> Self {
101 self.record_batch_size = record_batch_size;
102 self
103 }
104
105 pub fn with_record_batch_statistics(mut self, record_batch_statistics: bool) -> Self {
106 self.record_batch_statistics = record_batch_statistics;
107 self
108 }
109
110 pub fn with_parallel(mut self, parallel: bool) -> Self {
111 self.parallel = parallel;
112 self
113 }
114
115 pub fn batched(
116 self,
117 schema: &Schema,
118 ipc_fields: Vec<IpcField>,
119 ) -> PolarsResult<BatchedWriter<W>> {
120 let schema = schema_to_arrow_checked(schema, self.compat_level, "ipc")?;
121 let mut writer = write::FileWriter::new(
122 self.writer,
123 Arc::new(schema),
124 Some(ipc_fields),
125 WriteOptions {
126 compression: self.compression.map(|c| c.into()),
127 },
128 );
129 writer.start()?;
130
131 Ok(BatchedWriter {
132 writer,
133 compat_level: self.compat_level,
134 })
135 }
136
137 pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc<Metadata>) {
139 self.custom_schema_metadata = Some(custom_metadata);
140 }
141}
142
143impl<W> SerWriter<W> for IpcWriter<W>
144where
145 W: Write,
146{
147 fn new(writer: W) -> Self {
148 IpcWriter {
149 writer,
150 compression: None,
151 compat_level: CompatLevel::newest(),
152 record_batch_size: None,
153 record_batch_statistics: false,
154 parallel: true,
155 custom_schema_metadata: None,
156 }
157 }
158
159 fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> {
160 let schema = schema_to_arrow_checked(df.schema(), self.compat_level, "ipc")?;
161 let mut ipc_writer = write::FileWriter::try_new(
162 &mut self.writer,
163 Arc::new(schema),
164 None,
165 WriteOptions {
166 compression: self.compression.map(|c| c.into()),
167 },
168 )?;
169 if let Some(custom_metadata) = &self.custom_schema_metadata {
170 ipc_writer.set_custom_schema_metadata(Arc::clone(custom_metadata));
171 }
172
173 if self.parallel {
174 df.align_chunks_par();
175 } else {
176 df.align_chunks();
177 }
178 let iter = df.iter_chunks(self.compat_level, true);
179
180 for batch in iter {
181 ipc_writer.write(&batch, None)?
182 }
183 ipc_writer.finish()?;
184 Ok(())
185 }
186}
187
188pub struct BatchedWriter<W: Write> {
189 writer: write::FileWriter<W>,
190 compat_level: CompatLevel,
191}
192
193impl<W: Write> BatchedWriter<W> {
194 pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
199 let iter = df.iter_chunks(self.compat_level, true);
200 for batch in iter {
201 self.writer.write(&batch, None)?
202 }
203 Ok(())
204 }
205
206 pub fn write_encoded(
211 &mut self,
212 dictionaries: &[EncodedData],
213 message: &EncodedData,
214 ) -> PolarsResult<()> {
215 self.writer.write_encoded(dictionaries, message)
216 }
217
218 pub fn write_encoded_dictionaries(
219 &mut self,
220 encoded_dictionaries: &[EncodedData],
221 ) -> PolarsResult<()> {
222 self.writer.write_encoded_dictionaries(encoded_dictionaries)
223 }
224
225 pub fn finish(&mut self) -> PolarsResult<()> {
227 self.writer.finish()?;
228 Ok(())
229 }
230}
231
232#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
234#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
235#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
236pub enum IpcCompression {
237 LZ4,
239 ZSTD(polars_utils::compression::ZstdLevel),
241}
242
243impl Default for IpcCompression {
244 fn default() -> Self {
245 Self::ZSTD(Default::default())
246 }
247}
248
249impl From<IpcCompression> for write::Compression {
250 fn from(value: IpcCompression) -> Self {
251 match value {
252 IpcCompression::LZ4 => write::Compression::LZ4,
253 IpcCompression::ZSTD(level) => write::Compression::ZSTD(level),
254 }
255 }
256}