polars_io/csv/read/
splitfields.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2#[cfg(not(feature = "simd"))]
3mod inner {
4    /// An adapted version of std::iter::Split.
5    /// This exists solely because we cannot split the lines naively as
6    pub(crate) struct SplitFields<'a> {
7        v: &'a [u8],
8        separator: u8,
9        finished: bool,
10        quote_char: u8,
11        quoting: bool,
12        eol_char: u8,
13    }
14
15    impl<'a> SplitFields<'a> {
16        pub(crate) fn new(
17            slice: &'a [u8],
18            separator: u8,
19            quote_char: Option<u8>,
20            eol_char: u8,
21        ) -> Self {
22            Self {
23                v: slice,
24                separator,
25                finished: false,
26                quote_char: quote_char.unwrap_or(b'"'),
27                quoting: quote_char.is_some(),
28                eol_char,
29            }
30        }
31
32        unsafe fn finish_eol(
33            &mut self,
34            need_escaping: bool,
35            idx: usize,
36        ) -> Option<(&'a [u8], bool)> {
37            self.finished = true;
38            debug_assert!(idx <= self.v.len());
39            Some((self.v.get_unchecked(..idx), need_escaping))
40        }
41
42        fn finish(&mut self, need_escaping: bool) -> Option<(&'a [u8], bool)> {
43            self.finished = true;
44            Some((self.v, need_escaping))
45        }
46
47        fn eof_eol(&self, current_ch: u8) -> bool {
48            current_ch == self.separator || current_ch == self.eol_char
49        }
50    }
51
52    impl<'a> Iterator for SplitFields<'a> {
53        // the bool is used to indicate that it requires escaping
54        type Item = (&'a [u8], bool);
55
56        #[inline]
57        fn next(&mut self) -> Option<(&'a [u8], bool)> {
58            if self.finished {
59                return None;
60            } else if self.v.is_empty() {
61                return self.finish(false);
62            }
63
64            let mut needs_escaping = false;
65            // There can be strings with separators:
66            // "Street, City",
67
68            // SAFETY:
69            // we have checked bounds
70            let pos = if self.quoting && unsafe { *self.v.get_unchecked(0) } == self.quote_char {
71                needs_escaping = true;
72                // There can be pair of double-quotes within string.
73                // Each of the embedded double-quote characters must be represented
74                // by a pair of double-quote characters:
75                // e.g. 1997,Ford,E350,"Super, ""luxurious"" truck",20020
76
77                // denotes if we are in a string field, started with a quote
78                let mut in_field = false;
79
80                let mut idx = 0u32;
81                let mut current_idx = 0u32;
82                // micro optimizations
83                #[allow(clippy::explicit_counter_loop)]
84                for &c in self.v.iter() {
85                    if c == self.quote_char {
86                        // toggle between string field enclosure
87                        //      if we encounter a starting '"' -> in_field = true;
88                        //      if we encounter a closing '"' -> in_field = false;
89                        in_field = !in_field;
90                    }
91
92                    if !in_field && self.eof_eol(c) {
93                        if c == self.eol_char {
94                            // SAFETY:
95                            // we are in bounds
96                            return unsafe {
97                                self.finish_eol(needs_escaping, current_idx as usize)
98                            };
99                        }
100                        idx = current_idx;
101                        break;
102                    }
103                    current_idx += 1;
104                }
105
106                if idx == 0 {
107                    return self.finish(needs_escaping);
108                }
109
110                idx as usize
111            } else {
112                match self.v.iter().position(|&c| self.eof_eol(c)) {
113                    None => return self.finish(needs_escaping),
114                    Some(idx) => unsafe {
115                        // SAFETY:
116                        // idx was just found
117                        if *self.v.get_unchecked(idx) == self.eol_char {
118                            return self.finish_eol(needs_escaping, idx);
119                        } else {
120                            idx
121                        }
122                    },
123                }
124            };
125
126            unsafe {
127                debug_assert!(pos <= self.v.len());
128                // SAFETY:
129                // we are in bounds
130                let ret = Some((self.v.get_unchecked(..pos), needs_escaping));
131                self.v = self.v.get_unchecked(pos + 1..);
132                ret
133            }
134        }
135    }
136}
137
138#[cfg(feature = "simd")]
139mod inner {
140    use std::simd::prelude::*;
141
142    use polars_utils::clmul::prefix_xorsum_inclusive;
143
144    const SIMD_SIZE: usize = 64;
145    type SimdVec = u8x64;
146
147    /// An adapted version of std::iter::Split.
148    /// This exists solely because we cannot split the lines naively as
149    pub(crate) struct SplitFields<'a> {
150        pub v: &'a [u8],
151        separator: u8,
152        pub finished: bool,
153        quote_char: u8,
154        quoting: bool,
155        eol_char: u8,
156        simd_separator: SimdVec,
157        simd_eol_char: SimdVec,
158        simd_quote_char: SimdVec,
159        previous_valid_ends: u64,
160    }
161
162    impl<'a> SplitFields<'a> {
163        pub(crate) fn new(
164            slice: &'a [u8],
165            separator: u8,
166            quote_char: Option<u8>,
167            eol_char: u8,
168        ) -> Self {
169            let simd_separator = SimdVec::splat(separator);
170            let simd_eol_char = SimdVec::splat(eol_char);
171            let quoting = quote_char.is_some();
172            let quote_char = quote_char.unwrap_or(b'"');
173            let simd_quote_char = SimdVec::splat(quote_char);
174
175            Self {
176                v: slice,
177                separator,
178                finished: false,
179                quote_char,
180                quoting,
181                eol_char,
182                simd_separator,
183                simd_eol_char,
184                simd_quote_char,
185                previous_valid_ends: 0,
186            }
187        }
188
189        unsafe fn finish_eol(
190            &mut self,
191            need_escaping: bool,
192            pos: usize,
193        ) -> Option<(&'a [u8], bool)> {
194            self.finished = true;
195            debug_assert!(pos <= self.v.len());
196            Some((self.v.get_unchecked(..pos), need_escaping))
197        }
198
199        #[inline]
200        fn finish(&mut self, need_escaping: bool) -> Option<(&'a [u8], bool)> {
201            self.finished = true;
202            Some((self.v, need_escaping))
203        }
204
205        fn eof_eol(&self, current_ch: u8) -> bool {
206            current_ch == self.separator || current_ch == self.eol_char
207        }
208    }
209
210    impl<'a> Iterator for SplitFields<'a> {
211        // the bool is used to indicate that it requires escaping
212        type Item = (&'a [u8], bool);
213
214        #[inline]
215        fn next(&mut self) -> Option<(&'a [u8], bool)> {
216            // This must be before we check the cached value
217            if self.finished {
218                return None;
219            }
220            // Then check cached value as this is hot.
221            if self.previous_valid_ends != 0 {
222                let pos = self.previous_valid_ends.trailing_zeros() as usize;
223                self.previous_valid_ends >>= (pos + 1) as u64;
224
225                unsafe {
226                    debug_assert!(pos < self.v.len());
227                    // SAFETY:
228                    // we are in bounds
229                    let needs_escaping = self
230                        .v
231                        .first()
232                        .map(|c| *c == self.quote_char && self.quoting)
233                        .unwrap_or(false);
234
235                    if *self.v.get_unchecked(pos) == self.eol_char {
236                        return self.finish_eol(needs_escaping, pos);
237                    }
238
239                    let bytes = self.v.get_unchecked(..pos);
240
241                    self.v = self.v.get_unchecked(pos + 1..);
242                    let ret = Some((bytes, needs_escaping));
243
244                    return ret;
245                }
246            }
247            if self.v.is_empty() {
248                return self.finish(false);
249            }
250
251            let mut needs_escaping = false;
252            // There can be strings with separators:
253            // "Street, City",
254
255            // SAFETY:
256            // we have checked bounds
257            let pos = if self.quoting && unsafe { *self.v.get_unchecked(0) } == self.quote_char {
258                // Start of an enclosed field
259                let mut total_idx = 0;
260                needs_escaping = true;
261                let mut not_in_field_previous_iter = true;
262
263                loop {
264                    let bytes = unsafe { self.v.get_unchecked(total_idx..) };
265
266                    if bytes.len() > SIMD_SIZE {
267                        let lane: [u8; SIMD_SIZE] = unsafe {
268                            bytes
269                                .get_unchecked(0..SIMD_SIZE)
270                                .try_into()
271                                .unwrap_unchecked()
272                        };
273                        let simd_bytes = SimdVec::from(lane);
274                        let has_eol = simd_bytes.simd_eq(self.simd_eol_char);
275                        let has_sep = simd_bytes.simd_eq(self.simd_separator);
276                        let quote_mask = simd_bytes.simd_eq(self.simd_quote_char).to_bitmask();
277                        let mut end_mask = (has_sep | has_eol).to_bitmask();
278
279                        let mut not_in_quote_field = prefix_xorsum_inclusive(quote_mask);
280
281                        if not_in_field_previous_iter {
282                            not_in_quote_field = !not_in_quote_field;
283                        }
284                        not_in_field_previous_iter =
285                            (not_in_quote_field & (1 << (SIMD_SIZE - 1))) > 0;
286                        end_mask &= not_in_quote_field;
287
288                        if end_mask != 0 {
289                            let pos = end_mask.trailing_zeros() as usize;
290                            total_idx += pos;
291                            debug_assert!(
292                                self.v[total_idx] == self.eol_char
293                                    || self.v[total_idx] == self.separator
294                            );
295
296                            if pos == SIMD_SIZE - 1 {
297                                self.previous_valid_ends = 0;
298                            } else {
299                                self.previous_valid_ends = end_mask >> (pos + 1) as u64;
300                            }
301
302                            break;
303                        } else {
304                            total_idx += SIMD_SIZE;
305                        }
306                    } else {
307                        // There can be a pair of double-quotes within a string.
308                        // Each of the embedded double-quote characters must be represented
309                        // by a pair of double-quote characters:
310                        // e.g. 1997,Ford,E350,"Super, ""luxurious"" truck",20020
311
312                        // denotes if we are in a string field, started with a quote
313                        let mut in_field = !not_in_field_previous_iter;
314
315                        // usize::MAX is unset.
316                        let mut idx = usize::MAX;
317                        let mut current_idx = 0;
318                        // micro optimizations
319                        #[allow(clippy::explicit_counter_loop)]
320                        for &c in bytes.iter() {
321                            if c == self.quote_char {
322                                // toggle between string field enclosure
323                                //      if we encounter a starting '"' -> in_field = true;
324                                //      if we encounter a closing '"' -> in_field = false;
325                                in_field = !in_field;
326                            }
327
328                            if !in_field && self.eof_eol(c) {
329                                idx = current_idx;
330                                break;
331                            }
332                            current_idx += 1;
333                        }
334
335                        if idx == usize::MAX {
336                            return self.finish(needs_escaping);
337                        }
338
339                        total_idx += idx;
340                        debug_assert!(
341                            self.v[total_idx] == self.eol_char
342                                || self.v[total_idx] == self.separator
343                        );
344                        break;
345                    }
346                }
347                total_idx
348            } else {
349                // Start of an unenclosed field
350                let mut total_idx = 0;
351
352                loop {
353                    let bytes = unsafe { self.v.get_unchecked(total_idx..) };
354
355                    if bytes.len() > SIMD_SIZE {
356                        let lane: [u8; SIMD_SIZE] = unsafe {
357                            bytes
358                                .get_unchecked(0..SIMD_SIZE)
359                                .try_into()
360                                .unwrap_unchecked()
361                        };
362                        let simd_bytes = SimdVec::from(lane);
363                        let has_eol_char = simd_bytes.simd_eq(self.simd_eol_char);
364                        let has_separator = simd_bytes.simd_eq(self.simd_separator);
365                        let has_any_mask = (has_separator | has_eol_char).to_bitmask();
366
367                        if has_any_mask != 0 {
368                            total_idx += has_any_mask.trailing_zeros() as usize;
369                            break;
370                        } else {
371                            total_idx += SIMD_SIZE;
372                        }
373                    } else {
374                        match bytes.iter().position(|&c| self.eof_eol(c)) {
375                            None => return self.finish(needs_escaping),
376                            Some(idx) => {
377                                total_idx += idx;
378                                break;
379                            },
380                        }
381                    }
382                }
383                total_idx
384            };
385
386            // Make sure the iterator is done when EOL.
387            let c = unsafe { *self.v.get_unchecked(pos) };
388            if c == self.eol_char {
389                // SAFETY:
390                // we are in bounds
391                return unsafe { self.finish_eol(needs_escaping, pos) };
392            }
393
394            unsafe {
395                debug_assert!(pos < self.v.len());
396                // SAFETY:
397                // we are in bounds
398                let ret = Some((self.v.get_unchecked(..pos), needs_escaping));
399                self.v = self.v.get_unchecked(pos + 1..);
400                ret
401            }
402        }
403    }
404}
405
406pub(crate) use inner::SplitFields;
407
408#[cfg(test)]
409mod test {
410    use super::SplitFields;
411
412    #[test]
413    fn test_splitfields() {
414        let input = "\"foo\",\"bar\"";
415        let mut fields = SplitFields::new(input.as_bytes(), b',', Some(b'"'), b'\n');
416
417        assert_eq!(fields.next(), Some(("\"foo\"".as_bytes(), true)));
418        assert_eq!(fields.next(), Some(("\"bar\"".as_bytes(), true)));
419        assert_eq!(fields.next(), None);
420
421        let input2 = "\"foo\n bar\";\"baz\";12345";
422        let mut fields2 = SplitFields::new(input2.as_bytes(), b';', Some(b'"'), b'\n');
423
424        assert_eq!(fields2.next(), Some(("\"foo\n bar\"".as_bytes(), true)));
425        assert_eq!(fields2.next(), Some(("\"baz\"".as_bytes(), true)));
426        assert_eq!(fields2.next(), Some(("12345".as_bytes(), false)));
427        assert_eq!(fields2.next(), None);
428    }
429}