1use polars_buffer::Buffer;
2use polars_core::prelude::*;
3#[cfg(feature = "polars-time")]
4use polars_time::chunkedarray::string::infer as date_infer;
5#[cfg(feature = "polars-time")]
6use polars_time::prelude::string::Pattern;
7use polars_utils::format_pl_smallstr;
8
9use super::splitfields::SplitFields;
10use super::{CsvParseOptions, NullValues};
11use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
12
13#[allow(clippy::too_many_arguments)]
17pub(super) fn infer_file_schema_impl(
18 header_line: &Option<Buffer<u8>>,
19 content_lines: &[Buffer<u8>],
20 infer_all_as_str: bool,
21 parse_options: &CsvParseOptions,
22 schema_overwrite: Option<&Schema>,
23) -> Schema {
24 let mut headers = header_line
25 .as_ref()
26 .map(|line| infer_headers(line, parse_options))
27 .unwrap_or_else(|| Vec::with_capacity(8));
28
29 let extend_header_with_unknown_column = header_line.is_none();
30
31 let mut column_types = vec![PlHashSet::<DataType>::with_capacity(4); headers.len()];
32 let mut nulls = vec![false; headers.len()];
33
34 for content_line in content_lines {
35 infer_types_from_line(
36 content_line,
37 infer_all_as_str,
38 &mut headers,
39 extend_header_with_unknown_column,
40 parse_options,
41 &mut column_types,
42 &mut nulls,
43 );
44 }
45
46 build_schema(&headers, &column_types, schema_overwrite)
47}
48
49fn infer_headers(mut header_line: &[u8], parse_options: &CsvParseOptions) -> Vec<PlSmallStr> {
50 let len = header_line.len();
51
52 if header_line.last().copied() == Some(b'\r') {
53 header_line = &header_line[..len - 1];
54 }
55
56 let byterecord = SplitFields::new(
57 header_line,
58 parse_options.separator,
59 parse_options.quote_char,
60 parse_options.eol_char,
61 );
62
63 let headers = byterecord
64 .map(|(slice, needs_escaping)| {
65 let slice_escaped = if needs_escaping && (slice.len() >= 2) {
66 &slice[1..(slice.len() - 1)]
67 } else {
68 slice
69 };
70 String::from_utf8_lossy(slice_escaped)
71 })
72 .collect::<Vec<_>>();
73
74 let mut deduplicated_headers = Vec::with_capacity(headers.len());
75 let mut header_names = PlHashMap::with_capacity(headers.len());
76
77 for name in &headers {
78 let count = header_names.entry(name.as_ref()).or_insert(0usize);
79 if *count != 0 {
80 deduplicated_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
81 } else {
82 deduplicated_headers.push(PlSmallStr::from_str(name))
83 }
84 *count += 1;
85 }
86
87 deduplicated_headers
88}
89
90fn infer_types_from_line(
91 mut line: &[u8],
92 infer_all_as_str: bool,
93 headers: &mut Vec<PlSmallStr>,
94 extend_header_with_unknown_column: bool,
95 parse_options: &CsvParseOptions,
96 column_types: &mut Vec<PlHashSet<DataType>>,
97 nulls: &mut Vec<bool>,
98) {
99 let line_len = line.len();
100 if line.last().copied() == Some(b'\r') {
101 line = &line[..line_len - 1];
102 }
103
104 let record = SplitFields::new(
105 line,
106 parse_options.separator,
107 parse_options.quote_char,
108 parse_options.eol_char,
109 );
110
111 for (i, (slice, needs_escaping)) in record.enumerate() {
112 if i >= headers.len() {
113 if extend_header_with_unknown_column {
114 headers.push(column_name(i));
115 column_types.push(Default::default());
116 nulls.push(false);
117 } else {
118 break;
119 }
120 }
121
122 if infer_all_as_str {
123 column_types[i].insert(DataType::String);
124 continue;
125 }
126
127 if slice.is_empty() {
128 nulls[i] = true;
129 } else {
130 let slice_escaped = if needs_escaping && (slice.len() >= 2) {
131 &slice[1..(slice.len() - 1)]
132 } else {
133 slice
134 };
135 let s = String::from_utf8_lossy(slice_escaped);
136 let dtype = match &parse_options.null_values {
137 None => Some(infer_field_schema(
138 &s,
139 parse_options.try_parse_dates,
140 parse_options.decimal_comma,
141 )),
142 Some(NullValues::AllColumns(names)) => {
143 if !names.iter().any(|nv| nv == s.as_ref()) {
144 Some(infer_field_schema(
145 &s,
146 parse_options.try_parse_dates,
147 parse_options.decimal_comma,
148 ))
149 } else {
150 None
151 }
152 },
153 Some(NullValues::AllColumnsSingle(name)) => {
154 if s.as_ref() != name.as_str() {
155 Some(infer_field_schema(
156 &s,
157 parse_options.try_parse_dates,
158 parse_options.decimal_comma,
159 ))
160 } else {
161 None
162 }
163 },
164 Some(NullValues::Named(names)) => {
165 let current_name = &headers[i];
166 let null_name = &names.iter().find(|name| name.0 == current_name);
167
168 if let Some(null_name) = null_name {
169 if null_name.1.as_str() != s.as_ref() {
170 Some(infer_field_schema(
171 &s,
172 parse_options.try_parse_dates,
173 parse_options.decimal_comma,
174 ))
175 } else {
176 None
177 }
178 } else {
179 Some(infer_field_schema(
180 &s,
181 parse_options.try_parse_dates,
182 parse_options.decimal_comma,
183 ))
184 }
185 },
186 };
187 if let Some(dtype) = dtype {
188 column_types[i].insert(dtype);
189 }
190 }
191 }
192}
193
194fn build_schema(
195 headers: &[PlSmallStr],
196 column_types: &[PlHashSet<DataType>],
197 schema_overwrite: Option<&Schema>,
198) -> Schema {
199 assert!(headers.len() == column_types.len());
200
201 let get_schema_overwrite = |field_name| {
202 if let Some(schema_overwrite) = schema_overwrite {
203 if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
206 return Some((name.clone(), dtype.clone()));
207 }
208 }
209
210 None
211 };
212
213 Schema::from_iter(
214 headers
215 .iter()
216 .zip(column_types)
217 .map(|(field_name, type_possibilities)| {
218 let (name, dtype) = get_schema_overwrite(field_name).unwrap_or_else(|| {
219 (
220 field_name.clone(),
221 finish_infer_field_schema(type_possibilities),
222 )
223 });
224
225 Field::new(name, dtype)
226 }),
227 )
228}
229
230pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
231 match possibilities.len() {
234 1 => possibilities.iter().next().unwrap().clone(),
235 2 if possibilities.contains(&DataType::Int64)
236 && possibilities.contains(&DataType::Float64) =>
237 {
238 DataType::Float64
240 },
241 _ => DataType::String,
243 }
244}
245
246pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
248 let bytes = string.as_bytes();
251 if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {
252 if try_parse_dates {
253 #[cfg(feature = "polars-time")]
254 {
255 match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
256 Some(pattern_with_offset) => match pattern_with_offset {
257 Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
258 DataType::Datetime(TimeUnit::Microseconds, None)
259 },
260 Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
261 Pattern::DatetimeYMDZ => {
262 DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
263 },
264 Pattern::Time => DataType::Time,
265 },
266 None => DataType::String,
267 }
268 }
269 #[cfg(not(feature = "polars-time"))]
270 {
271 panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
272 }
273 } else {
274 DataType::String
275 }
276 }
277 else if BOOLEAN_RE.is_match(string) {
279 DataType::Boolean
280 } else if !decimal_comma && FLOAT_RE.is_match(string)
281 || decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
282 {
283 DataType::Float64
284 } else if INTEGER_RE.is_match(string) {
285 DataType::Int64
286 } else if try_parse_dates {
287 #[cfg(feature = "polars-time")]
288 {
289 match date_infer::infer_pattern_single(string) {
290 Some(pattern_with_offset) => match pattern_with_offset {
291 Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
292 DataType::Datetime(TimeUnit::Microseconds, None)
293 },
294 Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
295 Pattern::DatetimeYMDZ => {
296 DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
297 },
298 Pattern::Time => DataType::Time,
299 },
300 None => DataType::String,
301 }
302 }
303 #[cfg(not(feature = "polars-time"))]
304 {
305 panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
306 }
307 } else {
308 DataType::String
309 }
310}
311
312fn column_name(i: usize) -> PlSmallStr {
313 format_pl_smallstr!("column_{}", i + 1)
314}