polars_io/csv/write/
write_impl.rs

1mod serializer;
2
3use std::io::Write;
4
5use arrow::array::NullArray;
6use arrow::legacy::time_zone::Tz;
7use polars_core::POOL;
8use polars_core::prelude::*;
9use polars_error::polars_ensure;
10use rayon::prelude::*;
11use serializer::{serializer_for, string_serializer};
12
13use crate::csv::write::SerializeOptions;
14
15pub(crate) fn write<W: Write>(
16    writer: &mut W,
17    df: &DataFrame,
18    chunk_size: usize,
19    options: &SerializeOptions,
20    n_threads: usize,
21) -> PolarsResult<()> {
22    for s in df.get_columns() {
23        let nested = match s.dtype() {
24            DataType::List(_) => true,
25            #[cfg(feature = "dtype-struct")]
26            DataType::Struct(_) => true,
27            #[cfg(feature = "object")]
28            DataType::Object(_) => {
29                return Err(PolarsError::ComputeError(
30                    "csv writer does not support object dtype".into(),
31                ));
32            },
33            _ => false,
34        };
35        polars_ensure!(
36            !nested,
37            ComputeError: "CSV format does not support nested data",
38        );
39    }
40
41    // Check that the double quote is valid UTF-8.
42    polars_ensure!(
43        std::str::from_utf8(&[options.quote_char, options.quote_char]).is_ok(),
44        ComputeError: "quote char results in invalid utf-8",
45    );
46
47    let (datetime_formats, time_zones): (Vec<&str>, Vec<Option<Tz>>) = df
48        .get_columns()
49        .iter()
50        .map(|column| match column.dtype() {
51            DataType::Datetime(TimeUnit::Milliseconds, tz) => {
52                let (format, tz_parsed) = match tz {
53                    #[cfg(feature = "timezones")]
54                    Some(tz) => (
55                        options
56                            .datetime_format
57                            .as_deref()
58                            .unwrap_or("%FT%H:%M:%S.%3f%z"),
59                        tz.parse::<Tz>().ok(),
60                    ),
61                    _ => (
62                        options
63                            .datetime_format
64                            .as_deref()
65                            .unwrap_or("%FT%H:%M:%S.%3f"),
66                        None,
67                    ),
68                };
69                (format, tz_parsed)
70            },
71            DataType::Datetime(TimeUnit::Microseconds, tz) => {
72                let (format, tz_parsed) = match tz {
73                    #[cfg(feature = "timezones")]
74                    Some(tz) => (
75                        options
76                            .datetime_format
77                            .as_deref()
78                            .unwrap_or("%FT%H:%M:%S.%6f%z"),
79                        tz.parse::<Tz>().ok(),
80                    ),
81                    _ => (
82                        options
83                            .datetime_format
84                            .as_deref()
85                            .unwrap_or("%FT%H:%M:%S.%6f"),
86                        None,
87                    ),
88                };
89                (format, tz_parsed)
90            },
91            DataType::Datetime(TimeUnit::Nanoseconds, tz) => {
92                let (format, tz_parsed) = match tz {
93                    #[cfg(feature = "timezones")]
94                    Some(tz) => (
95                        options
96                            .datetime_format
97                            .as_deref()
98                            .unwrap_or("%FT%H:%M:%S.%9f%z"),
99                        tz.parse::<Tz>().ok(),
100                    ),
101                    _ => (
102                        options
103                            .datetime_format
104                            .as_deref()
105                            .unwrap_or("%FT%H:%M:%S.%9f"),
106                        None,
107                    ),
108                };
109                (format, tz_parsed)
110            },
111            _ => ("", None),
112        })
113        .unzip();
114
115    let len = df.height();
116    let total_rows_per_pool_iter = n_threads * chunk_size;
117
118    let mut n_rows_finished = 0;
119
120    // To comply with the safety requirements for the buf_writer closure, we need to make sure
121    // the column dtype references have a lifetime that exceeds the scope of the serializer, i.e.
122    // the full dataframe. If not, we can run into use-after-free memory issues for types that
123    // allocate, such as Enum or Categorical dtype (see GH issue #23939).
124    let col_dtypes: Vec<_> = df.get_columns().iter().map(|c| c.dtype()).collect();
125
126    let mut buffers: Vec<_> = (0..n_threads).map(|_| (Vec::new(), Vec::new())).collect();
127    while n_rows_finished < len {
128        let buf_writer = |thread_no, write_buffer: &mut Vec<_>, serializers_vec: &mut Vec<_>| {
129            let thread_offset = thread_no * chunk_size;
130            let total_offset = n_rows_finished + thread_offset;
131            let mut df = df.slice(total_offset as i64, chunk_size);
132            // the `series.iter` needs rechunked series.
133            // we don't do this on the whole as this probably needs much less rechunking
134            // so will be faster.
135            // and allows writing `pl.concat([df] * 100, rechunk=False).write_csv()` as the rechunk
136            // would go OOM
137            df.as_single_chunk();
138            let cols = df.get_columns();
139
140            // SAFETY:
141            // the bck thinks the lifetime is bounded to write_buffer_pool, but at the time we return
142            // the vectors the buffer pool, the series have already been removed from the buffers
143            // in other words, the lifetime does not leave this scope
144            let cols = unsafe { std::mem::transmute::<&[Column], &[Column]>(cols) };
145
146            if df.is_empty() {
147                return Ok(());
148            }
149
150            if serializers_vec.is_empty() {
151                debug_assert_eq!(cols.len(), col_dtypes.len());
152                *serializers_vec = std::iter::zip(cols, &col_dtypes)
153                    .enumerate()
154                    .map(|(i, (col, &col_dtype))| {
155                        serializer_for(
156                            &*col.as_materialized_series().chunks()[0],
157                            options,
158                            col_dtype,
159                            datetime_formats[i],
160                            time_zones[i],
161                        )
162                    })
163                    .collect::<Result<_, _>>()?;
164            } else {
165                debug_assert_eq!(serializers_vec.len(), cols.len());
166                for (col_iter, col) in std::iter::zip(serializers_vec.iter_mut(), cols) {
167                    col_iter.update_array(&*col.as_materialized_series().chunks()[0]);
168                }
169            }
170
171            let serializers = serializers_vec.as_mut_slice();
172
173            let len = std::cmp::min(cols[0].len(), chunk_size);
174
175            for _ in 0..len {
176                serializers[0].serialize(write_buffer, options);
177                for serializer in &mut serializers[1..] {
178                    write_buffer.push(options.separator);
179                    serializer.serialize(write_buffer, options);
180                }
181
182                write_buffer.extend_from_slice(options.line_terminator.as_bytes());
183            }
184
185            Ok(())
186        };
187
188        if n_threads > 1 {
189            POOL.install(|| {
190                buffers
191                    .par_iter_mut()
192                    .enumerate()
193                    .map(|(i, (w, s))| buf_writer(i, w, s))
194                    .collect::<PolarsResult<()>>()
195            })?;
196        } else {
197            let (w, s) = &mut buffers[0];
198            buf_writer(0, w, s)?;
199        }
200
201        for (write_buffer, _) in &mut buffers {
202            writer.write_all(write_buffer)?;
203            write_buffer.clear();
204        }
205
206        n_rows_finished += total_rows_per_pool_iter;
207    }
208    Ok(())
209}
210
211/// Writes a CSV header to `writer`.
212pub(crate) fn write_header<W: Write>(
213    writer: &mut W,
214    names: &[&str],
215    options: &SerializeOptions,
216) -> PolarsResult<()> {
217    let mut header = Vec::new();
218
219    // A hack, but it works for this case.
220    let fake_arr = NullArray::new(ArrowDataType::Null, 0);
221    let mut names_serializer = string_serializer(
222        |iter: &mut std::slice::Iter<&str>| iter.next().copied(),
223        options,
224        |_| names.iter(),
225        &fake_arr,
226    );
227    for i in 0..names.len() {
228        names_serializer.serialize(&mut header, options);
229        if i != names.len() - 1 {
230            header.push(options.separator);
231        }
232    }
233    header.extend_from_slice(options.line_terminator.as_bytes());
234    writer.write_all(&header)?;
235    Ok(())
236}
237
238/// Writes a UTF-8 BOM to `writer`.
239pub(crate) fn write_bom<W: Write>(writer: &mut W) -> PolarsResult<()> {
240    const BOM: [u8; 3] = [0xEF, 0xBB, 0xBF];
241    writer.write_all(&BOM)?;
242    Ok(())
243}