1use std::borrow::Cow;
2
3use polars_core::prelude::*;
4#[cfg(feature = "polars-time")]
5use polars_time::chunkedarray::string::infer as date_infer;
6#[cfg(feature = "polars-time")]
7use polars_time::prelude::string::Pattern;
8use polars_utils::format_pl_smallstr;
9
10use super::parser::{SplitLines, is_comment_line, skip_bom, skip_line_ending};
11use super::splitfields::SplitFields;
12use super::{CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues};
13use crate::csv::read::parser::skip_lines_naive;
14use crate::mmap::ReaderBytes;
15use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
16
17#[derive(Clone, Debug, Default)]
18pub struct SchemaInferenceResult {
19 inferred_schema: SchemaRef,
20 rows_read: usize,
21 bytes_read: usize,
22 bytes_total: usize,
23 n_threads: Option<usize>,
24}
25
26impl SchemaInferenceResult {
27 pub fn try_from_reader_bytes_and_options(
28 reader_bytes: &ReaderBytes,
29 options: &CsvReadOptions,
30 ) -> PolarsResult<Self> {
31 let parse_options = options.get_parse_options();
32
33 let infer_schema_length = options.infer_schema_length;
34 let has_header = options.has_header;
35 let schema_overwrite_arc = options.schema_overwrite.clone();
36 let schema_overwrite = schema_overwrite_arc.as_ref().map(|x| x.as_ref());
37 let skip_rows = options.skip_rows;
38 let skip_lines = options.skip_lines;
39 let skip_rows_after_header = options.skip_rows_after_header;
40 let raise_if_empty = options.raise_if_empty;
41 let n_threads = options.n_threads;
42
43 let bytes_total = reader_bytes.len();
44
45 let (inferred_schema, rows_read, bytes_read) = infer_file_schema(
46 reader_bytes,
47 &parse_options,
48 infer_schema_length,
49 has_header,
50 schema_overwrite,
51 skip_rows,
52 skip_lines,
53 skip_rows_after_header,
54 raise_if_empty,
55 )?;
56
57 let this = Self {
58 inferred_schema: Arc::new(inferred_schema),
59 rows_read,
60 bytes_read,
61 bytes_total,
62 n_threads,
63 };
64
65 Ok(this)
66 }
67
68 pub fn with_inferred_schema(mut self, inferred_schema: SchemaRef) -> Self {
69 self.inferred_schema = inferred_schema;
70 self
71 }
72
73 pub fn get_inferred_schema(&self) -> SchemaRef {
74 self.inferred_schema.clone()
75 }
76
77 pub fn get_estimated_n_rows(&self) -> usize {
78 (self.rows_read as f64 / self.bytes_read as f64 * self.bytes_total as f64) as usize
79 }
80}
81
82impl CsvReadOptions {
83 pub fn update_with_inference_result(&mut self, si_result: &SchemaInferenceResult) {
85 self.n_threads = si_result.n_threads;
86 }
87}
88
89pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
90 match possibilities.len() {
93 1 => possibilities.iter().next().unwrap().clone(),
94 2 if possibilities.contains(&DataType::Int64)
95 && possibilities.contains(&DataType::Float64) =>
96 {
97 DataType::Float64
99 },
100 _ => DataType::String,
102 }
103}
104
105pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
107 let bytes = string.as_bytes();
110 if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {
111 if try_parse_dates {
112 #[cfg(feature = "polars-time")]
113 {
114 match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
115 Some(pattern_with_offset) => match pattern_with_offset {
116 Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
117 DataType::Datetime(TimeUnit::Microseconds, None)
118 },
119 Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
120 Pattern::DatetimeYMDZ => DataType::Datetime(
121 TimeUnit::Microseconds,
122 Some(PlSmallStr::from_static("UTC")),
123 ),
124 Pattern::Time => DataType::Time,
125 },
126 None => DataType::String,
127 }
128 }
129 #[cfg(not(feature = "polars-time"))]
130 {
131 panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
132 }
133 } else {
134 DataType::String
135 }
136 }
137 else if BOOLEAN_RE.is_match(string) {
139 DataType::Boolean
140 } else if !decimal_comma && FLOAT_RE.is_match(string)
141 || decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
142 {
143 DataType::Float64
144 } else if INTEGER_RE.is_match(string) {
145 DataType::Int64
146 } else if try_parse_dates {
147 #[cfg(feature = "polars-time")]
148 {
149 match date_infer::infer_pattern_single(string) {
150 Some(pattern_with_offset) => match pattern_with_offset {
151 Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
152 DataType::Datetime(TimeUnit::Microseconds, None)
153 },
154 Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
155 Pattern::DatetimeYMDZ => DataType::Datetime(
156 TimeUnit::Microseconds,
157 Some(PlSmallStr::from_static("UTC")),
158 ),
159 Pattern::Time => DataType::Time,
160 },
161 None => DataType::String,
162 }
163 }
164 #[cfg(not(feature = "polars-time"))]
165 {
166 panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
167 }
168 } else {
169 DataType::String
170 }
171}
172
173#[inline]
174fn parse_bytes_with_encoding(bytes: &[u8], encoding: CsvEncoding) -> PolarsResult<Cow<str>> {
175 Ok(match encoding {
176 CsvEncoding::Utf8 => simdutf8::basic::from_utf8(bytes)
177 .map_err(|_| polars_err!(ComputeError: "invalid utf-8 sequence"))?
178 .into(),
179 CsvEncoding::LossyUtf8 => String::from_utf8_lossy(bytes),
180 })
181}
182
183fn column_name(i: usize) -> PlSmallStr {
184 format_pl_smallstr!("column_{}", i + 1)
185}
186
187#[allow(clippy::too_many_arguments)]
188fn infer_file_schema_inner(
189 reader_bytes: &ReaderBytes,
190 parse_options: &CsvParseOptions,
191 max_read_rows: Option<usize>,
192 has_header: bool,
193 schema_overwrite: Option<&Schema>,
194 mut skip_rows: usize,
197 skip_rows_after_header: usize,
198 recursion_count: u8,
199 raise_if_empty: bool,
200) -> PolarsResult<(Schema, usize, usize)> {
201 let start_ptr = reader_bytes.as_ptr() as usize;
203
204 let encoding = CsvEncoding::LossyUtf8;
207
208 let bytes = skip_line_ending(skip_bom(reader_bytes), parse_options.eol_char);
209 if raise_if_empty {
210 polars_ensure!(!bytes.is_empty(), NoData: "empty CSV");
211 };
212 let mut lines = SplitLines::new(
213 bytes,
214 parse_options.quote_char,
215 parse_options.eol_char,
216 parse_options.comment_prefix.as_ref(),
217 )
218 .skip(skip_rows);
219
220 let mut first_line = None;
225
226 for (i, line) in (&mut lines).enumerate() {
227 if !is_comment_line(line, parse_options.comment_prefix.as_ref()) {
228 first_line = Some(line);
229 skip_rows += i;
230 break;
231 }
232 }
233
234 if first_line.is_none() {
235 first_line = lines.next();
236 }
237
238 let mut headers: Vec<PlSmallStr> = if let Some(mut header_line) = first_line {
240 let len = header_line.len();
241 if len > 1 {
242 let trailing_byte = header_line[len - 1];
244 if trailing_byte == b'\r' {
245 header_line = &header_line[..len - 1];
246 }
247 }
248
249 let byterecord = SplitFields::new(
250 header_line,
251 parse_options.separator,
252 parse_options.quote_char,
253 parse_options.eol_char,
254 );
255 if has_header {
256 let headers = byterecord
257 .map(|(slice, needs_escaping)| {
258 let slice_escaped = if needs_escaping && (slice.len() >= 2) {
259 &slice[1..(slice.len() - 1)]
260 } else {
261 slice
262 };
263 let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
264 Ok(s)
265 })
266 .collect::<PolarsResult<Vec<_>>>()?;
267
268 let mut final_headers = Vec::with_capacity(headers.len());
269
270 let mut header_names = PlHashMap::with_capacity(headers.len());
271
272 for name in &headers {
273 let count = header_names.entry(name.as_ref()).or_insert(0usize);
274 if *count != 0 {
275 final_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
276 } else {
277 final_headers.push(PlSmallStr::from_str(name))
278 }
279 *count += 1;
280 }
281 final_headers
282 } else {
283 byterecord
284 .enumerate()
285 .map(|(i, _s)| column_name(i))
286 .collect::<Vec<PlSmallStr>>()
287 }
288 } else if has_header && !bytes.is_empty() && recursion_count == 0 {
289 let mut buf = Vec::with_capacity(bytes.len() + 2);
292 buf.extend_from_slice(bytes);
293 buf.push(parse_options.eol_char);
294
295 return infer_file_schema_inner(
296 &ReaderBytes::Owned(buf.into()),
297 parse_options,
298 max_read_rows,
299 has_header,
300 schema_overwrite,
301 skip_rows,
302 skip_rows_after_header,
303 recursion_count + 1,
304 raise_if_empty,
305 );
306 } else if !raise_if_empty {
307 return Ok((Schema::default(), 0, 0));
308 } else {
309 polars_bail!(NoData: "empty CSV");
310 };
311 if !has_header {
312 lines = SplitLines::new(
314 bytes,
315 parse_options.quote_char,
316 parse_options.eol_char,
317 parse_options.comment_prefix.as_ref(),
318 )
319 .skip(skip_rows);
320 }
321
322 let mut column_types: Vec<PlHashSet<DataType>> =
324 vec![PlHashSet::with_capacity(4); headers.len()];
325 let mut nulls: Vec<bool> = vec![false; headers.len()];
327
328 let mut rows_count = 0;
329 let mut fields = Vec::with_capacity(headers.len());
330
331 let records_ref = &mut lines;
333
334 let mut end_ptr = start_ptr;
335 for mut line in records_ref
336 .take(match max_read_rows {
337 Some(max_read_rows) => {
338 if max_read_rows <= (usize::MAX - skip_rows_after_header) {
339 max_read_rows + skip_rows_after_header
343 } else {
344 max_read_rows
345 }
346 },
347 None => usize::MAX,
348 })
349 .skip(skip_rows_after_header)
350 {
351 rows_count += 1;
352 end_ptr = line.as_ptr() as usize + line.len();
354
355 if line.is_empty() {
356 continue;
357 }
358
359 if is_comment_line(line, parse_options.comment_prefix.as_ref()) {
361 continue;
362 }
363
364 let len = line.len();
365 if len > 1 {
366 let trailing_byte = line[len - 1];
368 if trailing_byte == b'\r' {
369 line = &line[..len - 1];
370 }
371 }
372
373 let record = SplitFields::new(
374 line,
375 parse_options.separator,
376 parse_options.quote_char,
377 parse_options.eol_char,
378 );
379
380 for (i, (slice, needs_escaping)) in record.enumerate() {
381 if i >= headers.len() {
384 if !has_header {
385 headers.push(column_name(i));
386 column_types.push(Default::default());
387 nulls.push(false);
388 } else {
389 break;
390 }
391 }
392
393 if slice.is_empty() {
394 unsafe { *nulls.get_unchecked_mut(i) = true };
395 } else {
396 let slice_escaped = if needs_escaping && (slice.len() >= 2) {
397 &slice[1..(slice.len() - 1)]
398 } else {
399 slice
400 };
401 let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
402 let dtype = match &parse_options.null_values {
403 None => Some(infer_field_schema(
404 &s,
405 parse_options.try_parse_dates,
406 parse_options.decimal_comma,
407 )),
408 Some(NullValues::AllColumns(names)) => {
409 if !names.iter().any(|nv| nv == s.as_ref()) {
410 Some(infer_field_schema(
411 &s,
412 parse_options.try_parse_dates,
413 parse_options.decimal_comma,
414 ))
415 } else {
416 None
417 }
418 },
419 Some(NullValues::AllColumnsSingle(name)) => {
420 if s.as_ref() != name.as_str() {
421 Some(infer_field_schema(
422 &s,
423 parse_options.try_parse_dates,
424 parse_options.decimal_comma,
425 ))
426 } else {
427 None
428 }
429 },
430 Some(NullValues::Named(names)) => {
431 let current_name = unsafe { headers.get_unchecked(i) };
434 let null_name = &names.iter().find(|name| name.0 == current_name);
435
436 if let Some(null_name) = null_name {
437 if null_name.1.as_str() != s.as_ref() {
438 Some(infer_field_schema(
439 &s,
440 parse_options.try_parse_dates,
441 parse_options.decimal_comma,
442 ))
443 } else {
444 None
445 }
446 } else {
447 Some(infer_field_schema(
448 &s,
449 parse_options.try_parse_dates,
450 parse_options.decimal_comma,
451 ))
452 }
453 },
454 };
455 if let Some(dtype) = dtype {
456 unsafe { column_types.get_unchecked_mut(i).insert(dtype) };
457 }
458 }
459 }
460 }
461
462 for i in 0..headers.len() {
464 let field_name = &headers[i];
465
466 if let Some(schema_overwrite) = schema_overwrite {
467 if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
468 fields.push(Field::new(name.clone(), dtype.clone()));
469 continue;
470 }
471
472 if schema_overwrite.len() == headers.len() {
475 if let Some((name, dtype)) = schema_overwrite.get_at_index(i) {
476 fields.push(Field::new(name.clone(), dtype.clone()));
477 continue;
478 }
479 }
480 }
481
482 let possibilities = &column_types[i];
483 let dtype = finish_infer_field_schema(possibilities);
484 fields.push(Field::new(field_name.clone(), dtype));
485 }
486 if rows_count == 0
490 && !reader_bytes.is_empty()
491 && reader_bytes[reader_bytes.len() - 1] != parse_options.eol_char
492 && recursion_count == 0
493 {
494 let mut rb = Vec::with_capacity(reader_bytes.len() + 1);
495 rb.extend_from_slice(reader_bytes);
496 rb.push(parse_options.eol_char);
497 return infer_file_schema_inner(
498 &ReaderBytes::Owned(rb.into()),
499 parse_options,
500 max_read_rows,
501 has_header,
502 schema_overwrite,
503 skip_rows,
504 skip_rows_after_header,
505 recursion_count + 1,
506 raise_if_empty,
507 );
508 }
509
510 Ok((Schema::from_iter(fields), rows_count, end_ptr - start_ptr))
511}
512
513pub(super) fn check_decimal_comma(decimal_comma: bool, separator: u8) -> PolarsResult<()> {
514 if decimal_comma {
515 polars_ensure!(b',' != separator, InvalidOperation: "'decimal_comma' argument cannot be combined with ',' separator")
516 }
517 Ok(())
518}
519
520#[allow(clippy::too_many_arguments)]
530pub fn infer_file_schema(
531 reader_bytes: &ReaderBytes,
532 parse_options: &CsvParseOptions,
533 max_read_rows: Option<usize>,
534 has_header: bool,
535 schema_overwrite: Option<&Schema>,
536 skip_rows: usize,
537 skip_lines: usize,
538 skip_rows_after_header: usize,
539 raise_if_empty: bool,
540) -> PolarsResult<(Schema, usize, usize)> {
541 check_decimal_comma(parse_options.decimal_comma, parse_options.separator)?;
542
543 if skip_lines > 0 {
544 polars_ensure!(skip_rows == 0, InvalidOperation: "only one of 'skip_rows'/'skip_lines' may be set");
545 let bytes = skip_lines_naive(reader_bytes, parse_options.eol_char, skip_lines);
546 let reader_bytes = ReaderBytes::Borrowed(bytes);
547 infer_file_schema_inner(
548 &reader_bytes,
549 parse_options,
550 max_read_rows,
551 has_header,
552 schema_overwrite,
553 skip_rows,
554 skip_rows_after_header,
555 0,
556 raise_if_empty,
557 )
558 } else {
559 infer_file_schema_inner(
560 reader_bytes,
561 parse_options,
562 max_read_rows,
563 has_header,
564 schema_overwrite,
565 skip_rows,
566 skip_rows_after_header,
567 0,
568 raise_if_empty,
569 )
570 }
571}