polars_io/csv/read/
schema_inference.rs

1use std::borrow::Cow;
2
3use polars_core::prelude::*;
4#[cfg(feature = "polars-time")]
5use polars_time::chunkedarray::string::infer as date_infer;
6#[cfg(feature = "polars-time")]
7use polars_time::prelude::string::Pattern;
8use polars_utils::format_pl_smallstr;
9
10use super::parser::{SplitLines, is_comment_line, skip_bom, skip_line_ending};
11use super::splitfields::SplitFields;
12use super::{CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues};
13use crate::csv::read::parser::skip_lines_naive;
14use crate::mmap::ReaderBytes;
15use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
16
17#[derive(Clone, Debug, Default)]
18pub struct SchemaInferenceResult {
19    inferred_schema: SchemaRef,
20    rows_read: usize,
21    bytes_read: usize,
22    bytes_total: usize,
23    n_threads: Option<usize>,
24}
25
26impl SchemaInferenceResult {
27    pub fn try_from_reader_bytes_and_options(
28        reader_bytes: &ReaderBytes,
29        options: &CsvReadOptions,
30    ) -> PolarsResult<Self> {
31        let parse_options = options.get_parse_options();
32
33        let infer_schema_length = options.infer_schema_length;
34        let has_header = options.has_header;
35        let schema_overwrite_arc = options.schema_overwrite.clone();
36        let schema_overwrite = schema_overwrite_arc.as_ref().map(|x| x.as_ref());
37        let skip_rows = options.skip_rows;
38        let skip_lines = options.skip_lines;
39        let skip_rows_after_header = options.skip_rows_after_header;
40        let raise_if_empty = options.raise_if_empty;
41        let n_threads = options.n_threads;
42
43        let bytes_total = reader_bytes.len();
44
45        let (inferred_schema, rows_read, bytes_read) = infer_file_schema(
46            reader_bytes,
47            &parse_options,
48            infer_schema_length,
49            has_header,
50            schema_overwrite,
51            skip_rows,
52            skip_lines,
53            skip_rows_after_header,
54            raise_if_empty,
55        )?;
56
57        let this = Self {
58            inferred_schema: Arc::new(inferred_schema),
59            rows_read,
60            bytes_read,
61            bytes_total,
62            n_threads,
63        };
64
65        Ok(this)
66    }
67
68    pub fn with_inferred_schema(mut self, inferred_schema: SchemaRef) -> Self {
69        self.inferred_schema = inferred_schema;
70        self
71    }
72
73    pub fn get_inferred_schema(&self) -> SchemaRef {
74        self.inferred_schema.clone()
75    }
76
77    pub fn get_estimated_n_rows(&self) -> usize {
78        (self.rows_read as f64 / self.bytes_read as f64 * self.bytes_total as f64) as usize
79    }
80}
81
82impl CsvReadOptions {
83    /// Note: This does not update the schema from the inference result.
84    pub fn update_with_inference_result(&mut self, si_result: &SchemaInferenceResult) {
85        self.n_threads = si_result.n_threads;
86    }
87}
88
89pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
90    // determine data type based on possible types
91    // if there are incompatible types, use DataType::String
92    match possibilities.len() {
93        1 => possibilities.iter().next().unwrap().clone(),
94        2 if possibilities.contains(&DataType::Int64)
95            && possibilities.contains(&DataType::Float64) =>
96        {
97            // we have an integer and double, fall down to double
98            DataType::Float64
99        },
100        // default to String for conflicting datatypes (e.g bool and int)
101        _ => DataType::String,
102    }
103}
104
105/// Infer the data type of a record
106pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
107    // when quoting is enabled in the reader, these quotes aren't escaped, we default to
108    // String for them
109    let bytes = string.as_bytes();
110    if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {
111        if try_parse_dates {
112            #[cfg(feature = "polars-time")]
113            {
114                match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
115                    Some(pattern_with_offset) => match pattern_with_offset {
116                        Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
117                            DataType::Datetime(TimeUnit::Microseconds, None)
118                        },
119                        Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
120                        Pattern::DatetimeYMDZ => {
121                            DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
122                        },
123                        Pattern::Time => DataType::Time,
124                    },
125                    None => DataType::String,
126                }
127            }
128            #[cfg(not(feature = "polars-time"))]
129            {
130                panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
131            }
132        } else {
133            DataType::String
134        }
135    }
136    // match regex in a particular order
137    else if BOOLEAN_RE.is_match(string) {
138        DataType::Boolean
139    } else if !decimal_comma && FLOAT_RE.is_match(string)
140        || decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
141    {
142        DataType::Float64
143    } else if INTEGER_RE.is_match(string) {
144        DataType::Int64
145    } else if try_parse_dates {
146        #[cfg(feature = "polars-time")]
147        {
148            match date_infer::infer_pattern_single(string) {
149                Some(pattern_with_offset) => match pattern_with_offset {
150                    Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
151                        DataType::Datetime(TimeUnit::Microseconds, None)
152                    },
153                    Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
154                    Pattern::DatetimeYMDZ => {
155                        DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
156                    },
157                    Pattern::Time => DataType::Time,
158                },
159                None => DataType::String,
160            }
161        }
162        #[cfg(not(feature = "polars-time"))]
163        {
164            panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
165        }
166    } else {
167        DataType::String
168    }
169}
170
171#[inline]
172fn parse_bytes_with_encoding(bytes: &[u8], encoding: CsvEncoding) -> PolarsResult<Cow<str>> {
173    Ok(match encoding {
174        CsvEncoding::Utf8 => simdutf8::basic::from_utf8(bytes)
175            .map_err(|_| polars_err!(ComputeError: "invalid utf-8 sequence"))?
176            .into(),
177        CsvEncoding::LossyUtf8 => String::from_utf8_lossy(bytes),
178    })
179}
180
181fn column_name(i: usize) -> PlSmallStr {
182    format_pl_smallstr!("column_{}", i + 1)
183}
184
185#[allow(clippy::too_many_arguments)]
186fn infer_file_schema_inner(
187    reader_bytes: &ReaderBytes,
188    parse_options: &CsvParseOptions,
189    max_read_rows: Option<usize>,
190    has_header: bool,
191    schema_overwrite: Option<&Schema>,
192    // we take &mut because we maybe need to skip more rows dependent
193    // on the schema inference
194    mut skip_rows: usize,
195    skip_rows_after_header: usize,
196    recursion_count: u8,
197    raise_if_empty: bool,
198) -> PolarsResult<(Schema, usize, usize)> {
199    // keep track so that we can determine the amount of bytes read
200    let start_ptr = reader_bytes.as_ptr() as usize;
201
202    // We use lossy utf8 here because we don't want the schema inference to fail on utf8.
203    // It may later.
204    let encoding = CsvEncoding::LossyUtf8;
205
206    let bytes = skip_line_ending(skip_bom(reader_bytes), parse_options.eol_char);
207    if raise_if_empty {
208        polars_ensure!(!bytes.is_empty(), NoData: "empty CSV");
209    };
210    let mut lines = SplitLines::new(
211        bytes,
212        parse_options.quote_char,
213        parse_options.eol_char,
214        parse_options.comment_prefix.as_ref(),
215    )
216    .skip(skip_rows);
217
218    // get or create header names
219    // when has_header is false, creates default column names with column_ prefix
220
221    // skip lines that are comments
222    let mut first_line = None;
223
224    for (i, line) in (&mut lines).enumerate() {
225        if !is_comment_line(line, parse_options.comment_prefix.as_ref()) {
226            first_line = Some(line);
227            skip_rows += i;
228            break;
229        }
230    }
231
232    if first_line.is_none() {
233        first_line = lines.next();
234    }
235
236    // now that we've found the first non-comment line we parse the headers, or we create a header
237    let mut headers: Vec<PlSmallStr> = if let Some(mut header_line) = first_line {
238        let len = header_line.len();
239        if len > 1 {
240            // remove carriage return
241            let trailing_byte = header_line[len - 1];
242            if trailing_byte == b'\r' {
243                header_line = &header_line[..len - 1];
244            }
245        }
246
247        let byterecord = SplitFields::new(
248            header_line,
249            parse_options.separator,
250            parse_options.quote_char,
251            parse_options.eol_char,
252        );
253        if has_header {
254            let headers = byterecord
255                .map(|(slice, needs_escaping)| {
256                    let slice_escaped = if needs_escaping && (slice.len() >= 2) {
257                        &slice[1..(slice.len() - 1)]
258                    } else {
259                        slice
260                    };
261                    let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
262                    Ok(s)
263                })
264                .collect::<PolarsResult<Vec<_>>>()?;
265
266            let mut final_headers = Vec::with_capacity(headers.len());
267
268            let mut header_names = PlHashMap::with_capacity(headers.len());
269
270            for name in &headers {
271                let count = header_names.entry(name.as_ref()).or_insert(0usize);
272                if *count != 0 {
273                    final_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
274                } else {
275                    final_headers.push(PlSmallStr::from_str(name))
276                }
277                *count += 1;
278            }
279            final_headers
280        } else {
281            byterecord
282                .enumerate()
283                .map(|(i, _s)| column_name(i))
284                .collect::<Vec<PlSmallStr>>()
285        }
286    } else if has_header && !bytes.is_empty() && recursion_count == 0 {
287        // there was no new line char. So we copy the whole buf and add one
288        // this is likely to be cheap as there are no rows.
289        let mut buf = Vec::with_capacity(bytes.len() + 2);
290        buf.extend_from_slice(bytes);
291        buf.push(parse_options.eol_char);
292
293        return infer_file_schema_inner(
294            &ReaderBytes::Owned(buf.into()),
295            parse_options,
296            max_read_rows,
297            has_header,
298            schema_overwrite,
299            skip_rows,
300            skip_rows_after_header,
301            recursion_count + 1,
302            raise_if_empty,
303        );
304    } else if !raise_if_empty {
305        return Ok((Schema::default(), 0, 0));
306    } else {
307        polars_bail!(NoData: "empty CSV");
308    };
309    if !has_header {
310        // re-init lines so that the header is included in type inference.
311        lines = SplitLines::new(
312            bytes,
313            parse_options.quote_char,
314            parse_options.eol_char,
315            parse_options.comment_prefix.as_ref(),
316        )
317        .skip(skip_rows);
318    }
319
320    // keep track of inferred field types
321    let mut column_types: Vec<PlHashSet<DataType>> =
322        vec![PlHashSet::with_capacity(4); headers.len()];
323    // keep track of columns with nulls
324    let mut nulls: Vec<bool> = vec![false; headers.len()];
325
326    let mut rows_count = 0;
327    let mut fields = Vec::with_capacity(headers.len());
328
329    // needed to prevent ownership going into the iterator loop
330    let records_ref = &mut lines;
331
332    let mut end_ptr = start_ptr;
333    for mut line in records_ref
334        .take(match max_read_rows {
335            Some(max_read_rows) => {
336                if max_read_rows <= (usize::MAX - skip_rows_after_header) {
337                    // read skip_rows_after_header more rows for inferring
338                    // the correct schema as the first skip_rows_after_header
339                    // rows will be skipped
340                    max_read_rows + skip_rows_after_header
341                } else {
342                    max_read_rows
343                }
344            },
345            None => usize::MAX,
346        })
347        .skip(skip_rows_after_header)
348    {
349        rows_count += 1;
350        // keep track so that we can determine the amount of bytes read
351        end_ptr = line.as_ptr() as usize + line.len();
352
353        if line.is_empty() {
354            continue;
355        }
356
357        // line is a comment -> skip
358        if is_comment_line(line, parse_options.comment_prefix.as_ref()) {
359            continue;
360        }
361
362        let len = line.len();
363        if len > 1 {
364            // remove carriage return
365            let trailing_byte = line[len - 1];
366            if trailing_byte == b'\r' {
367                line = &line[..len - 1];
368            }
369        }
370
371        let record = SplitFields::new(
372            line,
373            parse_options.separator,
374            parse_options.quote_char,
375            parse_options.eol_char,
376        );
377
378        for (i, (slice, needs_escaping)) in record.enumerate() {
379            // When `has_header = False` and ``
380            // Increase the schema if the first line didn't have all columns.
381            if i >= headers.len() {
382                if !has_header {
383                    headers.push(column_name(i));
384                    column_types.push(Default::default());
385                    nulls.push(false);
386                } else {
387                    break;
388                }
389            }
390
391            if slice.is_empty() {
392                unsafe { *nulls.get_unchecked_mut(i) = true };
393            } else {
394                let slice_escaped = if needs_escaping && (slice.len() >= 2) {
395                    &slice[1..(slice.len() - 1)]
396                } else {
397                    slice
398                };
399                let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
400                let dtype = match &parse_options.null_values {
401                    None => Some(infer_field_schema(
402                        &s,
403                        parse_options.try_parse_dates,
404                        parse_options.decimal_comma,
405                    )),
406                    Some(NullValues::AllColumns(names)) => {
407                        if !names.iter().any(|nv| nv == s.as_ref()) {
408                            Some(infer_field_schema(
409                                &s,
410                                parse_options.try_parse_dates,
411                                parse_options.decimal_comma,
412                            ))
413                        } else {
414                            None
415                        }
416                    },
417                    Some(NullValues::AllColumnsSingle(name)) => {
418                        if s.as_ref() != name.as_str() {
419                            Some(infer_field_schema(
420                                &s,
421                                parse_options.try_parse_dates,
422                                parse_options.decimal_comma,
423                            ))
424                        } else {
425                            None
426                        }
427                    },
428                    Some(NullValues::Named(names)) => {
429                        // SAFETY:
430                        // we iterate over headers length.
431                        let current_name = unsafe { headers.get_unchecked(i) };
432                        let null_name = &names.iter().find(|name| name.0 == current_name);
433
434                        if let Some(null_name) = null_name {
435                            if null_name.1.as_str() != s.as_ref() {
436                                Some(infer_field_schema(
437                                    &s,
438                                    parse_options.try_parse_dates,
439                                    parse_options.decimal_comma,
440                                ))
441                            } else {
442                                None
443                            }
444                        } else {
445                            Some(infer_field_schema(
446                                &s,
447                                parse_options.try_parse_dates,
448                                parse_options.decimal_comma,
449                            ))
450                        }
451                    },
452                };
453                if let Some(dtype) = dtype {
454                    unsafe { column_types.get_unchecked_mut(i).insert(dtype) };
455                }
456            }
457        }
458    }
459
460    // build schema from inference results
461    for i in 0..headers.len() {
462        let field_name = &headers[i];
463
464        if let Some(schema_overwrite) = schema_overwrite {
465            if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
466                fields.push(Field::new(name.clone(), dtype.clone()));
467                continue;
468            }
469
470            // column might have been renamed
471            // execute only if schema is complete
472            if schema_overwrite.len() == headers.len() {
473                if let Some((name, dtype)) = schema_overwrite.get_at_index(i) {
474                    fields.push(Field::new(name.clone(), dtype.clone()));
475                    continue;
476                }
477            }
478        }
479
480        let possibilities = &column_types[i];
481        let dtype = finish_infer_field_schema(possibilities);
482        fields.push(Field::new(field_name.clone(), dtype));
483    }
484    // if there is a single line after the header without an eol
485    // we copy the bytes add an eol and rerun this function
486    // so that the inference is consistent with and without eol char
487    if rows_count == 0
488        && !reader_bytes.is_empty()
489        && reader_bytes[reader_bytes.len() - 1] != parse_options.eol_char
490        && recursion_count == 0
491    {
492        let mut rb = Vec::with_capacity(reader_bytes.len() + 1);
493        rb.extend_from_slice(reader_bytes);
494        rb.push(parse_options.eol_char);
495        return infer_file_schema_inner(
496            &ReaderBytes::Owned(rb.into()),
497            parse_options,
498            max_read_rows,
499            has_header,
500            schema_overwrite,
501            skip_rows,
502            skip_rows_after_header,
503            recursion_count + 1,
504            raise_if_empty,
505        );
506    }
507
508    Ok((Schema::from_iter(fields), rows_count, end_ptr - start_ptr))
509}
510
511pub(super) fn check_decimal_comma(decimal_comma: bool, separator: u8) -> PolarsResult<()> {
512    if decimal_comma {
513        polars_ensure!(b',' != separator, InvalidOperation: "'decimal_comma' argument cannot be combined with ',' separator")
514    }
515    Ok(())
516}
517
518/// Infer the schema of a CSV file by reading through the first n rows of the file,
519/// with `max_read_rows` controlling the maximum number of rows to read.
520///
521/// If `max_read_rows` is not set, the whole file is read to infer its schema.
522///
523/// Returns
524///     - inferred schema
525///     - number of rows used for inference.
526///     - bytes read
527#[allow(clippy::too_many_arguments)]
528pub fn infer_file_schema(
529    reader_bytes: &ReaderBytes,
530    parse_options: &CsvParseOptions,
531    max_read_rows: Option<usize>,
532    has_header: bool,
533    schema_overwrite: Option<&Schema>,
534    skip_rows: usize,
535    skip_lines: usize,
536    skip_rows_after_header: usize,
537    raise_if_empty: bool,
538) -> PolarsResult<(Schema, usize, usize)> {
539    check_decimal_comma(parse_options.decimal_comma, parse_options.separator)?;
540
541    if skip_lines > 0 {
542        polars_ensure!(skip_rows == 0, InvalidOperation: "only one of 'skip_rows'/'skip_lines' may be set");
543        let bytes = skip_lines_naive(reader_bytes, parse_options.eol_char, skip_lines);
544        let reader_bytes = ReaderBytes::Borrowed(bytes);
545        infer_file_schema_inner(
546            &reader_bytes,
547            parse_options,
548            max_read_rows,
549            has_header,
550            schema_overwrite,
551            skip_rows,
552            skip_rows_after_header,
553            0,
554            raise_if_empty,
555        )
556    } else {
557        infer_file_schema_inner(
558            reader_bytes,
559            parse_options,
560            max_read_rows,
561            has_header,
562            schema_overwrite,
563            skip_rows,
564            skip_rows_after_header,
565            0,
566            raise_if_empty,
567        )
568    }
569}