Skip to main content

polars_io/csv/read/
schema_inference.rs

1use polars_buffer::Buffer;
2use polars_core::prelude::*;
3#[cfg(feature = "polars-time")]
4use polars_time::chunkedarray::string::infer as date_infer;
5#[cfg(feature = "polars-time")]
6use polars_time::prelude::string::Pattern;
7use polars_utils::format_pl_smallstr;
8
9use super::splitfields::SplitFields;
10use super::{CsvParseOptions, NullValues};
11use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
12
13/// Low-level CSV schema inference function.
14///
15/// Use `read_until_start_and_infer_schema` instead.
16#[allow(clippy::too_many_arguments)]
17pub(super) fn infer_file_schema_impl(
18    header_line: &Option<Buffer<u8>>,
19    content_lines: &[Buffer<u8>],
20    infer_all_as_str: bool,
21    parse_options: &CsvParseOptions,
22    schema_overwrite: Option<&Schema>,
23) -> Schema {
24    let mut headers = header_line
25        .as_ref()
26        .map(|line| infer_headers(line, parse_options))
27        .unwrap_or_else(|| Vec::with_capacity(8));
28
29    let extend_header_with_unknown_column = header_line.is_none();
30
31    let mut column_types = vec![PlHashSet::<DataType>::with_capacity(4); headers.len()];
32    let mut nulls = vec![false; headers.len()];
33
34    for content_line in content_lines {
35        infer_types_from_line(
36            content_line,
37            infer_all_as_str,
38            &mut headers,
39            extend_header_with_unknown_column,
40            parse_options,
41            &mut column_types,
42            &mut nulls,
43        );
44    }
45
46    build_schema(&headers, &column_types, schema_overwrite)
47}
48
49fn infer_headers(mut header_line: &[u8], parse_options: &CsvParseOptions) -> Vec<PlSmallStr> {
50    let len = header_line.len();
51
52    if header_line.last().copied() == Some(b'\r') {
53        header_line = &header_line[..len - 1];
54    }
55
56    let byterecord = SplitFields::new(
57        header_line,
58        parse_options.separator,
59        parse_options.quote_char,
60        parse_options.eol_char,
61    );
62
63    let headers = byterecord
64        .map(|(slice, needs_escaping)| {
65            let slice_escaped = if needs_escaping && (slice.len() >= 2) {
66                &slice[1..(slice.len() - 1)]
67            } else {
68                slice
69            };
70            String::from_utf8_lossy(slice_escaped)
71        })
72        .collect::<Vec<_>>();
73
74    let mut deduplicated_headers = Vec::with_capacity(headers.len());
75    let mut header_names = PlHashMap::with_capacity(headers.len());
76
77    for name in &headers {
78        let count = header_names.entry(name.as_ref()).or_insert(0usize);
79        if *count != 0 {
80            deduplicated_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
81        } else {
82            deduplicated_headers.push(PlSmallStr::from_str(name))
83        }
84        *count += 1;
85    }
86
87    deduplicated_headers
88}
89
90fn infer_types_from_line(
91    mut line: &[u8],
92    infer_all_as_str: bool,
93    headers: &mut Vec<PlSmallStr>,
94    extend_header_with_unknown_column: bool,
95    parse_options: &CsvParseOptions,
96    column_types: &mut Vec<PlHashSet<DataType>>,
97    nulls: &mut Vec<bool>,
98) {
99    let line_len = line.len();
100    if line.last().copied() == Some(b'\r') {
101        line = &line[..line_len - 1];
102    }
103
104    let record = SplitFields::new(
105        line,
106        parse_options.separator,
107        parse_options.quote_char,
108        parse_options.eol_char,
109    );
110
111    for (i, (slice, needs_escaping)) in record.enumerate() {
112        if i >= headers.len() {
113            if extend_header_with_unknown_column {
114                headers.push(column_name(i));
115                column_types.push(Default::default());
116                nulls.push(false);
117            } else {
118                break;
119            }
120        }
121
122        if infer_all_as_str {
123            column_types[i].insert(DataType::String);
124            continue;
125        }
126
127        if slice.is_empty() {
128            nulls[i] = true;
129        } else {
130            let slice_escaped = if needs_escaping && (slice.len() >= 2) {
131                &slice[1..(slice.len() - 1)]
132            } else {
133                slice
134            };
135            let s = String::from_utf8_lossy(slice_escaped);
136            let dtype = match &parse_options.null_values {
137                None => Some(infer_field_schema(
138                    &s,
139                    parse_options.try_parse_dates,
140                    parse_options.decimal_comma,
141                )),
142                Some(NullValues::AllColumns(names)) => {
143                    if !names.iter().any(|nv| nv == s.as_ref()) {
144                        Some(infer_field_schema(
145                            &s,
146                            parse_options.try_parse_dates,
147                            parse_options.decimal_comma,
148                        ))
149                    } else {
150                        None
151                    }
152                },
153                Some(NullValues::AllColumnsSingle(name)) => {
154                    if s.as_ref() != name.as_str() {
155                        Some(infer_field_schema(
156                            &s,
157                            parse_options.try_parse_dates,
158                            parse_options.decimal_comma,
159                        ))
160                    } else {
161                        None
162                    }
163                },
164                Some(NullValues::Named(names)) => {
165                    let current_name = &headers[i];
166                    let null_name = &names.iter().find(|name| name.0 == current_name);
167
168                    if let Some(null_name) = null_name {
169                        if null_name.1.as_str() != s.as_ref() {
170                            Some(infer_field_schema(
171                                &s,
172                                parse_options.try_parse_dates,
173                                parse_options.decimal_comma,
174                            ))
175                        } else {
176                            None
177                        }
178                    } else {
179                        Some(infer_field_schema(
180                            &s,
181                            parse_options.try_parse_dates,
182                            parse_options.decimal_comma,
183                        ))
184                    }
185                },
186            };
187            if let Some(dtype) = dtype {
188                column_types[i].insert(dtype);
189            }
190        }
191    }
192}
193
194fn build_schema(
195    headers: &[PlSmallStr],
196    column_types: &[PlHashSet<DataType>],
197    schema_overwrite: Option<&Schema>,
198) -> Schema {
199    assert!(headers.len() == column_types.len());
200
201    let get_schema_overwrite = |field_name| {
202        if let Some(schema_overwrite) = schema_overwrite {
203            // Apply schema_overwrite by column name only. Positional overrides are handled
204            // separately via dtype_overwrite.
205            if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
206                return Some((name.clone(), dtype.clone()));
207            }
208        }
209
210        None
211    };
212
213    Schema::from_iter(
214        headers
215            .iter()
216            .zip(column_types)
217            .map(|(field_name, type_possibilities)| {
218                let (name, dtype) = get_schema_overwrite(field_name).unwrap_or_else(|| {
219                    (
220                        field_name.clone(),
221                        finish_infer_field_schema(type_possibilities),
222                    )
223                });
224
225                Field::new(name, dtype)
226            }),
227    )
228}
229
230pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
231    // determine data type based on possible types
232    // if there are incompatible types, use DataType::String
233    match possibilities.len() {
234        1 => possibilities.iter().next().unwrap().clone(),
235        2 if possibilities.contains(&DataType::Int64)
236            && possibilities.contains(&DataType::Float64) =>
237        {
238            // we have an integer and double, fall down to double
239            DataType::Float64
240        },
241        #[cfg(feature = "dtype-i128")]
242        2 if possibilities.contains(&DataType::Int64)
243            && possibilities.contains(&DataType::Int128) =>
244        {
245            // all values fit within i128
246            DataType::Int128
247        },
248        #[cfg(feature = "dtype-i128")]
249        2 if possibilities.contains(&DataType::Int128)
250            && possibilities.contains(&DataType::Float64) =>
251        {
252            // fall down to double for mixed int128 and float
253            DataType::Float64
254        },
255        // default to String for conflicting datatypes (e.g bool and int)
256        _ => DataType::String,
257    }
258}
259
260/// Infer the data type of a record
261pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
262    // when quoting is enabled in the reader, these quotes aren't escaped, we default to
263    // String for them
264    let bytes = string.as_bytes();
265    if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {
266        if try_parse_dates {
267            #[cfg(feature = "polars-time")]
268            {
269                match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
270                    Some(pattern_with_offset) => match pattern_with_offset {
271                        Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
272                            DataType::Datetime(TimeUnit::Microseconds, None)
273                        },
274                        Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
275                        Pattern::DatetimeYMDZ => {
276                            DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
277                        },
278                        Pattern::Time => DataType::Time,
279                    },
280                    None => DataType::String,
281                }
282            }
283            #[cfg(not(feature = "polars-time"))]
284            {
285                panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
286            }
287        } else {
288            DataType::String
289        }
290    }
291    // match regex in a particular order
292    else if BOOLEAN_RE.is_match(string) {
293        DataType::Boolean
294    } else if !decimal_comma && FLOAT_RE.is_match(string)
295        || decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
296    {
297        DataType::Float64
298    } else if INTEGER_RE.is_match(string) {
299        if string.parse::<i64>().is_ok() {
300            DataType::Int64
301        } else {
302            #[cfg(feature = "dtype-i128")]
303            {
304                DataType::Int128
305            }
306            #[cfg(not(feature = "dtype-i128"))]
307            {
308                DataType::Int64
309            }
310        }
311    } else if try_parse_dates {
312        #[cfg(feature = "polars-time")]
313        {
314            match date_infer::infer_pattern_single(string) {
315                Some(pattern_with_offset) => match pattern_with_offset {
316                    Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
317                        DataType::Datetime(TimeUnit::Microseconds, None)
318                    },
319                    Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
320                    Pattern::DatetimeYMDZ => {
321                        DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
322                    },
323                    Pattern::Time => DataType::Time,
324                },
325                None => DataType::String,
326            }
327        }
328        #[cfg(not(feature = "polars-time"))]
329        {
330            panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
331        }
332    } else {
333        DataType::String
334    }
335}
336
337fn column_name(i: usize) -> PlSmallStr {
338    format_pl_smallstr!("column_{}", i + 1)
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn test_infer_field_schema_i64_overflow() {
347        // Values within i64 range should infer as Int64.
348        assert_eq!(
349            infer_field_schema("9223372036854775807", false, false),
350            DataType::Int64,
351        );
352
353        // Values exceeding i64::MAX should infer as Int128 when the feature is enabled,
354        // otherwise as String.
355        let large = "12345678901234567890";
356        #[cfg(feature = "dtype-i128")]
357        assert_eq!(infer_field_schema(large, false, false), DataType::Int128,);
358        #[cfg(not(feature = "dtype-i128"))]
359        assert_eq!(infer_field_schema(large, false, false), DataType::Int64,);
360    }
361
362    #[test]
363    #[cfg(feature = "dtype-i128")]
364    fn test_finish_infer_field_schema_i64_and_i128() {
365        let mut possibilities = PlHashSet::new();
366        possibilities.insert(DataType::Int64);
367        possibilities.insert(DataType::Int128);
368        assert_eq!(finish_infer_field_schema(&possibilities), DataType::Int128);
369    }
370}