1use std::borrow::Cow;
2
3use polars_core::prelude::*;
4#[cfg(feature = "polars-time")]
5use polars_time::chunkedarray::string::infer as date_infer;
6#[cfg(feature = "polars-time")]
7use polars_time::prelude::string::Pattern;
8use polars_utils::format_pl_smallstr;
9use polars_utils::mmap::MemSlice;
10
11use super::splitfields::SplitFields;
12use super::{CsvEncoding, CsvParseOptions, NullValues};
13use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
14
15#[allow(clippy::too_many_arguments)]
19pub(super) fn infer_file_schema_impl(
20 header_line: &Option<MemSlice>,
21 content_lines: &[MemSlice],
22 infer_all_as_str: bool,
23 parse_options: &CsvParseOptions,
24 schema_overwrite: Option<&Schema>,
25) -> PolarsResult<Schema> {
26 let mut headers = header_line
27 .as_ref()
28 .map(|line| infer_headers(line, parse_options))
29 .transpose()?
30 .unwrap_or_else(|| Vec::with_capacity(8));
31
32 let extend_header_with_unknown_column = header_line.is_none();
33
34 let mut column_types = vec![PlHashSet::<DataType>::with_capacity(4); headers.len()];
35 let mut nulls = vec![false; headers.len()];
36
37 for content_line in content_lines {
38 infer_types_from_line(
39 content_line,
40 infer_all_as_str,
41 &mut headers,
42 extend_header_with_unknown_column,
43 parse_options,
44 &mut column_types,
45 &mut nulls,
46 )?;
47 }
48
49 Ok(build_schema(&headers, &column_types, schema_overwrite))
50}
51
52const INFER_ENCODING: CsvEncoding = CsvEncoding::LossyUtf8;
55
56fn infer_headers(
57 mut header_line: &[u8],
58 parse_options: &CsvParseOptions,
59) -> PolarsResult<Vec<PlSmallStr>> {
60 let len = header_line.len();
61
62 if header_line.last().copied() == Some(b'\r') {
63 header_line = &header_line[..len - 1];
64 }
65
66 let byterecord = SplitFields::new(
67 header_line,
68 parse_options.separator,
69 parse_options.quote_char,
70 parse_options.eol_char,
71 );
72
73 let headers = byterecord
74 .map(|(slice, needs_escaping)| {
75 let slice_escaped = if needs_escaping && (slice.len() >= 2) {
76 &slice[1..(slice.len() - 1)]
77 } else {
78 slice
79 };
80 let s = parse_bytes_with_encoding(slice_escaped, INFER_ENCODING)?;
81 Ok(s)
82 })
83 .collect::<PolarsResult<Vec<_>>>()?;
84
85 let mut deduplicated_headers = Vec::with_capacity(headers.len());
86 let mut header_names = PlHashMap::with_capacity(headers.len());
87
88 for name in &headers {
89 let count = header_names.entry(name.as_ref()).or_insert(0usize);
90 if *count != 0 {
91 deduplicated_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
92 } else {
93 deduplicated_headers.push(PlSmallStr::from_str(name))
94 }
95 *count += 1;
96 }
97
98 Ok(deduplicated_headers)
99}
100
101fn infer_types_from_line(
102 mut line: &[u8],
103 infer_all_as_str: bool,
104 headers: &mut Vec<PlSmallStr>,
105 extend_header_with_unknown_column: bool,
106 parse_options: &CsvParseOptions,
107 column_types: &mut Vec<PlHashSet<DataType>>,
108 nulls: &mut Vec<bool>,
109) -> PolarsResult<()> {
110 let line_len = line.len();
111 if line.last().copied() == Some(b'\r') {
112 line = &line[..line_len - 1];
113 }
114
115 let record = SplitFields::new(
116 line,
117 parse_options.separator,
118 parse_options.quote_char,
119 parse_options.eol_char,
120 );
121
122 for (i, (slice, needs_escaping)) in record.enumerate() {
123 if i >= headers.len() {
124 if extend_header_with_unknown_column {
125 headers.push(column_name(i));
126 column_types.push(Default::default());
127 nulls.push(false);
128 } else {
129 break;
130 }
131 }
132
133 if infer_all_as_str {
134 column_types[i].insert(DataType::String);
135 continue;
136 }
137
138 if slice.is_empty() {
139 nulls[i] = true;
140 } else {
141 let slice_escaped = if needs_escaping && (slice.len() >= 2) {
142 &slice[1..(slice.len() - 1)]
143 } else {
144 slice
145 };
146 let s = parse_bytes_with_encoding(slice_escaped, INFER_ENCODING)?;
147 let dtype = match &parse_options.null_values {
148 None => Some(infer_field_schema(
149 &s,
150 parse_options.try_parse_dates,
151 parse_options.decimal_comma,
152 )),
153 Some(NullValues::AllColumns(names)) => {
154 if !names.iter().any(|nv| nv == s.as_ref()) {
155 Some(infer_field_schema(
156 &s,
157 parse_options.try_parse_dates,
158 parse_options.decimal_comma,
159 ))
160 } else {
161 None
162 }
163 },
164 Some(NullValues::AllColumnsSingle(name)) => {
165 if s.as_ref() != name.as_str() {
166 Some(infer_field_schema(
167 &s,
168 parse_options.try_parse_dates,
169 parse_options.decimal_comma,
170 ))
171 } else {
172 None
173 }
174 },
175 Some(NullValues::Named(names)) => {
176 let current_name = &headers[i];
177 let null_name = &names.iter().find(|name| name.0 == current_name);
178
179 if let Some(null_name) = null_name {
180 if null_name.1.as_str() != s.as_ref() {
181 Some(infer_field_schema(
182 &s,
183 parse_options.try_parse_dates,
184 parse_options.decimal_comma,
185 ))
186 } else {
187 None
188 }
189 } else {
190 Some(infer_field_schema(
191 &s,
192 parse_options.try_parse_dates,
193 parse_options.decimal_comma,
194 ))
195 }
196 },
197 };
198 if let Some(dtype) = dtype {
199 column_types[i].insert(dtype);
200 }
201 }
202 }
203
204 Ok(())
205}
206
207fn build_schema(
208 headers: &[PlSmallStr],
209 column_types: &[PlHashSet<DataType>],
210 schema_overwrite: Option<&Schema>,
211) -> Schema {
212 assert!(headers.len() == column_types.len());
213
214 let get_schema_overwrite = |field_name| {
215 if let Some(schema_overwrite) = schema_overwrite {
216 if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
219 return Some((name.clone(), dtype.clone()));
220 }
221 }
222
223 None
224 };
225
226 Schema::from_iter(
227 headers
228 .iter()
229 .zip(column_types)
230 .map(|(field_name, type_possibilities)| {
231 let (name, dtype) = get_schema_overwrite(field_name).unwrap_or_else(|| {
232 (
233 field_name.clone(),
234 finish_infer_field_schema(type_possibilities),
235 )
236 });
237
238 Field::new(name, dtype)
239 }),
240 )
241}
242
243pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
244 match possibilities.len() {
247 1 => possibilities.iter().next().unwrap().clone(),
248 2 if possibilities.contains(&DataType::Int64)
249 && possibilities.contains(&DataType::Float64) =>
250 {
251 DataType::Float64
253 },
254 _ => DataType::String,
256 }
257}
258
259pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
261 let bytes = string.as_bytes();
264 if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {
265 if try_parse_dates {
266 #[cfg(feature = "polars-time")]
267 {
268 match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
269 Some(pattern_with_offset) => match pattern_with_offset {
270 Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
271 DataType::Datetime(TimeUnit::Microseconds, None)
272 },
273 Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
274 Pattern::DatetimeYMDZ => {
275 DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
276 },
277 Pattern::Time => DataType::Time,
278 },
279 None => DataType::String,
280 }
281 }
282 #[cfg(not(feature = "polars-time"))]
283 {
284 panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
285 }
286 } else {
287 DataType::String
288 }
289 }
290 else if BOOLEAN_RE.is_match(string) {
292 DataType::Boolean
293 } else if !decimal_comma && FLOAT_RE.is_match(string)
294 || decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
295 {
296 DataType::Float64
297 } else if INTEGER_RE.is_match(string) {
298 DataType::Int64
299 } else if try_parse_dates {
300 #[cfg(feature = "polars-time")]
301 {
302 match date_infer::infer_pattern_single(string) {
303 Some(pattern_with_offset) => match pattern_with_offset {
304 Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
305 DataType::Datetime(TimeUnit::Microseconds, None)
306 },
307 Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
308 Pattern::DatetimeYMDZ => {
309 DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
310 },
311 Pattern::Time => DataType::Time,
312 },
313 None => DataType::String,
314 }
315 }
316 #[cfg(not(feature = "polars-time"))]
317 {
318 panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
319 }
320 } else {
321 DataType::String
322 }
323}
324
325#[inline]
326fn parse_bytes_with_encoding(bytes: &[u8], encoding: CsvEncoding) -> PolarsResult<Cow<'_, str>> {
327 Ok(match encoding {
328 CsvEncoding::Utf8 => simdutf8::basic::from_utf8(bytes)
329 .map_err(|_| polars_err!(ComputeError: "invalid utf-8 sequence"))?
330 .into(),
331 CsvEncoding::LossyUtf8 => String::from_utf8_lossy(bytes),
332 })
333}
334
335fn column_name(i: usize) -> PlSmallStr {
336 format_pl_smallstr!("column_{}", i + 1)
337}