polars_io/csv/read/
utils.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2#[cfg(feature = "decompress")]
3use std::io::Read;
4use std::mem::MaybeUninit;
5
6use super::parser::next_line_position;
7#[cfg(feature = "decompress")]
8use super::parser::next_line_position_naive;
9use super::splitfields::SplitFields;
10
11#[cfg(feature = "decompress")]
12fn decompress_impl<R: Read>(
13    decoder: &mut R,
14    n_rows: Option<usize>,
15    separator: u8,
16    quote_char: Option<u8>,
17    eol_char: u8,
18) -> Option<Vec<u8>> {
19    let chunk_size = 4096;
20    Some(match n_rows {
21        None => {
22            // decompression in a preallocated buffer does not work with zlib-ng
23            // and will put the original compressed data in the buffer.
24            let mut out = Vec::new();
25            decoder.read_to_end(&mut out).ok()?;
26            out
27        },
28        Some(n_rows) => {
29            // we take the first rows first '\n\
30            let mut out = vec![];
31            let mut expected_fields = 0;
32            // make sure that we have enough bytes to decode the header (even if it has embedded new line chars)
33            // those extra bytes in the buffer don't matter, we don't need to track them
34            loop {
35                let read = decoder.take(chunk_size).read_to_end(&mut out).ok()?;
36                if read == 0 {
37                    break;
38                }
39                if next_line_position_naive(&out, eol_char).is_some() {
40                    // an extra shot
41                    let read = decoder.take(chunk_size).read_to_end(&mut out).ok()?;
42                    if read == 0 {
43                        break;
44                    }
45                    // now that we have enough, we compute the number of fields (also takes embedding into account)
46                    expected_fields =
47                        SplitFields::new(&out, separator, quote_char, eol_char).count();
48                    break;
49                }
50            }
51
52            let mut line_count = 0;
53            let mut buf_pos = 0;
54            // keep decoding bytes and count lines
55            // keep track of the n_rows we read
56            while line_count < n_rows {
57                match next_line_position(
58                    &out[buf_pos + 1..],
59                    Some(expected_fields),
60                    separator,
61                    quote_char,
62                    eol_char,
63                ) {
64                    Some(pos) => {
65                        line_count += 1;
66                        buf_pos += pos;
67                    },
68                    None => {
69                        // take more bytes so that we might find a new line the next iteration
70                        let read = decoder.take(chunk_size).read_to_end(&mut out).ok()?;
71                        // we depleted the reader
72                        if read == 0 {
73                            break;
74                        }
75                        continue;
76                    },
77                };
78            }
79            if line_count == n_rows {
80                out.truncate(buf_pos); // retain only first n_rows in out
81            }
82            out
83        },
84    })
85}
86
87#[cfg(feature = "decompress")]
88pub(crate) fn decompress(
89    bytes: &[u8],
90    n_rows: Option<usize>,
91    separator: u8,
92    quote_char: Option<u8>,
93    eol_char: u8,
94) -> Option<Vec<u8>> {
95    use crate::utils::compression::SupportedCompression;
96
97    if let Some(algo) = SupportedCompression::check(bytes) {
98        match algo {
99            SupportedCompression::GZIP => {
100                let mut decoder = flate2::read::MultiGzDecoder::new(bytes);
101                decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char)
102            },
103            SupportedCompression::ZLIB => {
104                let mut decoder = flate2::read::ZlibDecoder::new(bytes);
105                decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char)
106            },
107            SupportedCompression::ZSTD => {
108                let mut decoder = zstd::Decoder::with_buffer(bytes).ok()?;
109                decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char)
110            },
111        }
112    } else {
113        None
114    }
115}
116
117/// replace double quotes by single ones
118///
119/// This function assumes that bytes is wrapped in the quoting character.
120///
121/// # Safety
122///
123/// The caller must ensure that:
124///     - Output buffer must have enough capacity to hold `bytes.len()`
125///     - bytes ends with the quote character e.g.: `"`
126///     - bytes length > 1.
127pub(super) unsafe fn escape_field(bytes: &[u8], quote: u8, buf: &mut [MaybeUninit<u8>]) -> usize {
128    debug_assert!(bytes.len() > 1);
129    let mut prev_quote = false;
130
131    let mut count = 0;
132    for c in bytes.get_unchecked(1..bytes.len() - 1) {
133        if *c == quote {
134            if prev_quote {
135                prev_quote = false;
136                buf.get_unchecked_mut(count).write(*c);
137                count += 1;
138            } else {
139                prev_quote = true;
140            }
141        } else {
142            prev_quote = false;
143            buf.get_unchecked_mut(count).write(*c);
144            count += 1;
145        }
146    }
147    count
148}