polars_io/csv/write/
writer.rs1use std::io::Write;
2use std::num::NonZeroUsize;
3use std::sync::Arc;
4
5use polars_core::POOL;
6use polars_core::frame::DataFrame;
7use polars_core::schema::Schema;
8use polars_error::PolarsResult;
9use polars_utils::pl_str::PlSmallStr;
10
11use super::write_impl::{UTF8_BOM, csv_header, write};
12use super::{QuoteStyle, SerializeOptions};
13use crate::shared::SerWriter;
14
15#[must_use]
19pub struct CsvWriter<W: Write> {
20 buffer: W,
22 options: Arc<SerializeOptions>,
23 header: bool,
24 bom: bool,
25 batch_size: NonZeroUsize,
26 n_threads: usize,
27}
28
29impl<W> SerWriter<W> for CsvWriter<W>
30where
31 W: Write,
32{
33 fn new(buffer: W) -> Self {
34 let options = SerializeOptions::default();
35
36 CsvWriter {
37 buffer,
38 options: options.into(),
39 header: true,
40 bom: false,
41 batch_size: NonZeroUsize::new(1024).unwrap(),
42 n_threads: POOL.current_num_threads(),
43 }
44 }
45
46 fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> {
47 if self.bom {
48 self.buffer.write_all(&UTF8_BOM)?;
49 }
50 let names = df
51 .get_column_names()
52 .into_iter()
53 .map(|x| x.as_str())
54 .collect::<Vec<_>>();
55 if self.header {
56 self.buffer
57 .write_all(&csv_header(names.as_slice(), &self.options)?)?;
58 }
59 write(
60 &mut self.buffer,
61 df,
62 self.batch_size.into(),
63 self.options.clone(),
64 self.n_threads,
65 )
66 }
67}
68
69impl<W> CsvWriter<W>
70where
71 W: Write,
72{
73 fn options_mut(&mut self) -> &mut SerializeOptions {
74 Arc::make_mut(&mut self.options)
75 }
76
77 pub fn include_bom(mut self, include_bom: bool) -> Self {
79 self.bom = include_bom;
80 self
81 }
82
83 pub fn include_header(mut self, include_header: bool) -> Self {
85 self.header = include_header;
86 self
87 }
88
89 pub fn with_separator(mut self, separator: u8) -> Self {
91 self.options_mut().separator = separator;
92 self
93 }
94
95 pub fn with_batch_size(mut self, batch_size: NonZeroUsize) -> Self {
97 self.batch_size = batch_size;
98 self
99 }
100
101 pub fn with_date_format(mut self, format: Option<PlSmallStr>) -> Self {
103 if format.is_some() {
104 self.options_mut().date_format = format;
105 }
106 self
107 }
108
109 pub fn with_time_format(mut self, format: Option<PlSmallStr>) -> Self {
111 if format.is_some() {
112 self.options_mut().time_format = format;
113 }
114 self
115 }
116
117 pub fn with_datetime_format(mut self, format: Option<PlSmallStr>) -> Self {
119 if format.is_some() {
120 self.options_mut().datetime_format = format;
121 }
122 self
123 }
124
125 pub fn with_float_scientific(mut self, scientific: Option<bool>) -> Self {
127 if scientific.is_some() {
128 self.options_mut().float_scientific = scientific;
129 }
130 self
131 }
132
133 pub fn with_float_precision(mut self, precision: Option<usize>) -> Self {
135 if precision.is_some() {
136 self.options_mut().float_precision = precision;
137 }
138 self
139 }
140
141 pub fn with_decimal_comma(mut self, decimal_comma: bool) -> Self {
143 self.options_mut().decimal_comma = decimal_comma;
144 self
145 }
146
147 pub fn with_quote_char(mut self, char: u8) -> Self {
149 self.options_mut().quote_char = char;
150 self
151 }
152
153 pub fn with_null_value(mut self, null_value: PlSmallStr) -> Self {
155 self.options_mut().null = null_value;
156 self
157 }
158
159 pub fn with_line_terminator(mut self, line_terminator: PlSmallStr) -> Self {
161 self.options_mut().line_terminator = line_terminator;
162 self
163 }
164
165 pub fn with_quote_style(mut self, quote_style: QuoteStyle) -> Self {
168 self.options_mut().quote_style = quote_style;
169 self
170 }
171
172 pub fn n_threads(mut self, n_threads: usize) -> Self {
173 self.n_threads = n_threads;
174 self
175 }
176
177 pub fn batched(self, schema: &Schema) -> PolarsResult<BatchedWriter<W>> {
178 let expects_bom = self.bom;
179 let expects_header = self.header;
180 Ok(BatchedWriter {
181 writer: self,
182 has_written_bom: !expects_bom,
183 has_written_header: !expects_header,
184 schema: schema.clone(),
185 })
186 }
187}
188
189pub struct BatchedWriter<W: Write> {
190 writer: CsvWriter<W>,
191 has_written_bom: bool,
192 has_written_header: bool,
193 schema: Schema,
194}
195
196impl<W: Write> BatchedWriter<W> {
197 pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
202 if !self.has_written_bom {
203 self.has_written_bom = true;
204 self.writer.buffer.write_all(&UTF8_BOM)?;
205 }
206
207 if !self.has_written_header {
208 self.has_written_header = true;
209 let names = df
210 .get_column_names()
211 .into_iter()
212 .map(|x| x.as_str())
213 .collect::<Vec<_>>();
214
215 self.writer
216 .buffer
217 .write_all(&csv_header(names.as_slice(), &self.writer.options)?)?;
218 }
219
220 write(
221 &mut self.writer.buffer,
222 df,
223 self.writer.batch_size.into(),
224 self.writer.options.clone(),
225 self.writer.n_threads,
226 )?;
227 Ok(())
228 }
229
230 pub fn finish(&mut self) -> PolarsResult<()> {
232 if !self.has_written_bom {
233 self.has_written_bom = true;
234 self.writer.buffer.write_all(&UTF8_BOM)?;
235 }
236
237 if !self.has_written_header {
238 self.has_written_header = true;
239 let names = self
240 .schema
241 .iter_names()
242 .map(|x| x.as_str())
243 .collect::<Vec<_>>();
244
245 self.writer
246 .buffer
247 .write_all(&csv_header(&names, &self.writer.options)?)?;
248 };
249
250 Ok(())
251 }
252}