Skip to main content

polars_io/csv/read/
utils.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2#[cfg(feature = "decompress")]
3use std::io::Read;
4use std::mem::MaybeUninit;
5
6use super::parser::next_line_position;
7#[cfg(feature = "decompress")]
8use super::parser::next_line_position_naive;
9use super::splitfields::SplitFields;
10
11#[cfg(feature = "decompress")]
12fn decompress_impl<R: Read>(
13    decoder: &mut R,
14    n_rows: Option<usize>,
15    separator: u8,
16    quote_char: Option<u8>,
17    eol_char: u8,
18) -> Option<Vec<u8>> {
19    let chunk_size = 4096;
20    Some(match n_rows {
21        None => {
22            // decompression in a preallocated buffer does not work with zlib-ng
23            // and will put the original compressed data in the buffer.
24            let mut out = Vec::new();
25            decoder.read_to_end(&mut out).ok()?;
26            out
27        },
28        Some(n_rows) => {
29            // we take the first rows first '\n\
30            let mut out = vec![];
31            let mut expected_fields = 0;
32            // make sure that we have enough bytes to decode the header (even if it has embedded new line chars)
33            // those extra bytes in the buffer don't matter, we don't need to track them
34            loop {
35                let read = decoder.take(chunk_size).read_to_end(&mut out).ok()?;
36                if read == 0 {
37                    break;
38                }
39                if next_line_position_naive(&out, eol_char).is_some() {
40                    // an extra shot
41                    let read = decoder.take(chunk_size).read_to_end(&mut out).ok()?;
42                    if read == 0 {
43                        break;
44                    }
45                    // now that we have enough, we compute the number of fields (also takes embedding into account)
46                    expected_fields =
47                        SplitFields::new(&out, separator, quote_char, eol_char).count();
48                    break;
49                }
50            }
51
52            let mut line_count = 0;
53            let mut buf_pos = 0;
54            // keep decoding bytes and count lines
55            // keep track of the n_rows we read
56            while line_count < n_rows {
57                match next_line_position(
58                    &out[buf_pos + 1..],
59                    Some(expected_fields),
60                    separator,
61                    quote_char,
62                    eol_char,
63                ) {
64                    Some(pos) => {
65                        line_count += 1;
66                        buf_pos += pos;
67                    },
68                    None => {
69                        // take more bytes so that we might find a new line the next iteration
70                        let read = decoder.take(chunk_size).read_to_end(&mut out).ok()?;
71                        // we depleted the reader
72                        if read == 0 {
73                            break;
74                        }
75                        continue;
76                    },
77                };
78            }
79            if line_count == n_rows {
80                out.truncate(buf_pos); // retain only first n_rows in out
81            }
82            out
83        },
84    })
85}
86
87#[cfg(feature = "decompress")]
88pub(crate) fn decompress(
89    bytes: &[u8],
90    n_rows: Option<usize>,
91    separator: u8,
92    quote_char: Option<u8>,
93    eol_char: u8,
94) -> Option<Vec<u8>> {
95    use crate::utils::compression::SupportedCompression;
96
97    let algo = SupportedCompression::check(bytes)?;
98
99    match algo {
100        SupportedCompression::GZIP => {
101            let mut decoder = flate2::read::MultiGzDecoder::new(bytes);
102            decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char)
103        },
104        SupportedCompression::ZLIB => {
105            let mut decoder = flate2::read::ZlibDecoder::new(bytes);
106            decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char)
107        },
108        SupportedCompression::ZSTD => {
109            let mut decoder = zstd::Decoder::with_buffer(bytes).ok()?;
110            decompress_impl(&mut decoder, n_rows, separator, quote_char, eol_char)
111        },
112    }
113}
114
115/// replace double quotes by single ones
116///
117/// This function assumes that bytes is wrapped in the quoting character.
118///
119/// # Safety
120///
121/// The caller must ensure that:
122///     - Output buffer must have enough capacity to hold `bytes.len()`
123///     - bytes ends with the quote character e.g.: `"`
124///     - bytes length > 1.
125pub(super) unsafe fn escape_field(bytes: &[u8], quote: u8, buf: &mut [MaybeUninit<u8>]) -> usize {
126    debug_assert!(bytes.len() > 1);
127    let mut prev_quote = false;
128
129    let mut count = 0;
130    for c in bytes.get_unchecked(1..bytes.len() - 1) {
131        if *c == quote {
132            if prev_quote {
133                prev_quote = false;
134                buf.get_unchecked_mut(count).write(*c);
135                count += 1;
136            } else {
137                prev_quote = true;
138            }
139        } else {
140            prev_quote = false;
141            buf.get_unchecked_mut(count).write(*c);
142            count += 1;
143        }
144    }
145    count
146}