polars_io/csv/read/
schema_inference.rs

1use std::borrow::Cow;
2
3use polars_core::config::verbose;
4use polars_core::prelude::*;
5#[cfg(feature = "polars-time")]
6use polars_time::chunkedarray::string::infer as date_infer;
7#[cfg(feature = "polars-time")]
8use polars_time::prelude::string::Pattern;
9use polars_utils::format_pl_smallstr;
10
11use super::parser::{SplitLines, is_comment_line, skip_bom, skip_line_ending};
12use super::splitfields::SplitFields;
13use super::{CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues};
14use crate::csv::read::parser::skip_lines_naive;
15use crate::mmap::ReaderBytes;
16use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
17
18#[derive(Clone, Debug, Default)]
19pub struct SchemaInferenceResult {
20    inferred_schema: SchemaRef,
21    rows_read: usize,
22    bytes_read: usize,
23    bytes_total: usize,
24    n_threads: Option<usize>,
25}
26
27impl SchemaInferenceResult {
28    pub fn try_from_reader_bytes_and_options(
29        reader_bytes: &ReaderBytes,
30        options: &CsvReadOptions,
31    ) -> PolarsResult<Self> {
32        let parse_options = options.get_parse_options();
33
34        let infer_schema_length = options.infer_schema_length;
35        let has_header = options.has_header;
36        let schema_overwrite_arc = options.schema_overwrite.clone();
37        let schema_overwrite = schema_overwrite_arc.as_ref().map(|x| x.as_ref());
38        let skip_rows = options.skip_rows;
39        let skip_lines = options.skip_lines;
40        let skip_rows_after_header = options.skip_rows_after_header;
41        let raise_if_empty = options.raise_if_empty;
42        let mut n_threads = options.n_threads;
43
44        let bytes_total = reader_bytes.len();
45
46        let (inferred_schema, rows_read, bytes_read) = infer_file_schema(
47            reader_bytes,
48            &parse_options,
49            infer_schema_length,
50            has_header,
51            schema_overwrite,
52            skip_rows,
53            skip_lines,
54            skip_rows_after_header,
55            raise_if_empty,
56            &mut n_threads,
57        )?;
58
59        let this = Self {
60            inferred_schema: Arc::new(inferred_schema),
61            rows_read,
62            bytes_read,
63            bytes_total,
64            n_threads,
65        };
66
67        Ok(this)
68    }
69
70    pub fn with_inferred_schema(mut self, inferred_schema: SchemaRef) -> Self {
71        self.inferred_schema = inferred_schema;
72        self
73    }
74
75    pub fn get_inferred_schema(&self) -> SchemaRef {
76        self.inferred_schema.clone()
77    }
78
79    pub fn get_estimated_n_rows(&self) -> usize {
80        (self.rows_read as f64 / self.bytes_read as f64 * self.bytes_total as f64) as usize
81    }
82}
83
84impl CsvReadOptions {
85    /// Note: This does not update the schema from the inference result.
86    pub fn update_with_inference_result(&mut self, si_result: &SchemaInferenceResult) {
87        self.n_threads = si_result.n_threads;
88    }
89}
90
91pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
92    // determine data type based on possible types
93    // if there are incompatible types, use DataType::String
94    match possibilities.len() {
95        1 => possibilities.iter().next().unwrap().clone(),
96        2 if possibilities.contains(&DataType::Int64)
97            && possibilities.contains(&DataType::Float64) =>
98        {
99            // we have an integer and double, fall down to double
100            DataType::Float64
101        },
102        // default to String for conflicting datatypes (e.g bool and int)
103        _ => DataType::String,
104    }
105}
106
107/// Infer the data type of a record
108pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
109    // when quoting is enabled in the reader, these quotes aren't escaped, we default to
110    // String for them
111    if string.starts_with('"') {
112        if try_parse_dates {
113            #[cfg(feature = "polars-time")]
114            {
115                match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
116                    Some(pattern_with_offset) => match pattern_with_offset {
117                        Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
118                            DataType::Datetime(TimeUnit::Microseconds, None)
119                        },
120                        Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
121                        Pattern::DatetimeYMDZ => DataType::Datetime(
122                            TimeUnit::Microseconds,
123                            Some(PlSmallStr::from_static("UTC")),
124                        ),
125                        Pattern::Time => DataType::Time,
126                    },
127                    None => DataType::String,
128                }
129            }
130            #[cfg(not(feature = "polars-time"))]
131            {
132                panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
133            }
134        } else {
135            DataType::String
136        }
137    }
138    // match regex in a particular order
139    else if BOOLEAN_RE.is_match(string) {
140        DataType::Boolean
141    } else if !decimal_comma && FLOAT_RE.is_match(string)
142        || decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
143    {
144        DataType::Float64
145    } else if INTEGER_RE.is_match(string) {
146        DataType::Int64
147    } else if try_parse_dates {
148        #[cfg(feature = "polars-time")]
149        {
150            match date_infer::infer_pattern_single(string) {
151                Some(pattern_with_offset) => match pattern_with_offset {
152                    Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
153                        DataType::Datetime(TimeUnit::Microseconds, None)
154                    },
155                    Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
156                    Pattern::DatetimeYMDZ => DataType::Datetime(
157                        TimeUnit::Microseconds,
158                        Some(PlSmallStr::from_static("UTC")),
159                    ),
160                    Pattern::Time => DataType::Time,
161                },
162                None => DataType::String,
163            }
164        }
165        #[cfg(not(feature = "polars-time"))]
166        {
167            panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
168        }
169    } else {
170        DataType::String
171    }
172}
173
174#[inline]
175fn parse_bytes_with_encoding(bytes: &[u8], encoding: CsvEncoding) -> PolarsResult<Cow<str>> {
176    Ok(match encoding {
177        CsvEncoding::Utf8 => simdutf8::basic::from_utf8(bytes)
178            .map_err(|_| polars_err!(ComputeError: "invalid utf-8 sequence"))?
179            .into(),
180        CsvEncoding::LossyUtf8 => String::from_utf8_lossy(bytes),
181    })
182}
183
184#[allow(clippy::too_many_arguments)]
185fn infer_file_schema_inner(
186    reader_bytes: &ReaderBytes,
187    parse_options: &CsvParseOptions,
188    max_read_rows: Option<usize>,
189    has_header: bool,
190    schema_overwrite: Option<&Schema>,
191    // we take &mut because we maybe need to skip more rows dependent
192    // on the schema inference
193    mut skip_rows: usize,
194    skip_rows_after_header: usize,
195    recursion_count: u8,
196    raise_if_empty: bool,
197    n_threads: &mut Option<usize>,
198) -> PolarsResult<(Schema, usize, usize)> {
199    // keep track so that we can determine the amount of bytes read
200    let start_ptr = reader_bytes.as_ptr() as usize;
201
202    // We use lossy utf8 here because we don't want the schema inference to fail on utf8.
203    // It may later.
204    let encoding = CsvEncoding::LossyUtf8;
205
206    let bytes = skip_line_ending(skip_bom(reader_bytes), parse_options.eol_char);
207    if raise_if_empty {
208        polars_ensure!(!bytes.is_empty(), NoData: "empty CSV");
209    };
210    let mut lines = SplitLines::new(
211        bytes,
212        parse_options.quote_char,
213        parse_options.eol_char,
214        parse_options.comment_prefix.as_ref(),
215    )
216    .skip(skip_rows);
217
218    // get or create header names
219    // when has_header is false, creates default column names with column_ prefix
220
221    // skip lines that are comments
222    let mut first_line = None;
223
224    for (i, line) in (&mut lines).enumerate() {
225        if !is_comment_line(line, parse_options.comment_prefix.as_ref()) {
226            first_line = Some(line);
227            skip_rows += i;
228            break;
229        }
230    }
231
232    if first_line.is_none() {
233        first_line = lines.next();
234    }
235
236    // now that we've found the first non-comment line we parse the headers, or we create a header
237    let headers: Vec<PlSmallStr> = if let Some(mut header_line) = first_line {
238        let len = header_line.len();
239        if len > 1 {
240            // remove carriage return
241            let trailing_byte = header_line[len - 1];
242            if trailing_byte == b'\r' {
243                header_line = &header_line[..len - 1];
244            }
245        }
246
247        let byterecord = SplitFields::new(
248            header_line,
249            parse_options.separator,
250            parse_options.quote_char,
251            parse_options.eol_char,
252        );
253        if has_header {
254            let headers = byterecord
255                .map(|(slice, needs_escaping)| {
256                    let slice_escaped = if needs_escaping && (slice.len() >= 2) {
257                        &slice[1..(slice.len() - 1)]
258                    } else {
259                        slice
260                    };
261                    let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
262                    Ok(s)
263                })
264                .collect::<PolarsResult<Vec<_>>>()?;
265
266            let mut final_headers = Vec::with_capacity(headers.len());
267
268            let mut header_names = PlHashMap::with_capacity(headers.len());
269
270            for name in &headers {
271                let count = header_names.entry(name.as_ref()).or_insert(0usize);
272                if *count != 0 {
273                    final_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
274                } else {
275                    final_headers.push(PlSmallStr::from_str(name))
276                }
277                *count += 1;
278            }
279            final_headers
280        } else {
281            byterecord
282                .enumerate()
283                .map(|(i, _s)| format_pl_smallstr!("column_{}", i + 1))
284                .collect::<Vec<PlSmallStr>>()
285        }
286    } else if has_header && !bytes.is_empty() && recursion_count == 0 {
287        // there was no new line char. So we copy the whole buf and add one
288        // this is likely to be cheap as there are no rows.
289        let mut buf = Vec::with_capacity(bytes.len() + 2);
290        buf.extend_from_slice(bytes);
291        buf.push(parse_options.eol_char);
292
293        return infer_file_schema_inner(
294            &ReaderBytes::Owned(buf.into()),
295            parse_options,
296            max_read_rows,
297            has_header,
298            schema_overwrite,
299            skip_rows,
300            skip_rows_after_header,
301            recursion_count + 1,
302            raise_if_empty,
303            n_threads,
304        );
305    } else if !raise_if_empty {
306        return Ok((Schema::default(), 0, 0));
307    } else {
308        polars_bail!(NoData: "empty CSV");
309    };
310    if !has_header {
311        // re-init lines so that the header is included in type inference.
312        lines = SplitLines::new(
313            bytes,
314            parse_options.quote_char,
315            parse_options.eol_char,
316            parse_options.comment_prefix.as_ref(),
317        )
318        .skip(skip_rows);
319    }
320
321    let header_length = headers.len();
322    // keep track of inferred field types
323    let mut column_types: Vec<PlHashSet<DataType>> =
324        vec![PlHashSet::with_capacity(4); header_length];
325    // keep track of columns with nulls
326    let mut nulls: Vec<bool> = vec![false; header_length];
327
328    let mut rows_count = 0;
329    let mut fields = Vec::with_capacity(header_length);
330
331    // needed to prevent ownership going into the iterator loop
332    let records_ref = &mut lines;
333
334    let mut end_ptr = start_ptr;
335    for mut line in records_ref
336        .take(match max_read_rows {
337            Some(max_read_rows) => {
338                if max_read_rows <= (usize::MAX - skip_rows_after_header) {
339                    // read skip_rows_after_header more rows for inferring
340                    // the correct schema as the first skip_rows_after_header
341                    // rows will be skipped
342                    max_read_rows + skip_rows_after_header
343                } else {
344                    max_read_rows
345                }
346            },
347            None => usize::MAX,
348        })
349        .skip(skip_rows_after_header)
350    {
351        rows_count += 1;
352        // keep track so that we can determine the amount of bytes read
353        end_ptr = line.as_ptr() as usize + line.len();
354
355        if line.is_empty() {
356            continue;
357        }
358
359        // line is a comment -> skip
360        if is_comment_line(line, parse_options.comment_prefix.as_ref()) {
361            continue;
362        }
363
364        let len = line.len();
365        if len > 1 {
366            // remove carriage return
367            let trailing_byte = line[len - 1];
368            if trailing_byte == b'\r' {
369                line = &line[..len - 1];
370            }
371        }
372
373        let mut record = SplitFields::new(
374            line,
375            parse_options.separator,
376            parse_options.quote_char,
377            parse_options.eol_char,
378        );
379
380        for i in 0..header_length {
381            if let Some((slice, needs_escaping)) = record.next() {
382                if slice.is_empty() {
383                    unsafe { *nulls.get_unchecked_mut(i) = true };
384                } else {
385                    let slice_escaped = if needs_escaping && (slice.len() >= 2) {
386                        &slice[1..(slice.len() - 1)]
387                    } else {
388                        slice
389                    };
390                    let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
391                    let dtype = match &parse_options.null_values {
392                        None => Some(infer_field_schema(
393                            &s,
394                            parse_options.try_parse_dates,
395                            parse_options.decimal_comma,
396                        )),
397                        Some(NullValues::AllColumns(names)) => {
398                            if !names.iter().any(|nv| nv == s.as_ref()) {
399                                Some(infer_field_schema(
400                                    &s,
401                                    parse_options.try_parse_dates,
402                                    parse_options.decimal_comma,
403                                ))
404                            } else {
405                                None
406                            }
407                        },
408                        Some(NullValues::AllColumnsSingle(name)) => {
409                            if s.as_ref() != name.as_str() {
410                                Some(infer_field_schema(
411                                    &s,
412                                    parse_options.try_parse_dates,
413                                    parse_options.decimal_comma,
414                                ))
415                            } else {
416                                None
417                            }
418                        },
419                        Some(NullValues::Named(names)) => {
420                            // SAFETY:
421                            // we iterate over headers length.
422                            let current_name = unsafe { headers.get_unchecked(i) };
423                            let null_name = &names.iter().find(|name| name.0 == current_name);
424
425                            if let Some(null_name) = null_name {
426                                if null_name.1.as_str() != s.as_ref() {
427                                    Some(infer_field_schema(
428                                        &s,
429                                        parse_options.try_parse_dates,
430                                        parse_options.decimal_comma,
431                                    ))
432                                } else {
433                                    None
434                                }
435                            } else {
436                                Some(infer_field_schema(
437                                    &s,
438                                    parse_options.try_parse_dates,
439                                    parse_options.decimal_comma,
440                                ))
441                            }
442                        },
443                    };
444                    if let Some(dtype) = dtype {
445                        if matches!(&dtype, DataType::String)
446                            && needs_escaping
447                            && n_threads.unwrap_or(2) > 1
448                        {
449                            // The parser will chunk the file.
450                            // However this will be increasingly unlikely to be correct if there are many
451                            // new line characters in an escaped field. So we set a (somewhat arbitrary)
452                            // upper bound to the number of escaped lines we accept.
453                            // On the chunking side we also have logic to make this more robust.
454                            if slice
455                                .iter()
456                                .filter(|b| **b == parse_options.eol_char)
457                                .count()
458                                > 8
459                            {
460                                if verbose() {
461                                    eprintln!(
462                                        "falling back to single core reading because of many escaped new line chars."
463                                    )
464                                }
465                                *n_threads = Some(1);
466                            }
467                        }
468                        unsafe { column_types.get_unchecked_mut(i).insert(dtype) };
469                    }
470                }
471            }
472        }
473    }
474
475    // build schema from inference results
476    for i in 0..header_length {
477        let field_name = &headers[i];
478
479        if let Some(schema_overwrite) = schema_overwrite {
480            if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
481                fields.push(Field::new(name.clone(), dtype.clone()));
482                continue;
483            }
484
485            // column might have been renamed
486            // execute only if schema is complete
487            if schema_overwrite.len() == header_length {
488                if let Some((name, dtype)) = schema_overwrite.get_at_index(i) {
489                    fields.push(Field::new(name.clone(), dtype.clone()));
490                    continue;
491                }
492            }
493        }
494
495        let possibilities = &column_types[i];
496        let dtype = finish_infer_field_schema(possibilities);
497        fields.push(Field::new(field_name.clone(), dtype));
498    }
499    // if there is a single line after the header without an eol
500    // we copy the bytes add an eol and rerun this function
501    // so that the inference is consistent with and without eol char
502    if rows_count == 0
503        && !reader_bytes.is_empty()
504        && reader_bytes[reader_bytes.len() - 1] != parse_options.eol_char
505        && recursion_count == 0
506    {
507        let mut rb = Vec::with_capacity(reader_bytes.len() + 1);
508        rb.extend_from_slice(reader_bytes);
509        rb.push(parse_options.eol_char);
510        return infer_file_schema_inner(
511            &ReaderBytes::Owned(rb.into()),
512            parse_options,
513            max_read_rows,
514            has_header,
515            schema_overwrite,
516            skip_rows,
517            skip_rows_after_header,
518            recursion_count + 1,
519            raise_if_empty,
520            n_threads,
521        );
522    }
523
524    Ok((Schema::from_iter(fields), rows_count, end_ptr - start_ptr))
525}
526
527pub(super) fn check_decimal_comma(decimal_comma: bool, separator: u8) -> PolarsResult<()> {
528    if decimal_comma {
529        polars_ensure!(b',' != separator, InvalidOperation: "'decimal_comma' argument cannot be combined with ',' separator")
530    }
531    Ok(())
532}
533
534/// Infer the schema of a CSV file by reading through the first n rows of the file,
535/// with `max_read_rows` controlling the maximum number of rows to read.
536///
537/// If `max_read_rows` is not set, the whole file is read to infer its schema.
538///
539/// Returns
540///     - inferred schema
541///     - number of rows used for inference.
542///     - bytes read
543#[allow(clippy::too_many_arguments)]
544pub fn infer_file_schema(
545    reader_bytes: &ReaderBytes,
546    parse_options: &CsvParseOptions,
547    max_read_rows: Option<usize>,
548    has_header: bool,
549    schema_overwrite: Option<&Schema>,
550    skip_rows: usize,
551    skip_lines: usize,
552    skip_rows_after_header: usize,
553    raise_if_empty: bool,
554    n_threads: &mut Option<usize>,
555) -> PolarsResult<(Schema, usize, usize)> {
556    check_decimal_comma(parse_options.decimal_comma, parse_options.separator)?;
557
558    if skip_lines > 0 {
559        polars_ensure!(skip_rows == 0, InvalidOperation: "only one of 'skip_rows'/'skip_lines' may be set");
560        let bytes = skip_lines_naive(reader_bytes, parse_options.eol_char, skip_lines);
561        let reader_bytes = ReaderBytes::Borrowed(bytes);
562        infer_file_schema_inner(
563            &reader_bytes,
564            parse_options,
565            max_read_rows,
566            has_header,
567            schema_overwrite,
568            skip_rows,
569            skip_rows_after_header,
570            0,
571            raise_if_empty,
572            n_threads,
573        )
574    } else {
575        infer_file_schema_inner(
576            reader_bytes,
577            parse_options,
578            max_read_rows,
579            has_header,
580            schema_overwrite,
581            skip_rows,
582            skip_rows_after_header,
583            0,
584            raise_if_empty,
585            n_threads,
586        )
587    }
588}