1use std::borrow::Cow;
2
3use polars_core::prelude::*;
4#[cfg(feature = "polars-time")]
5use polars_time::chunkedarray::string::infer as date_infer;
6#[cfg(feature = "polars-time")]
7use polars_time::prelude::string::Pattern;
8use polars_utils::format_pl_smallstr;
9
10use super::parser::{SplitLines, is_comment_line, skip_bom, skip_line_ending};
11use super::splitfields::SplitFields;
12use super::{CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues};
13use crate::csv::read::parser::skip_lines_naive;
14use crate::mmap::ReaderBytes;
15use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
16
17#[derive(Clone, Debug, Default)]
18pub struct SchemaInferenceResult {
19 inferred_schema: SchemaRef,
20 rows_read: usize,
21 bytes_read: usize,
22 bytes_total: usize,
23 n_threads: Option<usize>,
24}
25
26impl SchemaInferenceResult {
27 pub fn try_from_reader_bytes_and_options(
28 reader_bytes: &ReaderBytes,
29 options: &CsvReadOptions,
30 ) -> PolarsResult<Self> {
31 let parse_options = options.get_parse_options();
32
33 let infer_schema_length = options.infer_schema_length;
34 let has_header = options.has_header;
35 let schema_overwrite_arc = options.schema_overwrite.clone();
36 let schema_overwrite = schema_overwrite_arc.as_ref().map(|x| x.as_ref());
37 let skip_rows = options.skip_rows;
38 let skip_lines = options.skip_lines;
39 let skip_rows_after_header = options.skip_rows_after_header;
40 let raise_if_empty = options.raise_if_empty;
41 let n_threads = options.n_threads;
42
43 let bytes_total = reader_bytes.len();
44
45 let (inferred_schema, rows_read, bytes_read) = infer_file_schema(
46 reader_bytes,
47 &parse_options,
48 infer_schema_length,
49 has_header,
50 schema_overwrite,
51 skip_rows,
52 skip_lines,
53 skip_rows_after_header,
54 raise_if_empty,
55 )?;
56
57 let this = Self {
58 inferred_schema: Arc::new(inferred_schema),
59 rows_read,
60 bytes_read,
61 bytes_total,
62 n_threads,
63 };
64
65 Ok(this)
66 }
67
68 pub fn with_inferred_schema(mut self, inferred_schema: SchemaRef) -> Self {
69 self.inferred_schema = inferred_schema;
70 self
71 }
72
73 pub fn get_inferred_schema(&self) -> SchemaRef {
74 self.inferred_schema.clone()
75 }
76
77 pub fn get_estimated_n_rows(&self) -> usize {
78 (self.rows_read as f64 / self.bytes_read as f64 * self.bytes_total as f64) as usize
79 }
80}
81
82impl CsvReadOptions {
83 pub fn update_with_inference_result(&mut self, si_result: &SchemaInferenceResult) {
85 self.n_threads = si_result.n_threads;
86 }
87}
88
89pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
90 match possibilities.len() {
93 1 => possibilities.iter().next().unwrap().clone(),
94 2 if possibilities.contains(&DataType::Int64)
95 && possibilities.contains(&DataType::Float64) =>
96 {
97 DataType::Float64
99 },
100 _ => DataType::String,
102 }
103}
104
105pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
107 let bytes = string.as_bytes();
110 if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {
111 if try_parse_dates {
112 #[cfg(feature = "polars-time")]
113 {
114 match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
115 Some(pattern_with_offset) => match pattern_with_offset {
116 Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
117 DataType::Datetime(TimeUnit::Microseconds, None)
118 },
119 Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
120 Pattern::DatetimeYMDZ => {
121 DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
122 },
123 Pattern::Time => DataType::Time,
124 },
125 None => DataType::String,
126 }
127 }
128 #[cfg(not(feature = "polars-time"))]
129 {
130 panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
131 }
132 } else {
133 DataType::String
134 }
135 }
136 else if BOOLEAN_RE.is_match(string) {
138 DataType::Boolean
139 } else if !decimal_comma && FLOAT_RE.is_match(string)
140 || decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
141 {
142 DataType::Float64
143 } else if INTEGER_RE.is_match(string) {
144 DataType::Int64
145 } else if try_parse_dates {
146 #[cfg(feature = "polars-time")]
147 {
148 match date_infer::infer_pattern_single(string) {
149 Some(pattern_with_offset) => match pattern_with_offset {
150 Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
151 DataType::Datetime(TimeUnit::Microseconds, None)
152 },
153 Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
154 Pattern::DatetimeYMDZ => {
155 DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
156 },
157 Pattern::Time => DataType::Time,
158 },
159 None => DataType::String,
160 }
161 }
162 #[cfg(not(feature = "polars-time"))]
163 {
164 panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
165 }
166 } else {
167 DataType::String
168 }
169}
170
171#[inline]
172fn parse_bytes_with_encoding(bytes: &[u8], encoding: CsvEncoding) -> PolarsResult<Cow<str>> {
173 Ok(match encoding {
174 CsvEncoding::Utf8 => simdutf8::basic::from_utf8(bytes)
175 .map_err(|_| polars_err!(ComputeError: "invalid utf-8 sequence"))?
176 .into(),
177 CsvEncoding::LossyUtf8 => String::from_utf8_lossy(bytes),
178 })
179}
180
181fn column_name(i: usize) -> PlSmallStr {
182 format_pl_smallstr!("column_{}", i + 1)
183}
184
185#[allow(clippy::too_many_arguments)]
186fn infer_file_schema_inner(
187 reader_bytes: &ReaderBytes,
188 parse_options: &CsvParseOptions,
189 max_read_rows: Option<usize>,
190 has_header: bool,
191 schema_overwrite: Option<&Schema>,
192 mut skip_rows: usize,
195 skip_rows_after_header: usize,
196 recursion_count: u8,
197 raise_if_empty: bool,
198) -> PolarsResult<(Schema, usize, usize)> {
199 let start_ptr = reader_bytes.as_ptr() as usize;
201
202 let encoding = CsvEncoding::LossyUtf8;
205
206 let bytes = skip_line_ending(skip_bom(reader_bytes), parse_options.eol_char);
207 if raise_if_empty {
208 polars_ensure!(!bytes.is_empty(), NoData: "empty CSV");
209 };
210 let mut lines = SplitLines::new(
211 bytes,
212 parse_options.quote_char,
213 parse_options.eol_char,
214 parse_options.comment_prefix.as_ref(),
215 )
216 .skip(skip_rows);
217
218 let mut first_line = None;
223
224 for (i, line) in (&mut lines).enumerate() {
225 if !is_comment_line(line, parse_options.comment_prefix.as_ref()) {
226 first_line = Some(line);
227 skip_rows += i;
228 break;
229 }
230 }
231
232 if first_line.is_none() {
233 first_line = lines.next();
234 }
235
236 let mut headers: Vec<PlSmallStr> = if let Some(mut header_line) = first_line {
238 let len = header_line.len();
239 if len > 1 {
240 let trailing_byte = header_line[len - 1];
242 if trailing_byte == b'\r' {
243 header_line = &header_line[..len - 1];
244 }
245 }
246
247 let byterecord = SplitFields::new(
248 header_line,
249 parse_options.separator,
250 parse_options.quote_char,
251 parse_options.eol_char,
252 );
253 if has_header {
254 let headers = byterecord
255 .map(|(slice, needs_escaping)| {
256 let slice_escaped = if needs_escaping && (slice.len() >= 2) {
257 &slice[1..(slice.len() - 1)]
258 } else {
259 slice
260 };
261 let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
262 Ok(s)
263 })
264 .collect::<PolarsResult<Vec<_>>>()?;
265
266 let mut final_headers = Vec::with_capacity(headers.len());
267
268 let mut header_names = PlHashMap::with_capacity(headers.len());
269
270 for name in &headers {
271 let count = header_names.entry(name.as_ref()).or_insert(0usize);
272 if *count != 0 {
273 final_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
274 } else {
275 final_headers.push(PlSmallStr::from_str(name))
276 }
277 *count += 1;
278 }
279 final_headers
280 } else {
281 byterecord
282 .enumerate()
283 .map(|(i, _s)| column_name(i))
284 .collect::<Vec<PlSmallStr>>()
285 }
286 } else if has_header && !bytes.is_empty() && recursion_count == 0 {
287 let mut buf = Vec::with_capacity(bytes.len() + 2);
290 buf.extend_from_slice(bytes);
291 buf.push(parse_options.eol_char);
292
293 return infer_file_schema_inner(
294 &ReaderBytes::Owned(buf.into()),
295 parse_options,
296 max_read_rows,
297 has_header,
298 schema_overwrite,
299 skip_rows,
300 skip_rows_after_header,
301 recursion_count + 1,
302 raise_if_empty,
303 );
304 } else if !raise_if_empty {
305 return Ok((Schema::default(), 0, 0));
306 } else {
307 polars_bail!(NoData: "empty CSV");
308 };
309 if !has_header {
310 lines = SplitLines::new(
312 bytes,
313 parse_options.quote_char,
314 parse_options.eol_char,
315 parse_options.comment_prefix.as_ref(),
316 )
317 .skip(skip_rows);
318 }
319
320 let mut column_types: Vec<PlHashSet<DataType>> =
322 vec![PlHashSet::with_capacity(4); headers.len()];
323 let mut nulls: Vec<bool> = vec![false; headers.len()];
325
326 let mut rows_count = 0;
327 let mut fields = Vec::with_capacity(headers.len());
328
329 let records_ref = &mut lines;
331
332 let mut end_ptr = start_ptr;
333 for mut line in records_ref
334 .take(match max_read_rows {
335 Some(max_read_rows) => {
336 if max_read_rows <= (usize::MAX - skip_rows_after_header) {
337 max_read_rows + skip_rows_after_header
341 } else {
342 max_read_rows
343 }
344 },
345 None => usize::MAX,
346 })
347 .skip(skip_rows_after_header)
348 {
349 rows_count += 1;
350 end_ptr = line.as_ptr() as usize + line.len();
352
353 if line.is_empty() {
354 continue;
355 }
356
357 if is_comment_line(line, parse_options.comment_prefix.as_ref()) {
359 continue;
360 }
361
362 let len = line.len();
363 if len > 1 {
364 let trailing_byte = line[len - 1];
366 if trailing_byte == b'\r' {
367 line = &line[..len - 1];
368 }
369 }
370
371 let record = SplitFields::new(
372 line,
373 parse_options.separator,
374 parse_options.quote_char,
375 parse_options.eol_char,
376 );
377
378 for (i, (slice, needs_escaping)) in record.enumerate() {
379 if i >= headers.len() {
382 if !has_header {
383 headers.push(column_name(i));
384 column_types.push(Default::default());
385 nulls.push(false);
386 } else {
387 break;
388 }
389 }
390
391 if slice.is_empty() {
392 unsafe { *nulls.get_unchecked_mut(i) = true };
393 } else {
394 let slice_escaped = if needs_escaping && (slice.len() >= 2) {
395 &slice[1..(slice.len() - 1)]
396 } else {
397 slice
398 };
399 let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
400 let dtype = match &parse_options.null_values {
401 None => Some(infer_field_schema(
402 &s,
403 parse_options.try_parse_dates,
404 parse_options.decimal_comma,
405 )),
406 Some(NullValues::AllColumns(names)) => {
407 if !names.iter().any(|nv| nv == s.as_ref()) {
408 Some(infer_field_schema(
409 &s,
410 parse_options.try_parse_dates,
411 parse_options.decimal_comma,
412 ))
413 } else {
414 None
415 }
416 },
417 Some(NullValues::AllColumnsSingle(name)) => {
418 if s.as_ref() != name.as_str() {
419 Some(infer_field_schema(
420 &s,
421 parse_options.try_parse_dates,
422 parse_options.decimal_comma,
423 ))
424 } else {
425 None
426 }
427 },
428 Some(NullValues::Named(names)) => {
429 let current_name = unsafe { headers.get_unchecked(i) };
432 let null_name = &names.iter().find(|name| name.0 == current_name);
433
434 if let Some(null_name) = null_name {
435 if null_name.1.as_str() != s.as_ref() {
436 Some(infer_field_schema(
437 &s,
438 parse_options.try_parse_dates,
439 parse_options.decimal_comma,
440 ))
441 } else {
442 None
443 }
444 } else {
445 Some(infer_field_schema(
446 &s,
447 parse_options.try_parse_dates,
448 parse_options.decimal_comma,
449 ))
450 }
451 },
452 };
453 if let Some(dtype) = dtype {
454 unsafe { column_types.get_unchecked_mut(i).insert(dtype) };
455 }
456 }
457 }
458 }
459
460 for i in 0..headers.len() {
462 let field_name = &headers[i];
463
464 if let Some(schema_overwrite) = schema_overwrite {
465 if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
466 fields.push(Field::new(name.clone(), dtype.clone()));
467 continue;
468 }
469
470 if schema_overwrite.len() == headers.len() {
473 if let Some((name, dtype)) = schema_overwrite.get_at_index(i) {
474 fields.push(Field::new(name.clone(), dtype.clone()));
475 continue;
476 }
477 }
478 }
479
480 let possibilities = &column_types[i];
481 let dtype = finish_infer_field_schema(possibilities);
482 fields.push(Field::new(field_name.clone(), dtype));
483 }
484 if rows_count == 0
488 && !reader_bytes.is_empty()
489 && reader_bytes[reader_bytes.len() - 1] != parse_options.eol_char
490 && recursion_count == 0
491 {
492 let mut rb = Vec::with_capacity(reader_bytes.len() + 1);
493 rb.extend_from_slice(reader_bytes);
494 rb.push(parse_options.eol_char);
495 return infer_file_schema_inner(
496 &ReaderBytes::Owned(rb.into()),
497 parse_options,
498 max_read_rows,
499 has_header,
500 schema_overwrite,
501 skip_rows,
502 skip_rows_after_header,
503 recursion_count + 1,
504 raise_if_empty,
505 );
506 }
507
508 Ok((Schema::from_iter(fields), rows_count, end_ptr - start_ptr))
509}
510
511pub(super) fn check_decimal_comma(decimal_comma: bool, separator: u8) -> PolarsResult<()> {
512 if decimal_comma {
513 polars_ensure!(b',' != separator, InvalidOperation: "'decimal_comma' argument cannot be combined with ',' separator")
514 }
515 Ok(())
516}
517
518#[allow(clippy::too_many_arguments)]
528pub fn infer_file_schema(
529 reader_bytes: &ReaderBytes,
530 parse_options: &CsvParseOptions,
531 max_read_rows: Option<usize>,
532 has_header: bool,
533 schema_overwrite: Option<&Schema>,
534 skip_rows: usize,
535 skip_lines: usize,
536 skip_rows_after_header: usize,
537 raise_if_empty: bool,
538) -> PolarsResult<(Schema, usize, usize)> {
539 check_decimal_comma(parse_options.decimal_comma, parse_options.separator)?;
540
541 if skip_lines > 0 {
542 polars_ensure!(skip_rows == 0, InvalidOperation: "only one of 'skip_rows'/'skip_lines' may be set");
543 let bytes = skip_lines_naive(reader_bytes, parse_options.eol_char, skip_lines);
544 let reader_bytes = ReaderBytes::Borrowed(bytes);
545 infer_file_schema_inner(
546 &reader_bytes,
547 parse_options,
548 max_read_rows,
549 has_header,
550 schema_overwrite,
551 skip_rows,
552 skip_rows_after_header,
553 0,
554 raise_if_empty,
555 )
556 } else {
557 infer_file_schema_inner(
558 reader_bytes,
559 parse_options,
560 max_read_rows,
561 has_header,
562 schema_overwrite,
563 skip_rows,
564 skip_rows_after_header,
565 0,
566 raise_if_empty,
567 )
568 }
569}