polars_io/csv/read/
utils.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2#[cfg(feature = "decompress")]
3use std::io::Read;
4use std::mem::MaybeUninit;
5
6use super::parser::next_line_position;
7#[cfg(feature = "decompress")]
8use super::parser::next_line_position_naive;
9use super::splitfields::SplitFields;
10
11/// TODO: Remove this in favor of parallel CountLines::analyze_chunk
12///
13/// (see https://github.com/pola-rs/polars/issues/19078)
14pub(crate) fn get_file_chunks(
15    bytes: &[u8],
16    n_chunks: usize,
17    expected_fields: Option<usize>,
18    separator: u8,
19    quote_char: Option<u8>,
20    eol_char: u8,
21) -> Vec<(usize, usize)> {
22    let mut last_pos = 0;
23    let total_len = bytes.len();
24    let chunk_size = total_len / n_chunks;
25    let mut offsets = Vec::with_capacity(n_chunks);
26    for _ in 0..n_chunks {
27        let search_pos = last_pos + chunk_size;
28
29        if search_pos >= bytes.len() {
30            break;
31        }
32
33        let end_pos = match next_line_position(
34            &bytes[search_pos..],
35            expected_fields,
36            separator,
37            quote_char,
38            eol_char,
39        ) {
40            Some(pos) => search_pos + pos,
41            None => {
42                break;
43            },
44        };
45        offsets.push((last_pos, end_pos));
46        last_pos = end_pos;
47    }
48    offsets.push((last_pos, total_len));
49    offsets
50}
51
52#[cfg(feature = "decompress")]
53fn decompress_impl<R: Read>(
54    decoder: &mut R,
55    n_rows: Option<usize>,
56    separator: u8,
57    quote_char: Option<u8>,
58    eol_char: u8,
59) -> Option<Vec<u8>> {
60    let chunk_size = 4096;
61    Some(match n_rows {
62        None => {
63            // decompression in a preallocated buffer does not work with zlib-ng
64            // and will put the original compressed data in the buffer.
65            let mut out = Vec::new();
66            decoder.read_to_end(&mut out).ok()?;
67            out
68        },
69        Some(n_rows) => {
70            // we take the first rows first '\n\
71            let mut out = vec![];
72            let mut expected_fields = 0;
73            // make sure that we have enough bytes to decode the header (even if it has embedded new line chars)
74            // those extra bytes in the buffer don't matter, we don't need to track them
75            loop {
76                let read = decoder.take(chunk_size).read_to_end(&mut out).ok()?;
77                if read == 0 {
78                    break;
79                }
80                if next_line_position_naive(&out, eol_char).is_some() {
81                    // an extra shot
82                    let read = decoder.take(chunk_size).read_to_end(&mut out).ok()?;
83                    if read == 0 {
84                        break;
85                    }
86                    // now that we have enough, we compute the number of fields (also takes embedding into account)
87                    expected_fields =
88                        SplitFields::new(&out, separator, quote_char, eol_char).count();
89                    break;
90                }
91            }
92
93            let mut line_count = 0;
94            let mut buf_pos = 0;
95            // keep decoding bytes and count lines
96            // keep track of the n_rows we read
97            while line_count < n_rows {
98                match next_line_position(
99                    &out[buf_pos + 1..],
100                    Some(expected_fields),
101                    separator,
102                    quote_char,
103                    eol_char,
104                ) {
105                    Some(pos) => {
106                        line_count += 1;
107                        buf_pos += pos;
108                    },
109                    None => {
110                        // take more bytes so that we might find a new line the next iteration
111                        let read = decoder.take(chunk_size).read_to_end(&mut out).ok()?;
112                        // we depleted the reader
113                        if read == 0 {
114                            break;
115                        }
116                        continue;
117                    },
118                };
119            }
120            if line_count == n_rows {
121                out.truncate(buf_pos); // retain only first n_rows in out
122            }
123            out
124        },
125    })
126}
127
128#[cfg(feature = "decompress")]
129pub(crate) fn decompress(
130    bytes: &[u8],
131    n_rows: Option<usize>,
132    separator: u8,
133    quote_char: Option<u8>,
134    eol_char: u8,
135) -> Option<Vec<u8>> {
136    use crate::utils::compression::SupportedCompression;
137
138    if let Some(algo) = SupportedCompression::check(bytes) {
139        match algo {
140            SupportedCompression::GZIP => {
141                let mut decoder = flate2::read::MultiGzDecoder::new(bytes);
142                decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char)
143            },
144            SupportedCompression::ZLIB => {
145                let mut decoder = flate2::read::ZlibDecoder::new(bytes);
146                decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char)
147            },
148            SupportedCompression::ZSTD => {
149                let mut decoder = zstd::Decoder::with_buffer(bytes).ok()?;
150                decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char)
151            },
152        }
153    } else {
154        None
155    }
156}
157
158/// replace double quotes by single ones
159///
160/// This function assumes that bytes is wrapped in the quoting character.
161///
162/// # Safety
163///
164/// The caller must ensure that:
165///     - Output buffer must have enough capacity to hold `bytes.len()`
166///     - bytes ends with the quote character e.g.: `"`
167///     - bytes length > 1.
168pub(super) unsafe fn escape_field(bytes: &[u8], quote: u8, buf: &mut [MaybeUninit<u8>]) -> usize {
169    debug_assert!(bytes.len() > 1);
170    let mut prev_quote = false;
171
172    let mut count = 0;
173    for c in bytes.get_unchecked(1..bytes.len() - 1) {
174        if *c == quote {
175            if prev_quote {
176                prev_quote = false;
177                buf.get_unchecked_mut(count).write(*c);
178                count += 1;
179            } else {
180                prev_quote = true;
181            }
182        } else {
183            prev_quote = false;
184            buf.get_unchecked_mut(count).write(*c);
185            count += 1;
186        }
187    }
188    count
189}
190
191#[cfg(test)]
192mod test {
193    use super::get_file_chunks;
194
195    #[test]
196    fn test_get_file_chunks() {
197        let path = "../../examples/datasets/foods1.csv";
198        let s = std::fs::read_to_string(path).unwrap();
199        let bytes = s.as_bytes();
200        // can be within -1 / +1 bounds.
201        assert!(
202            (get_file_chunks(bytes, 10, Some(4), b',', None, b'\n').len() as i32 - 10).abs() <= 1
203        );
204        assert!(
205            (get_file_chunks(bytes, 8, Some(4), b',', None, b'\n').len() as i32 - 8).abs() <= 1
206        );
207    }
208}