1use std::borrow::Cow;
2
3use polars_core::config::verbose;
4use polars_core::prelude::*;
5#[cfg(feature = "polars-time")]
6use polars_time::chunkedarray::string::infer as date_infer;
7#[cfg(feature = "polars-time")]
8use polars_time::prelude::string::Pattern;
9use polars_utils::format_pl_smallstr;
10
11use super::parser::{SplitLines, is_comment_line, skip_bom, skip_line_ending};
12use super::splitfields::SplitFields;
13use super::{CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues};
14use crate::csv::read::parser::skip_lines_naive;
15use crate::mmap::ReaderBytes;
16use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
17
18#[derive(Clone, Debug, Default)]
19pub struct SchemaInferenceResult {
20 inferred_schema: SchemaRef,
21 rows_read: usize,
22 bytes_read: usize,
23 bytes_total: usize,
24 n_threads: Option<usize>,
25}
26
27impl SchemaInferenceResult {
28 pub fn try_from_reader_bytes_and_options(
29 reader_bytes: &ReaderBytes,
30 options: &CsvReadOptions,
31 ) -> PolarsResult<Self> {
32 let parse_options = options.get_parse_options();
33
34 let infer_schema_length = options.infer_schema_length;
35 let has_header = options.has_header;
36 let schema_overwrite_arc = options.schema_overwrite.clone();
37 let schema_overwrite = schema_overwrite_arc.as_ref().map(|x| x.as_ref());
38 let skip_rows = options.skip_rows;
39 let skip_lines = options.skip_lines;
40 let skip_rows_after_header = options.skip_rows_after_header;
41 let raise_if_empty = options.raise_if_empty;
42 let mut n_threads = options.n_threads;
43
44 let bytes_total = reader_bytes.len();
45
46 let (inferred_schema, rows_read, bytes_read) = infer_file_schema(
47 reader_bytes,
48 &parse_options,
49 infer_schema_length,
50 has_header,
51 schema_overwrite,
52 skip_rows,
53 skip_lines,
54 skip_rows_after_header,
55 raise_if_empty,
56 &mut n_threads,
57 )?;
58
59 let this = Self {
60 inferred_schema: Arc::new(inferred_schema),
61 rows_read,
62 bytes_read,
63 bytes_total,
64 n_threads,
65 };
66
67 Ok(this)
68 }
69
70 pub fn with_inferred_schema(mut self, inferred_schema: SchemaRef) -> Self {
71 self.inferred_schema = inferred_schema;
72 self
73 }
74
75 pub fn get_inferred_schema(&self) -> SchemaRef {
76 self.inferred_schema.clone()
77 }
78
79 pub fn get_estimated_n_rows(&self) -> usize {
80 (self.rows_read as f64 / self.bytes_read as f64 * self.bytes_total as f64) as usize
81 }
82}
83
84impl CsvReadOptions {
85 pub fn update_with_inference_result(&mut self, si_result: &SchemaInferenceResult) {
87 self.n_threads = si_result.n_threads;
88 }
89}
90
91pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
92 match possibilities.len() {
95 1 => possibilities.iter().next().unwrap().clone(),
96 2 if possibilities.contains(&DataType::Int64)
97 && possibilities.contains(&DataType::Float64) =>
98 {
99 DataType::Float64
101 },
102 _ => DataType::String,
104 }
105}
106
107pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
109 if string.starts_with('"') {
112 if try_parse_dates {
113 #[cfg(feature = "polars-time")]
114 {
115 match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
116 Some(pattern_with_offset) => match pattern_with_offset {
117 Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
118 DataType::Datetime(TimeUnit::Microseconds, None)
119 },
120 Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
121 Pattern::DatetimeYMDZ => DataType::Datetime(
122 TimeUnit::Microseconds,
123 Some(PlSmallStr::from_static("UTC")),
124 ),
125 Pattern::Time => DataType::Time,
126 },
127 None => DataType::String,
128 }
129 }
130 #[cfg(not(feature = "polars-time"))]
131 {
132 panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
133 }
134 } else {
135 DataType::String
136 }
137 }
138 else if BOOLEAN_RE.is_match(string) {
140 DataType::Boolean
141 } else if !decimal_comma && FLOAT_RE.is_match(string)
142 || decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
143 {
144 DataType::Float64
145 } else if INTEGER_RE.is_match(string) {
146 DataType::Int64
147 } else if try_parse_dates {
148 #[cfg(feature = "polars-time")]
149 {
150 match date_infer::infer_pattern_single(string) {
151 Some(pattern_with_offset) => match pattern_with_offset {
152 Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
153 DataType::Datetime(TimeUnit::Microseconds, None)
154 },
155 Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
156 Pattern::DatetimeYMDZ => DataType::Datetime(
157 TimeUnit::Microseconds,
158 Some(PlSmallStr::from_static("UTC")),
159 ),
160 Pattern::Time => DataType::Time,
161 },
162 None => DataType::String,
163 }
164 }
165 #[cfg(not(feature = "polars-time"))]
166 {
167 panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
168 }
169 } else {
170 DataType::String
171 }
172}
173
174#[inline]
175fn parse_bytes_with_encoding(bytes: &[u8], encoding: CsvEncoding) -> PolarsResult<Cow<str>> {
176 Ok(match encoding {
177 CsvEncoding::Utf8 => simdutf8::basic::from_utf8(bytes)
178 .map_err(|_| polars_err!(ComputeError: "invalid utf-8 sequence"))?
179 .into(),
180 CsvEncoding::LossyUtf8 => String::from_utf8_lossy(bytes),
181 })
182}
183
184#[allow(clippy::too_many_arguments)]
185fn infer_file_schema_inner(
186 reader_bytes: &ReaderBytes,
187 parse_options: &CsvParseOptions,
188 max_read_rows: Option<usize>,
189 has_header: bool,
190 schema_overwrite: Option<&Schema>,
191 mut skip_rows: usize,
194 skip_rows_after_header: usize,
195 recursion_count: u8,
196 raise_if_empty: bool,
197 n_threads: &mut Option<usize>,
198) -> PolarsResult<(Schema, usize, usize)> {
199 let start_ptr = reader_bytes.as_ptr() as usize;
201
202 let encoding = CsvEncoding::LossyUtf8;
205
206 let bytes = skip_line_ending(skip_bom(reader_bytes), parse_options.eol_char);
207 if raise_if_empty {
208 polars_ensure!(!bytes.is_empty(), NoData: "empty CSV");
209 };
210 let mut lines = SplitLines::new(
211 bytes,
212 parse_options.quote_char,
213 parse_options.eol_char,
214 parse_options.comment_prefix.as_ref(),
215 )
216 .skip(skip_rows);
217
218 let mut first_line = None;
223
224 for (i, line) in (&mut lines).enumerate() {
225 if !is_comment_line(line, parse_options.comment_prefix.as_ref()) {
226 first_line = Some(line);
227 skip_rows += i;
228 break;
229 }
230 }
231
232 if first_line.is_none() {
233 first_line = lines.next();
234 }
235
236 let headers: Vec<PlSmallStr> = if let Some(mut header_line) = first_line {
238 let len = header_line.len();
239 if len > 1 {
240 let trailing_byte = header_line[len - 1];
242 if trailing_byte == b'\r' {
243 header_line = &header_line[..len - 1];
244 }
245 }
246
247 let byterecord = SplitFields::new(
248 header_line,
249 parse_options.separator,
250 parse_options.quote_char,
251 parse_options.eol_char,
252 );
253 if has_header {
254 let headers = byterecord
255 .map(|(slice, needs_escaping)| {
256 let slice_escaped = if needs_escaping && (slice.len() >= 2) {
257 &slice[1..(slice.len() - 1)]
258 } else {
259 slice
260 };
261 let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
262 Ok(s)
263 })
264 .collect::<PolarsResult<Vec<_>>>()?;
265
266 let mut final_headers = Vec::with_capacity(headers.len());
267
268 let mut header_names = PlHashMap::with_capacity(headers.len());
269
270 for name in &headers {
271 let count = header_names.entry(name.as_ref()).or_insert(0usize);
272 if *count != 0 {
273 final_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
274 } else {
275 final_headers.push(PlSmallStr::from_str(name))
276 }
277 *count += 1;
278 }
279 final_headers
280 } else {
281 byterecord
282 .enumerate()
283 .map(|(i, _s)| format_pl_smallstr!("column_{}", i + 1))
284 .collect::<Vec<PlSmallStr>>()
285 }
286 } else if has_header && !bytes.is_empty() && recursion_count == 0 {
287 let mut buf = Vec::with_capacity(bytes.len() + 2);
290 buf.extend_from_slice(bytes);
291 buf.push(parse_options.eol_char);
292
293 return infer_file_schema_inner(
294 &ReaderBytes::Owned(buf.into()),
295 parse_options,
296 max_read_rows,
297 has_header,
298 schema_overwrite,
299 skip_rows,
300 skip_rows_after_header,
301 recursion_count + 1,
302 raise_if_empty,
303 n_threads,
304 );
305 } else if !raise_if_empty {
306 return Ok((Schema::default(), 0, 0));
307 } else {
308 polars_bail!(NoData: "empty CSV");
309 };
310 if !has_header {
311 lines = SplitLines::new(
313 bytes,
314 parse_options.quote_char,
315 parse_options.eol_char,
316 parse_options.comment_prefix.as_ref(),
317 )
318 .skip(skip_rows);
319 }
320
321 let header_length = headers.len();
322 let mut column_types: Vec<PlHashSet<DataType>> =
324 vec![PlHashSet::with_capacity(4); header_length];
325 let mut nulls: Vec<bool> = vec![false; header_length];
327
328 let mut rows_count = 0;
329 let mut fields = Vec::with_capacity(header_length);
330
331 let records_ref = &mut lines;
333
334 let mut end_ptr = start_ptr;
335 for mut line in records_ref
336 .take(match max_read_rows {
337 Some(max_read_rows) => {
338 if max_read_rows <= (usize::MAX - skip_rows_after_header) {
339 max_read_rows + skip_rows_after_header
343 } else {
344 max_read_rows
345 }
346 },
347 None => usize::MAX,
348 })
349 .skip(skip_rows_after_header)
350 {
351 rows_count += 1;
352 end_ptr = line.as_ptr() as usize + line.len();
354
355 if line.is_empty() {
356 continue;
357 }
358
359 if is_comment_line(line, parse_options.comment_prefix.as_ref()) {
361 continue;
362 }
363
364 let len = line.len();
365 if len > 1 {
366 let trailing_byte = line[len - 1];
368 if trailing_byte == b'\r' {
369 line = &line[..len - 1];
370 }
371 }
372
373 let mut record = SplitFields::new(
374 line,
375 parse_options.separator,
376 parse_options.quote_char,
377 parse_options.eol_char,
378 );
379
380 for i in 0..header_length {
381 if let Some((slice, needs_escaping)) = record.next() {
382 if slice.is_empty() {
383 unsafe { *nulls.get_unchecked_mut(i) = true };
384 } else {
385 let slice_escaped = if needs_escaping && (slice.len() >= 2) {
386 &slice[1..(slice.len() - 1)]
387 } else {
388 slice
389 };
390 let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
391 let dtype = match &parse_options.null_values {
392 None => Some(infer_field_schema(
393 &s,
394 parse_options.try_parse_dates,
395 parse_options.decimal_comma,
396 )),
397 Some(NullValues::AllColumns(names)) => {
398 if !names.iter().any(|nv| nv == s.as_ref()) {
399 Some(infer_field_schema(
400 &s,
401 parse_options.try_parse_dates,
402 parse_options.decimal_comma,
403 ))
404 } else {
405 None
406 }
407 },
408 Some(NullValues::AllColumnsSingle(name)) => {
409 if s.as_ref() != name.as_str() {
410 Some(infer_field_schema(
411 &s,
412 parse_options.try_parse_dates,
413 parse_options.decimal_comma,
414 ))
415 } else {
416 None
417 }
418 },
419 Some(NullValues::Named(names)) => {
420 let current_name = unsafe { headers.get_unchecked(i) };
423 let null_name = &names.iter().find(|name| name.0 == current_name);
424
425 if let Some(null_name) = null_name {
426 if null_name.1.as_str() != s.as_ref() {
427 Some(infer_field_schema(
428 &s,
429 parse_options.try_parse_dates,
430 parse_options.decimal_comma,
431 ))
432 } else {
433 None
434 }
435 } else {
436 Some(infer_field_schema(
437 &s,
438 parse_options.try_parse_dates,
439 parse_options.decimal_comma,
440 ))
441 }
442 },
443 };
444 if let Some(dtype) = dtype {
445 if matches!(&dtype, DataType::String)
446 && needs_escaping
447 && n_threads.unwrap_or(2) > 1
448 {
449 if slice
455 .iter()
456 .filter(|b| **b == parse_options.eol_char)
457 .count()
458 > 8
459 {
460 if verbose() {
461 eprintln!(
462 "falling back to single core reading because of many escaped new line chars."
463 )
464 }
465 *n_threads = Some(1);
466 }
467 }
468 unsafe { column_types.get_unchecked_mut(i).insert(dtype) };
469 }
470 }
471 }
472 }
473 }
474
475 for i in 0..header_length {
477 let field_name = &headers[i];
478
479 if let Some(schema_overwrite) = schema_overwrite {
480 if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
481 fields.push(Field::new(name.clone(), dtype.clone()));
482 continue;
483 }
484
485 if schema_overwrite.len() == header_length {
488 if let Some((name, dtype)) = schema_overwrite.get_at_index(i) {
489 fields.push(Field::new(name.clone(), dtype.clone()));
490 continue;
491 }
492 }
493 }
494
495 let possibilities = &column_types[i];
496 let dtype = finish_infer_field_schema(possibilities);
497 fields.push(Field::new(field_name.clone(), dtype));
498 }
499 if rows_count == 0
503 && !reader_bytes.is_empty()
504 && reader_bytes[reader_bytes.len() - 1] != parse_options.eol_char
505 && recursion_count == 0
506 {
507 let mut rb = Vec::with_capacity(reader_bytes.len() + 1);
508 rb.extend_from_slice(reader_bytes);
509 rb.push(parse_options.eol_char);
510 return infer_file_schema_inner(
511 &ReaderBytes::Owned(rb.into()),
512 parse_options,
513 max_read_rows,
514 has_header,
515 schema_overwrite,
516 skip_rows,
517 skip_rows_after_header,
518 recursion_count + 1,
519 raise_if_empty,
520 n_threads,
521 );
522 }
523
524 Ok((Schema::from_iter(fields), rows_count, end_ptr - start_ptr))
525}
526
527pub(super) fn check_decimal_comma(decimal_comma: bool, separator: u8) -> PolarsResult<()> {
528 if decimal_comma {
529 polars_ensure!(b',' != separator, InvalidOperation: "'decimal_comma' argument cannot be combined with ',' separator")
530 }
531 Ok(())
532}
533
534#[allow(clippy::too_many_arguments)]
544pub fn infer_file_schema(
545 reader_bytes: &ReaderBytes,
546 parse_options: &CsvParseOptions,
547 max_read_rows: Option<usize>,
548 has_header: bool,
549 schema_overwrite: Option<&Schema>,
550 skip_rows: usize,
551 skip_lines: usize,
552 skip_rows_after_header: usize,
553 raise_if_empty: bool,
554 n_threads: &mut Option<usize>,
555) -> PolarsResult<(Schema, usize, usize)> {
556 check_decimal_comma(parse_options.decimal_comma, parse_options.separator)?;
557
558 if skip_lines > 0 {
559 polars_ensure!(skip_rows == 0, InvalidOperation: "only one of 'skip_rows'/'skip_lines' may be set");
560 let bytes = skip_lines_naive(reader_bytes, parse_options.eol_char, skip_lines);
561 let reader_bytes = ReaderBytes::Borrowed(bytes);
562 infer_file_schema_inner(
563 &reader_bytes,
564 parse_options,
565 max_read_rows,
566 has_header,
567 schema_overwrite,
568 skip_rows,
569 skip_rows_after_header,
570 0,
571 raise_if_empty,
572 n_threads,
573 )
574 } else {
575 infer_file_schema_inner(
576 reader_bytes,
577 parse_options,
578 max_read_rows,
579 has_header,
580 schema_overwrite,
581 skip_rows,
582 skip_rows_after_header,
583 0,
584 raise_if_empty,
585 n_threads,
586 )
587 }
588}