polars_io/csv/read/
schema_inference.rs

1use std::borrow::Cow;
2
3use polars_core::prelude::*;
4#[cfg(feature = "polars-time")]
5use polars_time::chunkedarray::string::infer as date_infer;
6#[cfg(feature = "polars-time")]
7use polars_time::prelude::string::Pattern;
8use polars_utils::format_pl_smallstr;
9
10use super::parser::{SplitLines, is_comment_line, skip_bom, skip_line_ending};
11use super::splitfields::SplitFields;
12use super::{CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues};
13use crate::csv::read::parser::skip_lines_naive;
14use crate::mmap::ReaderBytes;
15use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
16
17#[derive(Clone, Debug, Default)]
18pub struct SchemaInferenceResult {
19    inferred_schema: SchemaRef,
20    rows_read: usize,
21    bytes_read: usize,
22    bytes_total: usize,
23    n_threads: Option<usize>,
24}
25
26impl SchemaInferenceResult {
27    pub fn try_from_reader_bytes_and_options(
28        reader_bytes: &ReaderBytes,
29        options: &CsvReadOptions,
30    ) -> PolarsResult<Self> {
31        let parse_options = options.get_parse_options();
32
33        let infer_schema_length = options.infer_schema_length;
34        let has_header = options.has_header;
35        let schema_overwrite_arc = options.schema_overwrite.clone();
36        let schema_overwrite = schema_overwrite_arc.as_ref().map(|x| x.as_ref());
37        let skip_rows = options.skip_rows;
38        let skip_lines = options.skip_lines;
39        let skip_rows_after_header = options.skip_rows_after_header;
40        let raise_if_empty = options.raise_if_empty;
41        let n_threads = options.n_threads;
42
43        let bytes_total = reader_bytes.len();
44
45        let (inferred_schema, rows_read, bytes_read) = infer_file_schema(
46            reader_bytes,
47            &parse_options,
48            infer_schema_length,
49            has_header,
50            schema_overwrite,
51            skip_rows,
52            skip_lines,
53            skip_rows_after_header,
54            raise_if_empty,
55        )?;
56
57        let this = Self {
58            inferred_schema: Arc::new(inferred_schema),
59            rows_read,
60            bytes_read,
61            bytes_total,
62            n_threads,
63        };
64
65        Ok(this)
66    }
67
68    pub fn with_inferred_schema(mut self, inferred_schema: SchemaRef) -> Self {
69        self.inferred_schema = inferred_schema;
70        self
71    }
72
73    pub fn get_inferred_schema(&self) -> SchemaRef {
74        self.inferred_schema.clone()
75    }
76
77    pub fn get_estimated_n_rows(&self) -> usize {
78        (self.rows_read as f64 / self.bytes_read as f64 * self.bytes_total as f64) as usize
79    }
80}
81
82impl CsvReadOptions {
83    /// Note: This does not update the schema from the inference result.
84    pub fn update_with_inference_result(&mut self, si_result: &SchemaInferenceResult) {
85        self.n_threads = si_result.n_threads;
86    }
87}
88
89pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
90    // determine data type based on possible types
91    // if there are incompatible types, use DataType::String
92    match possibilities.len() {
93        1 => possibilities.iter().next().unwrap().clone(),
94        2 if possibilities.contains(&DataType::Int64)
95            && possibilities.contains(&DataType::Float64) =>
96        {
97            // we have an integer and double, fall down to double
98            DataType::Float64
99        },
100        // default to String for conflicting datatypes (e.g bool and int)
101        _ => DataType::String,
102    }
103}
104
105/// Infer the data type of a record
106pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
107    // when quoting is enabled in the reader, these quotes aren't escaped, we default to
108    // String for them
109    let bytes = string.as_bytes();
110    if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {
111        if try_parse_dates {
112            #[cfg(feature = "polars-time")]
113            {
114                match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
115                    Some(pattern_with_offset) => match pattern_with_offset {
116                        Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
117                            DataType::Datetime(TimeUnit::Microseconds, None)
118                        },
119                        Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
120                        Pattern::DatetimeYMDZ => DataType::Datetime(
121                            TimeUnit::Microseconds,
122                            Some(PlSmallStr::from_static("UTC")),
123                        ),
124                        Pattern::Time => DataType::Time,
125                    },
126                    None => DataType::String,
127                }
128            }
129            #[cfg(not(feature = "polars-time"))]
130            {
131                panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
132            }
133        } else {
134            DataType::String
135        }
136    }
137    // match regex in a particular order
138    else if BOOLEAN_RE.is_match(string) {
139        DataType::Boolean
140    } else if !decimal_comma && FLOAT_RE.is_match(string)
141        || decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
142    {
143        DataType::Float64
144    } else if INTEGER_RE.is_match(string) {
145        DataType::Int64
146    } else if try_parse_dates {
147        #[cfg(feature = "polars-time")]
148        {
149            match date_infer::infer_pattern_single(string) {
150                Some(pattern_with_offset) => match pattern_with_offset {
151                    Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
152                        DataType::Datetime(TimeUnit::Microseconds, None)
153                    },
154                    Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
155                    Pattern::DatetimeYMDZ => DataType::Datetime(
156                        TimeUnit::Microseconds,
157                        Some(PlSmallStr::from_static("UTC")),
158                    ),
159                    Pattern::Time => DataType::Time,
160                },
161                None => DataType::String,
162            }
163        }
164        #[cfg(not(feature = "polars-time"))]
165        {
166            panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
167        }
168    } else {
169        DataType::String
170    }
171}
172
173#[inline]
174fn parse_bytes_with_encoding(bytes: &[u8], encoding: CsvEncoding) -> PolarsResult<Cow<str>> {
175    Ok(match encoding {
176        CsvEncoding::Utf8 => simdutf8::basic::from_utf8(bytes)
177            .map_err(|_| polars_err!(ComputeError: "invalid utf-8 sequence"))?
178            .into(),
179        CsvEncoding::LossyUtf8 => String::from_utf8_lossy(bytes),
180    })
181}
182
183fn column_name(i: usize) -> PlSmallStr {
184    format_pl_smallstr!("column_{}", i + 1)
185}
186
187#[allow(clippy::too_many_arguments)]
188fn infer_file_schema_inner(
189    reader_bytes: &ReaderBytes,
190    parse_options: &CsvParseOptions,
191    max_read_rows: Option<usize>,
192    has_header: bool,
193    schema_overwrite: Option<&Schema>,
194    // we take &mut because we maybe need to skip more rows dependent
195    // on the schema inference
196    mut skip_rows: usize,
197    skip_rows_after_header: usize,
198    recursion_count: u8,
199    raise_if_empty: bool,
200) -> PolarsResult<(Schema, usize, usize)> {
201    // keep track so that we can determine the amount of bytes read
202    let start_ptr = reader_bytes.as_ptr() as usize;
203
204    // We use lossy utf8 here because we don't want the schema inference to fail on utf8.
205    // It may later.
206    let encoding = CsvEncoding::LossyUtf8;
207
208    let bytes = skip_line_ending(skip_bom(reader_bytes), parse_options.eol_char);
209    if raise_if_empty {
210        polars_ensure!(!bytes.is_empty(), NoData: "empty CSV");
211    };
212    let mut lines = SplitLines::new(
213        bytes,
214        parse_options.quote_char,
215        parse_options.eol_char,
216        parse_options.comment_prefix.as_ref(),
217    )
218    .skip(skip_rows);
219
220    // get or create header names
221    // when has_header is false, creates default column names with column_ prefix
222
223    // skip lines that are comments
224    let mut first_line = None;
225
226    for (i, line) in (&mut lines).enumerate() {
227        if !is_comment_line(line, parse_options.comment_prefix.as_ref()) {
228            first_line = Some(line);
229            skip_rows += i;
230            break;
231        }
232    }
233
234    if first_line.is_none() {
235        first_line = lines.next();
236    }
237
238    // now that we've found the first non-comment line we parse the headers, or we create a header
239    let mut headers: Vec<PlSmallStr> = if let Some(mut header_line) = first_line {
240        let len = header_line.len();
241        if len > 1 {
242            // remove carriage return
243            let trailing_byte = header_line[len - 1];
244            if trailing_byte == b'\r' {
245                header_line = &header_line[..len - 1];
246            }
247        }
248
249        let byterecord = SplitFields::new(
250            header_line,
251            parse_options.separator,
252            parse_options.quote_char,
253            parse_options.eol_char,
254        );
255        if has_header {
256            let headers = byterecord
257                .map(|(slice, needs_escaping)| {
258                    let slice_escaped = if needs_escaping && (slice.len() >= 2) {
259                        &slice[1..(slice.len() - 1)]
260                    } else {
261                        slice
262                    };
263                    let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
264                    Ok(s)
265                })
266                .collect::<PolarsResult<Vec<_>>>()?;
267
268            let mut final_headers = Vec::with_capacity(headers.len());
269
270            let mut header_names = PlHashMap::with_capacity(headers.len());
271
272            for name in &headers {
273                let count = header_names.entry(name.as_ref()).or_insert(0usize);
274                if *count != 0 {
275                    final_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
276                } else {
277                    final_headers.push(PlSmallStr::from_str(name))
278                }
279                *count += 1;
280            }
281            final_headers
282        } else {
283            byterecord
284                .enumerate()
285                .map(|(i, _s)| column_name(i))
286                .collect::<Vec<PlSmallStr>>()
287        }
288    } else if has_header && !bytes.is_empty() && recursion_count == 0 {
289        // there was no new line char. So we copy the whole buf and add one
290        // this is likely to be cheap as there are no rows.
291        let mut buf = Vec::with_capacity(bytes.len() + 2);
292        buf.extend_from_slice(bytes);
293        buf.push(parse_options.eol_char);
294
295        return infer_file_schema_inner(
296            &ReaderBytes::Owned(buf.into()),
297            parse_options,
298            max_read_rows,
299            has_header,
300            schema_overwrite,
301            skip_rows,
302            skip_rows_after_header,
303            recursion_count + 1,
304            raise_if_empty,
305        );
306    } else if !raise_if_empty {
307        return Ok((Schema::default(), 0, 0));
308    } else {
309        polars_bail!(NoData: "empty CSV");
310    };
311    if !has_header {
312        // re-init lines so that the header is included in type inference.
313        lines = SplitLines::new(
314            bytes,
315            parse_options.quote_char,
316            parse_options.eol_char,
317            parse_options.comment_prefix.as_ref(),
318        )
319        .skip(skip_rows);
320    }
321
322    // keep track of inferred field types
323    let mut column_types: Vec<PlHashSet<DataType>> =
324        vec![PlHashSet::with_capacity(4); headers.len()];
325    // keep track of columns with nulls
326    let mut nulls: Vec<bool> = vec![false; headers.len()];
327
328    let mut rows_count = 0;
329    let mut fields = Vec::with_capacity(headers.len());
330
331    // needed to prevent ownership going into the iterator loop
332    let records_ref = &mut lines;
333
334    let mut end_ptr = start_ptr;
335    for mut line in records_ref
336        .take(match max_read_rows {
337            Some(max_read_rows) => {
338                if max_read_rows <= (usize::MAX - skip_rows_after_header) {
339                    // read skip_rows_after_header more rows for inferring
340                    // the correct schema as the first skip_rows_after_header
341                    // rows will be skipped
342                    max_read_rows + skip_rows_after_header
343                } else {
344                    max_read_rows
345                }
346            },
347            None => usize::MAX,
348        })
349        .skip(skip_rows_after_header)
350    {
351        rows_count += 1;
352        // keep track so that we can determine the amount of bytes read
353        end_ptr = line.as_ptr() as usize + line.len();
354
355        if line.is_empty() {
356            continue;
357        }
358
359        // line is a comment -> skip
360        if is_comment_line(line, parse_options.comment_prefix.as_ref()) {
361            continue;
362        }
363
364        let len = line.len();
365        if len > 1 {
366            // remove carriage return
367            let trailing_byte = line[len - 1];
368            if trailing_byte == b'\r' {
369                line = &line[..len - 1];
370            }
371        }
372
373        let record = SplitFields::new(
374            line,
375            parse_options.separator,
376            parse_options.quote_char,
377            parse_options.eol_char,
378        );
379
380        for (i, (slice, needs_escaping)) in record.enumerate() {
381            // When `has_header = False` and ``
382            // Increase the schema if the first line didn't have all columns.
383            if i >= headers.len() {
384                if !has_header {
385                    headers.push(column_name(i));
386                    column_types.push(Default::default());
387                    nulls.push(false);
388                } else {
389                    break;
390                }
391            }
392
393            if slice.is_empty() {
394                unsafe { *nulls.get_unchecked_mut(i) = true };
395            } else {
396                let slice_escaped = if needs_escaping && (slice.len() >= 2) {
397                    &slice[1..(slice.len() - 1)]
398                } else {
399                    slice
400                };
401                let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
402                let dtype = match &parse_options.null_values {
403                    None => Some(infer_field_schema(
404                        &s,
405                        parse_options.try_parse_dates,
406                        parse_options.decimal_comma,
407                    )),
408                    Some(NullValues::AllColumns(names)) => {
409                        if !names.iter().any(|nv| nv == s.as_ref()) {
410                            Some(infer_field_schema(
411                                &s,
412                                parse_options.try_parse_dates,
413                                parse_options.decimal_comma,
414                            ))
415                        } else {
416                            None
417                        }
418                    },
419                    Some(NullValues::AllColumnsSingle(name)) => {
420                        if s.as_ref() != name.as_str() {
421                            Some(infer_field_schema(
422                                &s,
423                                parse_options.try_parse_dates,
424                                parse_options.decimal_comma,
425                            ))
426                        } else {
427                            None
428                        }
429                    },
430                    Some(NullValues::Named(names)) => {
431                        // SAFETY:
432                        // we iterate over headers length.
433                        let current_name = unsafe { headers.get_unchecked(i) };
434                        let null_name = &names.iter().find(|name| name.0 == current_name);
435
436                        if let Some(null_name) = null_name {
437                            if null_name.1.as_str() != s.as_ref() {
438                                Some(infer_field_schema(
439                                    &s,
440                                    parse_options.try_parse_dates,
441                                    parse_options.decimal_comma,
442                                ))
443                            } else {
444                                None
445                            }
446                        } else {
447                            Some(infer_field_schema(
448                                &s,
449                                parse_options.try_parse_dates,
450                                parse_options.decimal_comma,
451                            ))
452                        }
453                    },
454                };
455                if let Some(dtype) = dtype {
456                    unsafe { column_types.get_unchecked_mut(i).insert(dtype) };
457                }
458            }
459        }
460    }
461
462    // build schema from inference results
463    for i in 0..headers.len() {
464        let field_name = &headers[i];
465
466        if let Some(schema_overwrite) = schema_overwrite {
467            if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
468                fields.push(Field::new(name.clone(), dtype.clone()));
469                continue;
470            }
471
472            // column might have been renamed
473            // execute only if schema is complete
474            if schema_overwrite.len() == headers.len() {
475                if let Some((name, dtype)) = schema_overwrite.get_at_index(i) {
476                    fields.push(Field::new(name.clone(), dtype.clone()));
477                    continue;
478                }
479            }
480        }
481
482        let possibilities = &column_types[i];
483        let dtype = finish_infer_field_schema(possibilities);
484        fields.push(Field::new(field_name.clone(), dtype));
485    }
486    // if there is a single line after the header without an eol
487    // we copy the bytes add an eol and rerun this function
488    // so that the inference is consistent with and without eol char
489    if rows_count == 0
490        && !reader_bytes.is_empty()
491        && reader_bytes[reader_bytes.len() - 1] != parse_options.eol_char
492        && recursion_count == 0
493    {
494        let mut rb = Vec::with_capacity(reader_bytes.len() + 1);
495        rb.extend_from_slice(reader_bytes);
496        rb.push(parse_options.eol_char);
497        return infer_file_schema_inner(
498            &ReaderBytes::Owned(rb.into()),
499            parse_options,
500            max_read_rows,
501            has_header,
502            schema_overwrite,
503            skip_rows,
504            skip_rows_after_header,
505            recursion_count + 1,
506            raise_if_empty,
507        );
508    }
509
510    Ok((Schema::from_iter(fields), rows_count, end_ptr - start_ptr))
511}
512
513pub(super) fn check_decimal_comma(decimal_comma: bool, separator: u8) -> PolarsResult<()> {
514    if decimal_comma {
515        polars_ensure!(b',' != separator, InvalidOperation: "'decimal_comma' argument cannot be combined with ',' separator")
516    }
517    Ok(())
518}
519
520/// Infer the schema of a CSV file by reading through the first n rows of the file,
521/// with `max_read_rows` controlling the maximum number of rows to read.
522///
523/// If `max_read_rows` is not set, the whole file is read to infer its schema.
524///
525/// Returns
526///     - inferred schema
527///     - number of rows used for inference.
528///     - bytes read
529#[allow(clippy::too_many_arguments)]
530pub fn infer_file_schema(
531    reader_bytes: &ReaderBytes,
532    parse_options: &CsvParseOptions,
533    max_read_rows: Option<usize>,
534    has_header: bool,
535    schema_overwrite: Option<&Schema>,
536    skip_rows: usize,
537    skip_lines: usize,
538    skip_rows_after_header: usize,
539    raise_if_empty: bool,
540) -> PolarsResult<(Schema, usize, usize)> {
541    check_decimal_comma(parse_options.decimal_comma, parse_options.separator)?;
542
543    if skip_lines > 0 {
544        polars_ensure!(skip_rows == 0, InvalidOperation: "only one of 'skip_rows'/'skip_lines' may be set");
545        let bytes = skip_lines_naive(reader_bytes, parse_options.eol_char, skip_lines);
546        let reader_bytes = ReaderBytes::Borrowed(bytes);
547        infer_file_schema_inner(
548            &reader_bytes,
549            parse_options,
550            max_read_rows,
551            has_header,
552            schema_overwrite,
553            skip_rows,
554            skip_rows_after_header,
555            0,
556            raise_if_empty,
557        )
558    } else {
559        infer_file_schema_inner(
560            reader_bytes,
561            parse_options,
562            max_read_rows,
563            has_header,
564            schema_overwrite,
565            skip_rows,
566            skip_rows_after_header,
567            0,
568            raise_if_empty,
569        )
570    }
571}