polars_io/csv/write/
write_impl.rs

1mod serializer;
2
3use arrow::array::NullArray;
4use arrow::legacy::time_zone::Tz;
5use polars_core::POOL;
6use polars_core::prelude::*;
7use polars_error::polars_ensure;
8use polars_utils::reuse_vec::reuse_vec;
9use rayon::prelude::*;
10use serializer::{serializer_for, string_serializer};
11
12use crate::csv::write::SerializeOptions;
13
14type ColumnSerializer<'a> =
15    dyn crate::csv::write::write_impl::serializer::Serializer<'a> + Send + 'a;
16
17/// Writes CSV from DataFrames.
18pub struct CsvSerializer {
19    serializers: Vec<Box<ColumnSerializer<'static>>>,
20    options: Arc<SerializeOptions>,
21    datetime_formats: Arc<[PlSmallStr]>,
22    time_zones: Arc<[Option<Tz>]>,
23}
24
25impl Clone for CsvSerializer {
26    fn clone(&self) -> Self {
27        Self {
28            serializers: vec![],
29            options: self.options.clone(),
30            datetime_formats: self.datetime_formats.clone(),
31            time_zones: self.time_zones.clone(),
32        }
33    }
34}
35
36impl CsvSerializer {
37    pub fn new(schema: SchemaRef, options: Arc<SerializeOptions>) -> PolarsResult<Self> {
38        for dtype in schema.iter_values() {
39            let nested = match dtype {
40                DataType::List(_) => true,
41                #[cfg(feature = "dtype-struct")]
42                DataType::Struct(_) => true,
43                #[cfg(feature = "object")]
44                DataType::Object(_) => {
45                    return Err(PolarsError::ComputeError(
46                        "csv writer does not support object dtype".into(),
47                    ));
48                },
49                _ => false,
50            };
51            polars_ensure!(
52                !nested,
53                ComputeError: "CSV format does not support nested data",
54            );
55        }
56
57        // Check that the double quote is valid UTF-8.
58        polars_ensure!(
59            std::str::from_utf8(&[options.quote_char, options.quote_char]).is_ok(),
60            ComputeError: "quote char results in invalid utf-8",
61        );
62
63        let (datetime_formats, time_zones): (Vec<PlSmallStr>, Vec<Option<Tz>>) = schema
64            .iter_values()
65            .map(|dtype| {
66                let (datetime_format_str, time_zone) = match dtype {
67                    DataType::Datetime(TimeUnit::Milliseconds, tz) => {
68                        let (format, tz_parsed) = match tz {
69                            #[cfg(feature = "timezones")]
70                            Some(tz) => (
71                                options
72                                    .datetime_format
73                                    .as_deref()
74                                    .unwrap_or("%FT%H:%M:%S.%3f%z"),
75                                tz.parse::<Tz>().ok(),
76                            ),
77                            _ => (
78                                options
79                                    .datetime_format
80                                    .as_deref()
81                                    .unwrap_or("%FT%H:%M:%S.%3f"),
82                                None,
83                            ),
84                        };
85                        (format, tz_parsed)
86                    },
87                    DataType::Datetime(TimeUnit::Microseconds, tz) => {
88                        let (format, tz_parsed) = match tz {
89                            #[cfg(feature = "timezones")]
90                            Some(tz) => (
91                                options
92                                    .datetime_format
93                                    .as_deref()
94                                    .unwrap_or("%FT%H:%M:%S.%6f%z"),
95                                tz.parse::<Tz>().ok(),
96                            ),
97                            _ => (
98                                options
99                                    .datetime_format
100                                    .as_deref()
101                                    .unwrap_or("%FT%H:%M:%S.%6f"),
102                                None,
103                            ),
104                        };
105                        (format, tz_parsed)
106                    },
107                    DataType::Datetime(TimeUnit::Nanoseconds, tz) => {
108                        let (format, tz_parsed) = match tz {
109                            #[cfg(feature = "timezones")]
110                            Some(tz) => (
111                                options
112                                    .datetime_format
113                                    .as_deref()
114                                    .unwrap_or("%FT%H:%M:%S.%9f%z"),
115                                tz.parse::<Tz>().ok(),
116                            ),
117                            _ => (
118                                options
119                                    .datetime_format
120                                    .as_deref()
121                                    .unwrap_or("%FT%H:%M:%S.%9f"),
122                                None,
123                            ),
124                        };
125                        (format, tz_parsed)
126                    },
127                    _ => ("", None),
128                };
129
130                (datetime_format_str.into(), time_zone)
131            })
132            .collect();
133
134        Ok(Self {
135            serializers: vec![],
136            options,
137            datetime_formats: Arc::from_iter(datetime_formats),
138            time_zones: Arc::from_iter(time_zones),
139        })
140    }
141
142    /// # Panics
143    /// Panics if a column has >1 chunk.
144    pub fn serialize_to_csv<'a>(
145        &'a mut self,
146        df: &'a DataFrame,
147        buffer: &mut Vec<u8>,
148    ) -> PolarsResult<()> {
149        if df.height() == 0 || df.width() == 0 {
150            return Ok(());
151        }
152
153        let options = Arc::clone(&self.options);
154        let options = options.as_ref();
155
156        let mut serializers_vec = reuse_vec(std::mem::take(&mut self.serializers));
157        let serializers = self.build_serializers(df.columns(), &mut serializers_vec)?;
158
159        for _ in 0..df.height() {
160            serializers[0].serialize(buffer, options);
161            for serializer in &mut serializers[1..] {
162                buffer.push(options.separator);
163                serializer.serialize(buffer, options);
164            }
165
166            buffer.extend_from_slice(options.line_terminator.as_bytes());
167        }
168
169        self.serializers = reuse_vec(serializers_vec);
170
171        Ok(())
172    }
173
174    /// # Panics
175    /// Panics if a column has >1 chunk.
176    fn build_serializers<'a, 'b>(
177        &'a mut self,
178        columns: &'a [Column],
179        serializers: &'b mut Vec<Box<ColumnSerializer<'a>>>,
180    ) -> PolarsResult<&'b mut [Box<ColumnSerializer<'a>>]> {
181        serializers.clear();
182        serializers.reserve(columns.len());
183
184        for (i, c) in columns.iter().enumerate() {
185            assert_eq!(c.n_chunks(), 1);
186
187            serializers.push(serializer_for(
188                c.as_materialized_series().chunks()[0].as_ref(),
189                Arc::as_ref(&self.options),
190                c.dtype(),
191                self.datetime_formats[i].as_str(),
192                self.time_zones[i],
193            )?)
194        }
195
196        Ok(serializers)
197    }
198}
199
200pub(crate) fn write(
201    mut writer: impl std::io::Write,
202    df: &DataFrame,
203    chunk_size: usize,
204    options: Arc<SerializeOptions>,
205    n_threads: usize,
206) -> PolarsResult<()> {
207    let len = df.height();
208    let total_rows_per_pool_iter = n_threads * chunk_size;
209
210    let mut n_rows_finished = 0;
211
212    let csv_serializer = CsvSerializer::new(Arc::clone(df.schema()), options)?;
213
214    let mut buffers: Vec<(Vec<u8>, CsvSerializer)> = (0..n_threads)
215        .map(|_| (Vec::new(), csv_serializer.clone()))
216        .collect();
217    while n_rows_finished < len {
218        let buf_writer =
219            |thread_no, write_buffer: &mut Vec<_>, csv_serializer: &mut CsvSerializer| {
220                let thread_offset = thread_no * chunk_size;
221                let total_offset = n_rows_finished + thread_offset;
222                let mut df = df.slice(total_offset as i64, chunk_size);
223                // the `series.iter` needs rechunked series.
224                // we don't do this on the whole as this probably needs much less rechunking
225                // so will be faster.
226                // and allows writing `pl.concat([df] * 100, rechunk=False).write_csv()` as the rechunk
227                // would go OOM
228                df.rechunk_mut();
229
230                csv_serializer.serialize_to_csv(&df, write_buffer)?;
231
232                Ok(())
233            };
234
235        if n_threads > 1 {
236            POOL.install(|| {
237                buffers
238                    .par_iter_mut()
239                    .enumerate()
240                    .map(|(i, (w, s))| buf_writer(i, w, s))
241                    .collect::<PolarsResult<()>>()
242            })?;
243        } else {
244            let (w, s) = &mut buffers[0];
245            buf_writer(0, w, s)?;
246        }
247
248        for (write_buffer, _) in &mut buffers {
249            writer.write_all(write_buffer)?;
250            write_buffer.clear();
251        }
252
253        n_rows_finished += total_rows_per_pool_iter;
254    }
255    Ok(())
256}
257
258/// Writes a CSV header to `writer`.
259pub fn write_csv_header(
260    mut writer: impl std::io::Write,
261    names: &[&str],
262    options: &SerializeOptions,
263) -> PolarsResult<()> {
264    let mut header = Vec::new();
265
266    // A hack, but it works for this case.
267    let fake_arr = NullArray::new(ArrowDataType::Null, 0);
268    let mut names_serializer = string_serializer(
269        |iter: &mut std::slice::Iter<&str>| iter.next().copied(),
270        options,
271        |_| names.iter(),
272        &fake_arr,
273    );
274    for i in 0..names.len() {
275        names_serializer.serialize(&mut header, options);
276        if i != names.len() - 1 {
277            header.push(options.separator);
278        }
279    }
280    header.extend_from_slice(options.line_terminator.as_bytes());
281    writer.write_all(&header)?;
282    Ok(())
283}
284
285/// Writes a UTF-8 BOM to `writer`.
286pub fn write_bom(mut writer: impl std::io::Write) -> PolarsResult<()> {
287    const BOM: [u8; 3] = [0xEF, 0xBB, 0xBF];
288    writer.write_all(&BOM)?;
289    Ok(())
290}