polars_io/csv/read/
schema_inference.rs

1use std::borrow::Cow;
2
3use polars_core::prelude::*;
4#[cfg(feature = "polars-time")]
5use polars_time::chunkedarray::string::infer as date_infer;
6#[cfg(feature = "polars-time")]
7use polars_time::prelude::string::Pattern;
8use polars_utils::format_pl_smallstr;
9use polars_utils::mmap::MemSlice;
10
11use super::splitfields::SplitFields;
12use super::{CsvEncoding, CsvParseOptions, NullValues};
13use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
14
15/// Low-level CSV schema inference function.
16///
17/// Use `read_until_start_and_infer_schema` instead.
18#[allow(clippy::too_many_arguments)]
19pub(super) fn infer_file_schema_impl(
20    header_line: &Option<MemSlice>,
21    content_lines: &[MemSlice],
22    infer_all_as_str: bool,
23    parse_options: &CsvParseOptions,
24    schema_overwrite: Option<&Schema>,
25) -> PolarsResult<Schema> {
26    let mut headers = header_line
27        .as_ref()
28        .map(|line| infer_headers(line, parse_options))
29        .transpose()?
30        .unwrap_or_else(|| Vec::with_capacity(8));
31
32    let extend_header_with_unknown_column = header_line.is_none();
33
34    let mut column_types = vec![PlHashSet::<DataType>::with_capacity(4); headers.len()];
35    let mut nulls = vec![false; headers.len()];
36
37    for content_line in content_lines {
38        infer_types_from_line(
39            content_line,
40            infer_all_as_str,
41            &mut headers,
42            extend_header_with_unknown_column,
43            parse_options,
44            &mut column_types,
45            &mut nulls,
46        )?;
47    }
48
49    Ok(build_schema(&headers, &column_types, schema_overwrite))
50}
51
52// We use lossy utf8 here because we don't want the schema inference to fail on utf8.
53// It may later.
54const INFER_ENCODING: CsvEncoding = CsvEncoding::LossyUtf8;
55
56fn infer_headers(
57    mut header_line: &[u8],
58    parse_options: &CsvParseOptions,
59) -> PolarsResult<Vec<PlSmallStr>> {
60    let len = header_line.len();
61
62    if header_line.last().copied() == Some(b'\r') {
63        header_line = &header_line[..len - 1];
64    }
65
66    let byterecord = SplitFields::new(
67        header_line,
68        parse_options.separator,
69        parse_options.quote_char,
70        parse_options.eol_char,
71    );
72
73    let headers = byterecord
74        .map(|(slice, needs_escaping)| {
75            let slice_escaped = if needs_escaping && (slice.len() >= 2) {
76                &slice[1..(slice.len() - 1)]
77            } else {
78                slice
79            };
80            let s = parse_bytes_with_encoding(slice_escaped, INFER_ENCODING)?;
81            Ok(s)
82        })
83        .collect::<PolarsResult<Vec<_>>>()?;
84
85    let mut deduplicated_headers = Vec::with_capacity(headers.len());
86    let mut header_names = PlHashMap::with_capacity(headers.len());
87
88    for name in &headers {
89        let count = header_names.entry(name.as_ref()).or_insert(0usize);
90        if *count != 0 {
91            deduplicated_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
92        } else {
93            deduplicated_headers.push(PlSmallStr::from_str(name))
94        }
95        *count += 1;
96    }
97
98    Ok(deduplicated_headers)
99}
100
101fn infer_types_from_line(
102    mut line: &[u8],
103    infer_all_as_str: bool,
104    headers: &mut Vec<PlSmallStr>,
105    extend_header_with_unknown_column: bool,
106    parse_options: &CsvParseOptions,
107    column_types: &mut Vec<PlHashSet<DataType>>,
108    nulls: &mut Vec<bool>,
109) -> PolarsResult<()> {
110    let line_len = line.len();
111    if line.last().copied() == Some(b'\r') {
112        line = &line[..line_len - 1];
113    }
114
115    let record = SplitFields::new(
116        line,
117        parse_options.separator,
118        parse_options.quote_char,
119        parse_options.eol_char,
120    );
121
122    for (i, (slice, needs_escaping)) in record.enumerate() {
123        if i >= headers.len() {
124            if extend_header_with_unknown_column {
125                headers.push(column_name(i));
126                column_types.push(Default::default());
127                nulls.push(false);
128            } else {
129                break;
130            }
131        }
132
133        if infer_all_as_str {
134            column_types[i].insert(DataType::String);
135            continue;
136        }
137
138        if slice.is_empty() {
139            nulls[i] = true;
140        } else {
141            let slice_escaped = if needs_escaping && (slice.len() >= 2) {
142                &slice[1..(slice.len() - 1)]
143            } else {
144                slice
145            };
146            let s = parse_bytes_with_encoding(slice_escaped, INFER_ENCODING)?;
147            let dtype = match &parse_options.null_values {
148                None => Some(infer_field_schema(
149                    &s,
150                    parse_options.try_parse_dates,
151                    parse_options.decimal_comma,
152                )),
153                Some(NullValues::AllColumns(names)) => {
154                    if !names.iter().any(|nv| nv == s.as_ref()) {
155                        Some(infer_field_schema(
156                            &s,
157                            parse_options.try_parse_dates,
158                            parse_options.decimal_comma,
159                        ))
160                    } else {
161                        None
162                    }
163                },
164                Some(NullValues::AllColumnsSingle(name)) => {
165                    if s.as_ref() != name.as_str() {
166                        Some(infer_field_schema(
167                            &s,
168                            parse_options.try_parse_dates,
169                            parse_options.decimal_comma,
170                        ))
171                    } else {
172                        None
173                    }
174                },
175                Some(NullValues::Named(names)) => {
176                    let current_name = &headers[i];
177                    let null_name = &names.iter().find(|name| name.0 == current_name);
178
179                    if let Some(null_name) = null_name {
180                        if null_name.1.as_str() != s.as_ref() {
181                            Some(infer_field_schema(
182                                &s,
183                                parse_options.try_parse_dates,
184                                parse_options.decimal_comma,
185                            ))
186                        } else {
187                            None
188                        }
189                    } else {
190                        Some(infer_field_schema(
191                            &s,
192                            parse_options.try_parse_dates,
193                            parse_options.decimal_comma,
194                        ))
195                    }
196                },
197            };
198            if let Some(dtype) = dtype {
199                column_types[i].insert(dtype);
200            }
201        }
202    }
203
204    Ok(())
205}
206
207fn build_schema(
208    headers: &[PlSmallStr],
209    column_types: &[PlHashSet<DataType>],
210    schema_overwrite: Option<&Schema>,
211) -> Schema {
212    assert!(headers.len() == column_types.len());
213
214    let get_schema_overwrite = |field_name| {
215        if let Some(schema_overwrite) = schema_overwrite {
216            // Apply schema_overwrite by column name only. Positional overrides are handled
217            // separately via dtype_overwrite.
218            if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
219                return Some((name.clone(), dtype.clone()));
220            }
221        }
222
223        None
224    };
225
226    Schema::from_iter(
227        headers
228            .iter()
229            .zip(column_types)
230            .map(|(field_name, type_possibilities)| {
231                let (name, dtype) = get_schema_overwrite(field_name).unwrap_or_else(|| {
232                    (
233                        field_name.clone(),
234                        finish_infer_field_schema(type_possibilities),
235                    )
236                });
237
238                Field::new(name, dtype)
239            }),
240    )
241}
242
243pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
244    // determine data type based on possible types
245    // if there are incompatible types, use DataType::String
246    match possibilities.len() {
247        1 => possibilities.iter().next().unwrap().clone(),
248        2 if possibilities.contains(&DataType::Int64)
249            && possibilities.contains(&DataType::Float64) =>
250        {
251            // we have an integer and double, fall down to double
252            DataType::Float64
253        },
254        // default to String for conflicting datatypes (e.g bool and int)
255        _ => DataType::String,
256    }
257}
258
259/// Infer the data type of a record
260pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
261    // when quoting is enabled in the reader, these quotes aren't escaped, we default to
262    // String for them
263    let bytes = string.as_bytes();
264    if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {
265        if try_parse_dates {
266            #[cfg(feature = "polars-time")]
267            {
268                match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
269                    Some(pattern_with_offset) => match pattern_with_offset {
270                        Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
271                            DataType::Datetime(TimeUnit::Microseconds, None)
272                        },
273                        Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
274                        Pattern::DatetimeYMDZ => {
275                            DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
276                        },
277                        Pattern::Time => DataType::Time,
278                    },
279                    None => DataType::String,
280                }
281            }
282            #[cfg(not(feature = "polars-time"))]
283            {
284                panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
285            }
286        } else {
287            DataType::String
288        }
289    }
290    // match regex in a particular order
291    else if BOOLEAN_RE.is_match(string) {
292        DataType::Boolean
293    } else if !decimal_comma && FLOAT_RE.is_match(string)
294        || decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
295    {
296        DataType::Float64
297    } else if INTEGER_RE.is_match(string) {
298        DataType::Int64
299    } else if try_parse_dates {
300        #[cfg(feature = "polars-time")]
301        {
302            match date_infer::infer_pattern_single(string) {
303                Some(pattern_with_offset) => match pattern_with_offset {
304                    Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
305                        DataType::Datetime(TimeUnit::Microseconds, None)
306                    },
307                    Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
308                    Pattern::DatetimeYMDZ => {
309                        DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
310                    },
311                    Pattern::Time => DataType::Time,
312                },
313                None => DataType::String,
314            }
315        }
316        #[cfg(not(feature = "polars-time"))]
317        {
318            panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
319        }
320    } else {
321        DataType::String
322    }
323}
324
325#[inline]
326fn parse_bytes_with_encoding(bytes: &[u8], encoding: CsvEncoding) -> PolarsResult<Cow<'_, str>> {
327    Ok(match encoding {
328        CsvEncoding::Utf8 => simdutf8::basic::from_utf8(bytes)
329            .map_err(|_| polars_err!(ComputeError: "invalid utf-8 sequence"))?
330            .into(),
331        CsvEncoding::LossyUtf8 => String::from_utf8_lossy(bytes),
332    })
333}
334
335fn column_name(i: usize) -> PlSmallStr {
336    format_pl_smallstr!("column_{}", i + 1)
337}