1mod serializer;
2
3use std::io::Write;
4
5use arrow::array::NullArray;
6use arrow::legacy::time_zone::Tz;
7use polars_core::POOL;
8use polars_core::prelude::*;
9use polars_error::polars_ensure;
10use rayon::prelude::*;
11use serializer::{serializer_for, string_serializer};
12
13use crate::csv::write::SerializeOptions;
14
15pub(crate) fn write<W: Write>(
16 writer: &mut W,
17 df: &DataFrame,
18 chunk_size: usize,
19 options: &SerializeOptions,
20 n_threads: usize,
21) -> PolarsResult<()> {
22 for s in df.get_columns() {
23 let nested = match s.dtype() {
24 DataType::List(_) => true,
25 #[cfg(feature = "dtype-struct")]
26 DataType::Struct(_) => true,
27 #[cfg(feature = "object")]
28 DataType::Object(_) => {
29 return Err(PolarsError::ComputeError(
30 "csv writer does not support object dtype".into(),
31 ));
32 },
33 _ => false,
34 };
35 polars_ensure!(
36 !nested,
37 ComputeError: "CSV format does not support nested data",
38 );
39 }
40
41 polars_ensure!(
43 std::str::from_utf8(&[options.quote_char, options.quote_char]).is_ok(),
44 ComputeError: "quote char results in invalid utf-8",
45 );
46
47 let (datetime_formats, time_zones): (Vec<&str>, Vec<Option<Tz>>) = df
48 .get_columns()
49 .iter()
50 .map(|column| match column.dtype() {
51 DataType::Datetime(TimeUnit::Milliseconds, tz) => {
52 let (format, tz_parsed) = match tz {
53 #[cfg(feature = "timezones")]
54 Some(tz) => (
55 options
56 .datetime_format
57 .as_deref()
58 .unwrap_or("%FT%H:%M:%S.%3f%z"),
59 tz.parse::<Tz>().ok(),
60 ),
61 _ => (
62 options
63 .datetime_format
64 .as_deref()
65 .unwrap_or("%FT%H:%M:%S.%3f"),
66 None,
67 ),
68 };
69 (format, tz_parsed)
70 },
71 DataType::Datetime(TimeUnit::Microseconds, tz) => {
72 let (format, tz_parsed) = match tz {
73 #[cfg(feature = "timezones")]
74 Some(tz) => (
75 options
76 .datetime_format
77 .as_deref()
78 .unwrap_or("%FT%H:%M:%S.%6f%z"),
79 tz.parse::<Tz>().ok(),
80 ),
81 _ => (
82 options
83 .datetime_format
84 .as_deref()
85 .unwrap_or("%FT%H:%M:%S.%6f"),
86 None,
87 ),
88 };
89 (format, tz_parsed)
90 },
91 DataType::Datetime(TimeUnit::Nanoseconds, tz) => {
92 let (format, tz_parsed) = match tz {
93 #[cfg(feature = "timezones")]
94 Some(tz) => (
95 options
96 .datetime_format
97 .as_deref()
98 .unwrap_or("%FT%H:%M:%S.%9f%z"),
99 tz.parse::<Tz>().ok(),
100 ),
101 _ => (
102 options
103 .datetime_format
104 .as_deref()
105 .unwrap_or("%FT%H:%M:%S.%9f"),
106 None,
107 ),
108 };
109 (format, tz_parsed)
110 },
111 _ => ("", None),
112 })
113 .unzip();
114
115 let len = df.height();
116 let total_rows_per_pool_iter = n_threads * chunk_size;
117
118 let mut n_rows_finished = 0;
119
120 let col_dtypes: Vec<_> = df.get_columns().iter().map(|c| c.dtype()).collect();
125
126 let mut buffers: Vec<_> = (0..n_threads).map(|_| (Vec::new(), Vec::new())).collect();
127 while n_rows_finished < len {
128 let buf_writer = |thread_no, write_buffer: &mut Vec<_>, serializers_vec: &mut Vec<_>| {
129 let thread_offset = thread_no * chunk_size;
130 let total_offset = n_rows_finished + thread_offset;
131 let mut df = df.slice(total_offset as i64, chunk_size);
132 df.as_single_chunk();
138 let cols = df.get_columns();
139
140 let cols = unsafe { std::mem::transmute::<&[Column], &[Column]>(cols) };
145
146 if df.is_empty() {
147 return Ok(());
148 }
149
150 if serializers_vec.is_empty() {
151 debug_assert_eq!(cols.len(), col_dtypes.len());
152 *serializers_vec = std::iter::zip(cols, &col_dtypes)
153 .enumerate()
154 .map(|(i, (col, &col_dtype))| {
155 serializer_for(
156 &*col.as_materialized_series().chunks()[0],
157 options,
158 col_dtype,
159 datetime_formats[i],
160 time_zones[i],
161 )
162 })
163 .collect::<Result<_, _>>()?;
164 } else {
165 debug_assert_eq!(serializers_vec.len(), cols.len());
166 for (col_iter, col) in std::iter::zip(serializers_vec.iter_mut(), cols) {
167 col_iter.update_array(&*col.as_materialized_series().chunks()[0]);
168 }
169 }
170
171 let serializers = serializers_vec.as_mut_slice();
172
173 let len = std::cmp::min(cols[0].len(), chunk_size);
174
175 for _ in 0..len {
176 serializers[0].serialize(write_buffer, options);
177 for serializer in &mut serializers[1..] {
178 write_buffer.push(options.separator);
179 serializer.serialize(write_buffer, options);
180 }
181
182 write_buffer.extend_from_slice(options.line_terminator.as_bytes());
183 }
184
185 Ok(())
186 };
187
188 if n_threads > 1 {
189 POOL.install(|| {
190 buffers
191 .par_iter_mut()
192 .enumerate()
193 .map(|(i, (w, s))| buf_writer(i, w, s))
194 .collect::<PolarsResult<()>>()
195 })?;
196 } else {
197 let (w, s) = &mut buffers[0];
198 buf_writer(0, w, s)?;
199 }
200
201 for (write_buffer, _) in &mut buffers {
202 writer.write_all(write_buffer)?;
203 write_buffer.clear();
204 }
205
206 n_rows_finished += total_rows_per_pool_iter;
207 }
208 Ok(())
209}
210
211pub(crate) fn write_header<W: Write>(
213 writer: &mut W,
214 names: &[&str],
215 options: &SerializeOptions,
216) -> PolarsResult<()> {
217 let mut header = Vec::new();
218
219 let fake_arr = NullArray::new(ArrowDataType::Null, 0);
221 let mut names_serializer = string_serializer(
222 |iter: &mut std::slice::Iter<&str>| iter.next().copied(),
223 options,
224 |_| names.iter(),
225 &fake_arr,
226 );
227 for i in 0..names.len() {
228 names_serializer.serialize(&mut header, options);
229 if i != names.len() - 1 {
230 header.push(options.separator);
231 }
232 }
233 header.extend_from_slice(options.line_terminator.as_bytes());
234 writer.write_all(&header)?;
235 Ok(())
236}
237
238pub(crate) fn write_bom<W: Write>(writer: &mut W) -> PolarsResult<()> {
240 const BOM: [u8; 3] = [0xEF, 0xBB, 0xBF];
241 writer.write_all(&BOM)?;
242 Ok(())
243}