polars_io/csv/write/
writer.rs

1use std::io::Write;
2use std::num::NonZeroUsize;
3use std::sync::Arc;
4
5use polars_core::POOL;
6use polars_core::frame::DataFrame;
7use polars_core::schema::Schema;
8use polars_error::PolarsResult;
9use polars_utils::pl_str::PlSmallStr;
10
11use super::write_impl::{UTF8_BOM, csv_header, write};
12use super::{QuoteStyle, SerializeOptions};
13use crate::shared::SerWriter;
14
15/// Write a DataFrame to csv.
16///
17/// Don't use a `Buffered` writer, the `CsvWriter` internally already buffers writes.
18#[must_use]
19pub struct CsvWriter<W: Write> {
20    /// File or Stream handler
21    buffer: W,
22    options: Arc<SerializeOptions>,
23    header: bool,
24    bom: bool,
25    batch_size: NonZeroUsize,
26    n_threads: usize,
27}
28
29impl<W> SerWriter<W> for CsvWriter<W>
30where
31    W: Write,
32{
33    fn new(buffer: W) -> Self {
34        let options = SerializeOptions::default();
35
36        CsvWriter {
37            buffer,
38            options: options.into(),
39            header: true,
40            bom: false,
41            batch_size: NonZeroUsize::new(1024).unwrap(),
42            n_threads: POOL.current_num_threads(),
43        }
44    }
45
46    fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> {
47        if self.bom {
48            self.buffer.write_all(&UTF8_BOM)?;
49        }
50        let names = df
51            .get_column_names()
52            .into_iter()
53            .map(|x| x.as_str())
54            .collect::<Vec<_>>();
55        if self.header {
56            self.buffer
57                .write_all(&csv_header(names.as_slice(), &self.options)?)?;
58        }
59        write(
60            &mut self.buffer,
61            df,
62            self.batch_size.into(),
63            self.options.clone(),
64            self.n_threads,
65        )
66    }
67}
68
69impl<W> CsvWriter<W>
70where
71    W: Write,
72{
73    fn options_mut(&mut self) -> &mut SerializeOptions {
74        Arc::make_mut(&mut self.options)
75    }
76
77    /// Set whether to write UTF-8 UTF8_BOM.
78    pub fn include_bom(mut self, include_bom: bool) -> Self {
79        self.bom = include_bom;
80        self
81    }
82
83    /// Set whether to write headers.
84    pub fn include_header(mut self, include_header: bool) -> Self {
85        self.header = include_header;
86        self
87    }
88
89    /// Set the CSV file's column separator as a byte character.
90    pub fn with_separator(mut self, separator: u8) -> Self {
91        self.options_mut().separator = separator;
92        self
93    }
94
95    /// Set the batch size to use while writing the CSV.
96    pub fn with_batch_size(mut self, batch_size: NonZeroUsize) -> Self {
97        self.batch_size = batch_size;
98        self
99    }
100
101    /// Set the CSV file's date format.
102    pub fn with_date_format(mut self, format: Option<PlSmallStr>) -> Self {
103        if format.is_some() {
104            self.options_mut().date_format = format;
105        }
106        self
107    }
108
109    /// Set the CSV file's time format.
110    pub fn with_time_format(mut self, format: Option<PlSmallStr>) -> Self {
111        if format.is_some() {
112            self.options_mut().time_format = format;
113        }
114        self
115    }
116
117    /// Set the CSV file's datetime format.
118    pub fn with_datetime_format(mut self, format: Option<PlSmallStr>) -> Self {
119        if format.is_some() {
120            self.options_mut().datetime_format = format;
121        }
122        self
123    }
124
125    /// Set the CSV file's forced scientific notation for floats.
126    pub fn with_float_scientific(mut self, scientific: Option<bool>) -> Self {
127        if scientific.is_some() {
128            self.options_mut().float_scientific = scientific;
129        }
130        self
131    }
132
133    /// Set the CSV file's float precision.
134    pub fn with_float_precision(mut self, precision: Option<usize>) -> Self {
135        if precision.is_some() {
136            self.options_mut().float_precision = precision;
137        }
138        self
139    }
140
141    /// Set the CSV decimal separator.
142    pub fn with_decimal_comma(mut self, decimal_comma: bool) -> Self {
143        self.options_mut().decimal_comma = decimal_comma;
144        self
145    }
146
147    /// Set the single byte character used for quoting.
148    pub fn with_quote_char(mut self, char: u8) -> Self {
149        self.options_mut().quote_char = char;
150        self
151    }
152
153    /// Set the CSV file's null value representation.
154    pub fn with_null_value(mut self, null_value: PlSmallStr) -> Self {
155        self.options_mut().null = null_value;
156        self
157    }
158
159    /// Set the CSV file's line terminator.
160    pub fn with_line_terminator(mut self, line_terminator: PlSmallStr) -> Self {
161        self.options_mut().line_terminator = line_terminator;
162        self
163    }
164
165    /// Set the CSV file's quoting behavior.
166    /// See more on [`QuoteStyle`].
167    pub fn with_quote_style(mut self, quote_style: QuoteStyle) -> Self {
168        self.options_mut().quote_style = quote_style;
169        self
170    }
171
172    pub fn n_threads(mut self, n_threads: usize) -> Self {
173        self.n_threads = n_threads;
174        self
175    }
176
177    pub fn batched(self, schema: &Schema) -> PolarsResult<BatchedWriter<W>> {
178        let expects_bom = self.bom;
179        let expects_header = self.header;
180        Ok(BatchedWriter {
181            writer: self,
182            has_written_bom: !expects_bom,
183            has_written_header: !expects_header,
184            schema: schema.clone(),
185        })
186    }
187}
188
189pub struct BatchedWriter<W: Write> {
190    writer: CsvWriter<W>,
191    has_written_bom: bool,
192    has_written_header: bool,
193    schema: Schema,
194}
195
196impl<W: Write> BatchedWriter<W> {
197    /// Write a batch to the csv writer.
198    ///
199    /// # Panics
200    /// The caller must ensure the chunks in the given [`DataFrame`] are aligned.
201    pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
202        if !self.has_written_bom {
203            self.has_written_bom = true;
204            self.writer.buffer.write_all(&UTF8_BOM)?;
205        }
206
207        if !self.has_written_header {
208            self.has_written_header = true;
209            let names = df
210                .get_column_names()
211                .into_iter()
212                .map(|x| x.as_str())
213                .collect::<Vec<_>>();
214
215            self.writer
216                .buffer
217                .write_all(&csv_header(names.as_slice(), &self.writer.options)?)?;
218        }
219
220        write(
221            &mut self.writer.buffer,
222            df,
223            self.writer.batch_size.into(),
224            self.writer.options.clone(),
225            self.writer.n_threads,
226        )?;
227        Ok(())
228    }
229
230    /// Writes the header of the csv file if not done already. Returns the total size of the file.
231    pub fn finish(&mut self) -> PolarsResult<()> {
232        if !self.has_written_bom {
233            self.has_written_bom = true;
234            self.writer.buffer.write_all(&UTF8_BOM)?;
235        }
236
237        if !self.has_written_header {
238            self.has_written_header = true;
239            let names = self
240                .schema
241                .iter_names()
242                .map(|x| x.as_str())
243                .collect::<Vec<_>>();
244
245            self.writer
246                .buffer
247                .write_all(&csv_header(&names, &self.writer.options)?)?;
248        };
249
250        Ok(())
251    }
252}