Skip to main content

polars_core/utils/
supertype.rs

1use bitflags::bitflags;
2use num_traits::Signed;
3#[cfg(feature = "dtype-decimal")]
4use polars_compute::decimal::{DEC128_MAX_PREC, i128_to_dec128};
5
6use super::*;
7
8/// Given two data types, determine the data type that both types can safely be cast to.
9///
10/// Returns a [`PolarsError::ComputeError`] if no such data type exists.
11pub fn try_get_supertype(l: &DataType, r: &DataType) -> PolarsResult<DataType> {
12    get_supertype(l, r).ok_or_else(
13        || polars_err!(SchemaMismatch: "failed to determine supertype of {} and {}", l, r),
14    )
15}
16
17pub fn try_get_supertype_with_options(
18    l: &DataType,
19    r: &DataType,
20    options: SuperTypeOptions,
21) -> PolarsResult<DataType> {
22    get_supertype_with_options(l, r, options).ok_or_else(
23        || polars_err!(SchemaMismatch: "failed to determine supertype of {} and {}", l, r),
24    )
25}
26
27/// Returns the smallest numeric dtype that `l` and `r` can both be casted to without loss of data.
28/// If no such dtype exists, or `l` and `r` are already matching, returns `None`.
29pub fn get_numeric_upcast_supertype_lossless(l: &DataType, r: &DataType) -> Option<DataType> {
30    use DataType::*;
31
32    if l.is_unknown() || r.is_unknown() || l.is_nested() || r.is_nested() || l == r {
33        None
34    } else if l.is_float() && r.is_float() {
35        match (l, r) {
36            (Float64, _) | (_, Float64) => Some(Float64),
37            (Float32, _) | (_, Float32) => Some(Float32),
38            v => {
39                // Did we add a new float type?
40                if cfg!(debug_assertions) {
41                    panic!("{v:?}")
42                } else {
43                    None
44                }
45            },
46        }
47    } else if l.is_signed_integer() && r.is_signed_integer() {
48        match (l, r) {
49            (Int128, _) | (_, Int128) => Some(Int128),
50            (Int64, _) | (_, Int64) => Some(Int64),
51            (Int32, _) | (_, Int32) => Some(Int32),
52            (Int16, _) | (_, Int16) => Some(Int16),
53            (Int8, _) | (_, Int8) => Some(Int8),
54            v => {
55                if cfg!(debug_assertions) {
56                    panic!("{v:?}")
57                } else {
58                    None
59                }
60            },
61        }
62    } else if l.is_unsigned_integer() && r.is_unsigned_integer() {
63        match (l, r) {
64            (UInt128, _) | (_, UInt128) => Some(UInt128),
65            (UInt64, _) | (_, UInt64) => Some(UInt64),
66            (UInt32, _) | (_, UInt32) => Some(UInt32),
67            (UInt16, _) | (_, UInt16) => Some(UInt16),
68            (UInt8, _) | (_, UInt8) => Some(UInt8),
69            v => {
70                if cfg!(debug_assertions) {
71                    panic!("{v:?}")
72                } else {
73                    None
74                }
75            },
76        }
77    } else if l.is_integer() && r.is_integer() {
78        // One side is signed, the other is unsigned. We just need to upcast the
79        // unsigned side to a signed integer with the next-largest bit width.
80        match (l, r) {
81            (UInt128, _) | (_, UInt128) => None, // No lossless cast possible
82            (UInt64, _) | (_, UInt64) | (Int128, _) | (_, Int128) => Some(Int128),
83            (UInt32, _) | (_, UInt32) | (Int64, _) | (_, Int64) => Some(Int64),
84            (UInt16, _) | (_, UInt16) | (Int32, _) | (_, Int32) => Some(Int32),
85            (UInt8, _) | (_, UInt8) | (Int16, _) | (_, Int16) => Some(Int16),
86            v => {
87                // One side was UInt and we should have already matched against
88                // all the UInt types
89                if cfg!(debug_assertions) {
90                    panic!("{v:?}")
91                } else {
92                    None
93                }
94            },
95        }
96    } else {
97        None
98    }
99}
100
101bitflags! {
102    #[repr(transparent)]
103    #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
104    pub struct SuperTypeFlags: u8 {
105        /// Implode lists to match nesting types.
106        const ALLOW_IMPLODE_LIST = 1 << 0;
107        /// Allow casting of primitive types (numeric, bools) to strings
108        const ALLOW_PRIMITIVE_TO_STRING = 1 << 1;
109    }
110}
111
112impl Default for SuperTypeFlags {
113    fn default() -> Self {
114        SuperTypeFlags::from_bits_truncate(0) | SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING
115    }
116}
117
118#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Default)]
119pub struct SuperTypeOptions {
120    pub flags: SuperTypeFlags,
121}
122
123impl From<SuperTypeFlags> for SuperTypeOptions {
124    fn from(flags: SuperTypeFlags) -> Self {
125        SuperTypeOptions { flags }
126    }
127}
128
129impl SuperTypeOptions {
130    pub fn allow_implode_list(&self) -> bool {
131        self.flags.contains(SuperTypeFlags::ALLOW_IMPLODE_LIST)
132    }
133
134    pub fn allow_primitive_to_string(&self) -> bool {
135        self.flags
136            .contains(SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING)
137    }
138}
139
140pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
141    get_supertype_with_options(l, r, SuperTypeOptions::default())
142}
143
144/// Given two data types, determine the data type that both types can reasonably safely be cast to.
145///
146///
147/// Returns [`None`] if no such data type exists.
148pub fn get_supertype_with_options(
149    l: &DataType,
150    r: &DataType,
151    options: SuperTypeOptions,
152) -> Option<DataType> {
153    fn inner(l: &DataType, r: &DataType, options: SuperTypeOptions) -> Option<DataType> {
154        use DataType::*;
155        if l == r {
156            return Some(l.clone());
157        }
158        match (l, r) {
159            #[cfg(feature = "dtype-i8")]
160            (Int8, Boolean) => Some(Int8),
161            //(Int8, Int8) => Some(Int8),
162            #[cfg(all(feature = "dtype-i8", feature = "dtype-i16"))]
163            (Int8, Int16) => Some(Int16),
164            #[cfg(feature = "dtype-i8")]
165            (Int8, Int32) => Some(Int32),
166            #[cfg(feature = "dtype-i8")]
167            (Int8, Int64) => Some(Int64),
168            #[cfg(all(feature = "dtype-i8", feature = "dtype-i16"))]
169            (Int8, UInt8) => Some(Int16),
170            #[cfg(all(feature = "dtype-i8", feature = "dtype-u16"))]
171            (Int8, UInt16) => Some(Int32),
172            #[cfg(feature = "dtype-i8")]
173            (Int8, UInt32) => Some(Int64),
174            #[cfg(feature = "dtype-i8")]
175            (Int8, UInt64) => Some(Float64), // Follow numpy
176            #[cfg(all(feature = "dtype-i8", feature = "dtype-f16"))]
177            (Int8, Float16) => Some(Float16),
178            #[cfg(feature = "dtype-i8")]
179            (Int8, Float32) => Some(Float32),
180            #[cfg(feature = "dtype-i8")]
181            (Int8, Float64) => Some(Float64),
182
183            #[cfg(feature = "dtype-i16")]
184            (Int16, Boolean) => Some(Int16),
185            #[cfg(all(feature = "dtype-i16", feature = "dtype-i8"))]
186            (Int16, Int8) => Some(Int16),
187            //(Int16, Int16) => Some(Int16),
188            #[cfg(feature = "dtype-i16")]
189            (Int16, Int32) => Some(Int32),
190            #[cfg(feature = "dtype-i16")]
191            (Int16, Int64) => Some(Int64),
192            #[cfg(all(feature = "dtype-i16", feature = "dtype-u8"))]
193            (Int16, UInt8) => Some(Int16),
194            #[cfg(all(feature = "dtype-i16", feature = "dtype-u16"))]
195            (Int16, UInt16) => Some(Int32),
196            #[cfg(feature = "dtype-i16")]
197            (Int16, UInt32) => Some(Int64),
198            #[cfg(feature = "dtype-i16")]
199            (Int16, UInt64) => Some(Float64), // Follow numpy
200            #[cfg(all(feature = "dtype-i16", feature = "dtype-f16"))]
201            (Int16, Float16) => Some(Float32),
202            #[cfg(feature = "dtype-i16")]
203            (Int16, Float32) => Some(Float32),
204            #[cfg(feature = "dtype-i16")]
205            (Int16, Float64) => Some(Float64),
206
207            (Int32, Boolean) => Some(Int32),
208            #[cfg(feature = "dtype-i8")]
209            (Int32, Int8) => Some(Int32),
210            #[cfg(feature = "dtype-i16")]
211            (Int32, Int16) => Some(Int32),
212            //(Int32, Int32) => Some(Int32),
213            (Int32, Int64) => Some(Int64),
214            #[cfg(feature = "dtype-u8")]
215            (Int32, UInt8) => Some(Int32),
216            #[cfg(feature = "dtype-u16")]
217            (Int32, UInt16) => Some(Int32),
218            (Int32, UInt32) => Some(Int64),
219            #[cfg(not(feature = "bigidx"))]
220            (Int32, UInt64) => Some(Float64), // Follow numpy
221            #[cfg(feature = "bigidx")]
222            (Int32, UInt64) => Some(Int64), // Needed for bigidx
223            #[cfg(feature = "dtype-f16")]
224            (Int32, Float16) => Some(Float64),
225            (Int32, Float32) => Some(Float64), // Follow numpy
226            (Int32, Float64) => Some(Float64),
227
228            (Int64, Boolean) => Some(Int64),
229            #[cfg(feature = "dtype-i8")]
230            (Int64, Int8) => Some(Int64),
231            #[cfg(feature = "dtype-i16")]
232            (Int64, Int16) => Some(Int64),
233            (Int64, Int32) => Some(Int64),
234            //(Int64, Int64) => Some(Int64),
235            #[cfg(feature = "dtype-u8")]
236            (Int64, UInt8) => Some(Int64),
237            #[cfg(feature = "dtype-u16")]
238            (Int64, UInt16) => Some(Int64),
239            (Int64, UInt32) => Some(Int64),
240            #[cfg(not(feature = "bigidx"))]
241            (Int64, UInt64) => Some(Float64), // Follow numpy
242            #[cfg(feature = "bigidx")]
243            (Int64, UInt64) => Some(Int64), // Needed for bigidx
244            #[cfg(feature = "dtype-f16")]
245            (Int64, Float16) => Some(Float64), // Follow (Int64, Float32) case
246            (Int64, Float32) => Some(Float64), // Follow numpy
247            (Int64, Float64) => Some(Float64),
248
249            #[cfg(feature = "dtype-i128")]
250            (Int128, a) if a.is_integer() | a.is_bool() => Some(Int128),
251            #[cfg(feature = "dtype-i128")]
252            (Int128, a) if a.is_float() => Some(Float64),
253
254            #[cfg(feature = "dtype-u8")]
255            (UInt8, UInt32) => Some(UInt32),
256            #[cfg(feature = "dtype-u8")]
257            (UInt8, UInt64) => Some(UInt64),
258
259            #[cfg(all(feature = "dtype-u16", feature = "dtype-u8"))]
260            (UInt16, UInt8) => Some(UInt16),
261            #[cfg(feature = "dtype-u16")]
262            (UInt16, UInt32) => Some(UInt32),
263            #[cfg(feature = "dtype-u16")]
264            (UInt16, UInt64) => Some(UInt64),
265
266            (UInt32, UInt64) => Some(UInt64),
267
268            #[cfg(feature = "dtype-u128")]
269            (UInt128, a) if a.is_unsigned_integer() || a.is_bool() => Some(UInt128),
270            #[cfg(feature = "dtype-u128")]
271            (UInt128, a) if a.is_signed_integer() => Some(Int128),
272            #[cfg(feature = "dtype-u128")]
273            (UInt128, a) if a.is_float() => Some(Float64),
274
275            #[cfg(feature = "dtype-u8")]
276            (Boolean, UInt8) => Some(UInt8),
277            #[cfg(feature = "dtype-u16")]
278            (Boolean, UInt16) => Some(UInt16),
279            (Boolean, UInt32) => Some(UInt32),
280            (Boolean, UInt64) => Some(UInt64),
281
282            #[cfg(all(feature = "dtype-f16", feature = "dtype-u8"))]
283            (Float16, UInt8) => Some(Float16),
284            #[cfg(all(feature = "dtype-f16", feature = "dtype-u16"))]
285            (Float16, UInt16) => Some(Float32),
286            #[cfg(feature = "dtype-f16")]
287            (Float16, UInt32) => Some(Float64),
288            #[cfg(feature = "dtype-f16")]
289            (Float16, UInt64) => Some(Float64),
290
291            #[cfg(feature = "dtype-u8")]
292            (Float32, UInt8) => Some(Float32),
293            #[cfg(feature = "dtype-u16")]
294            (Float32, UInt16) => Some(Float32),
295            (Float32, UInt32) => Some(Float64),
296            (Float32, UInt64) => Some(Float64),
297
298            #[cfg(feature = "dtype-u8")]
299            (Float64, UInt8) => Some(Float64),
300            #[cfg(feature = "dtype-u16")]
301            (Float64, UInt16) => Some(Float64),
302            (Float64, UInt32) => Some(Float64),
303            (Float64, UInt64) => Some(Float64),
304
305            #[cfg(feature = "dtype-f16")]
306            (Float16, Float32) => Some(Float32),
307            #[cfg(feature = "dtype-f16")]
308            (Float16, Float64) => Some(Float64),
309            (Float32, Float64) => Some(Float64),
310            #[cfg(feature = "dtype-f16")]
311            (Float32, Float16) => Some(Float32),
312            #[cfg(feature = "dtype-f16")]
313            (Float64, Float16) => Some(Float64),
314            (Float64, Float32) => Some(Float64),
315
316            // Time related dtypes
317            #[cfg(feature = "dtype-date")]
318            (Date, UInt32) => Some(Int64),
319            #[cfg(feature = "dtype-date")]
320            (Date, UInt64) => Some(Int64),
321            #[cfg(feature = "dtype-date")]
322            (Date, Int32) => Some(Int32),
323            #[cfg(feature = "dtype-date")]
324            (Date, Int64) => Some(Int64),
325            #[cfg(all(feature = "dtype-date", feature = "dtype-f16"))]
326            (Date, Float16) => Some(Float32),
327            #[cfg(feature = "dtype-date")]
328            (Date, Float32) => Some(Float32),
329            #[cfg(feature = "dtype-date")]
330            (Date, Float64) => Some(Float64),
331            #[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))]
332            (Date, Datetime(tu, tz)) => Some(Datetime(*tu, tz.clone())),
333
334            #[cfg(feature = "dtype-datetime")]
335            (Datetime(_, _), UInt32) => Some(Int64),
336            #[cfg(feature = "dtype-datetime")]
337            (Datetime(_, _), UInt64) => Some(Int64),
338            #[cfg(feature = "dtype-datetime")]
339            (Datetime(_, _), Int32) => Some(Int64),
340            #[cfg(feature = "dtype-datetime")]
341            (Datetime(_, _), Int64) => Some(Int64),
342            #[cfg(all(feature = "dtype-datetime", feature = "dtype-f16"))]
343            (Datetime(_, _), Float16) => Some(Float64),
344            #[cfg(feature = "dtype-datetime")]
345            (Datetime(_, _), Float32) => Some(Float64),
346            #[cfg(feature = "dtype-datetime")]
347            (Datetime(_, _), Float64) => Some(Float64),
348            #[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))]
349            (Datetime(tu, tz), Date) => Some(Datetime(*tu, tz.clone())),
350
351            #[cfg(feature = "dtype-f16")]
352            (Boolean, Float16) => Some(Float16),
353            (Boolean, Float32) => Some(Float32),
354            (Boolean, Float64) => Some(Float64),
355
356            #[cfg(feature = "dtype-duration")]
357            (Duration(_), UInt32) => Some(Int64),
358            #[cfg(feature = "dtype-duration")]
359            (Duration(_), UInt64) => Some(Int64),
360            #[cfg(feature = "dtype-duration")]
361            (Duration(_), Int32) => Some(Int64),
362            #[cfg(feature = "dtype-duration")]
363            (Duration(_), Int64) => Some(Int64),
364            #[cfg(all(feature = "dtype-duration", feature = "dtype-f16"))]
365            (Duration(_), Float16) => Some(Float64),
366            #[cfg(feature = "dtype-duration")]
367            (Duration(_), Float32) => Some(Float64),
368            #[cfg(feature = "dtype-duration")]
369            (Duration(_), Float64) => Some(Float64),
370
371            #[cfg(feature = "dtype-time")]
372            (Time, Int32) => Some(Int64),
373            #[cfg(feature = "dtype-time")]
374            (Time, Int64) => Some(Int64),
375            #[cfg(all(feature = "dtype-time", feature = "dtype-f16"))]
376            (Time, Float16) => Some(Float64),
377            #[cfg(feature = "dtype-time")]
378            (Time, Float32) => Some(Float64),
379            #[cfg(feature = "dtype-time")]
380            (Time, Float64) => Some(Float64),
381
382            // Every known type can be cast to a string except binary
383            (dt, String) if !matches!(dt, Unknown(UnknownKind::Any)) && dt != &Binary && options.allow_primitive_to_string() || !dt.to_physical().is_primitive() => Some(String),
384            (String, Binary) => Some(Binary),
385            (dt, Null) => Some(dt.clone()),
386
387            #[cfg(all(feature = "dtype-duration", feature = "dtype-datetime"))]
388            (Duration(lu), Datetime(ru, Some(tz))) | (Datetime(lu, Some(tz)), Duration(ru)) => {
389                if tz.is_empty() {
390                    Some(Datetime(get_time_units(lu, ru), None))
391                } else {
392                    Some(Datetime(get_time_units(lu, ru), Some(tz.clone())))
393                }
394            }
395            #[cfg(all(feature = "dtype-duration", feature = "dtype-datetime"))]
396            (Duration(lu), Datetime(ru, None)) | (Datetime(lu, None), Duration(ru)) => {
397                Some(Datetime(get_time_units(lu, ru), None))
398            }
399            #[cfg(all(feature = "dtype-duration", feature = "dtype-date"))]
400            (Duration(_), Date) | (Date, Duration(_)) => Some(Date),
401            #[cfg(feature = "dtype-duration")]
402            (Duration(lu), Duration(ru)) => Some(Duration(get_time_units(lu, ru))),
403
404            // both None or both Some("<tz>") timezones
405            // we cast from more precision to higher precision as that always fits with occasional loss of precision
406            #[cfg(feature = "dtype-datetime")]
407            (Datetime(tu_l, tz_l), Datetime(tu_r, tz_r)) if
408                // both are none
409                (tz_l.is_none() && tz_r.is_none())
410                // both have the same time zone
411                || (tz_l.is_some() && (tz_l == tz_r)) => {
412                let tu = get_time_units(tu_l, tu_r);
413                Some(Datetime(tu, tz_r.clone()))
414            }
415            (List(inner_left), List(inner_right)) => {
416                let st = get_supertype(inner_left, inner_right)?;
417                Some(List(Box::new(st)))
418            }
419            #[cfg(feature = "dtype-array")]
420            (List(inner_left), Array(inner_right, _)) | (Array(inner_left, _), List(inner_right)) => {
421                let st = get_supertype(inner_left, inner_right)?;
422                Some(List(Box::new(st)))
423            }
424            #[cfg(feature = "dtype-array")]
425            (Array(inner_left, width_left), Array(inner_right, width_right)) if *width_left == *width_right => {
426                let st = get_supertype(inner_left, inner_right)?;
427                Some(Array(Box::new(st), *width_left))
428            }
429            (List(inner), other) | (other, List(inner)) if options.allow_implode_list() => {
430                let st = get_supertype(inner, other)?;
431                Some(List(Box::new(st)))
432            }
433            #[cfg(feature = "dtype-array")]
434            (Array(inner_left, _), Array(inner_right, _)) => {
435                let st = get_supertype(inner_left, inner_right)?;
436                Some(List(Box::new(st)))
437            }
438            #[cfg(feature = "dtype-struct")]
439            (Struct(inner), right @ Unknown(UnknownKind::Float | UnknownKind::Int(_))) => {
440                match inner.first() {
441                    Some(inner) => get_supertype(&inner.dtype, right),
442                    None => None
443                }
444            },
445            (dt, Unknown(kind)) => {
446                match kind {
447                    UnknownKind::Float | UnknownKind::Int(_) if  dt.is_string() => {
448                        if options.allow_primitive_to_string() {
449                            Some(dt.clone())
450                        } else {
451                            None
452                        }
453                    },
454                    // Materialize float to float
455                    UnknownKind::Float | UnknownKind::Int(_) if dt.is_float() => Some(dt.clone()),
456                    UnknownKind::Float if dt.is_integer() | dt.is_decimal() => Some(Unknown(UnknownKind::Float)),
457                    // Materialize str
458                    UnknownKind::Str if dt.is_string() | dt.is_enum() => Some(dt.clone()),
459                    // Materialize str
460                    #[cfg(feature = "dtype-categorical")]
461                    UnknownKind::Str if dt.is_categorical() => Some(dt.clone()),
462                    // Keep unknown
463                    dynam if dt.is_null() => Some(Unknown(*dynam)),
464                    // Find integers sizes
465                    UnknownKind::Int(v) if dt.is_primitive_numeric() => {
466                        // Both dyn int
467                        if let Unknown(UnknownKind::Int(v_other)) = dt {
468                            // Take the maximum value to ensure we bubble up the required minimal size.
469                            Some(Unknown(UnknownKind::Int(std::cmp::max(*v, *v_other))))
470                        }
471                        // dyn int vs number
472                        else {
473                            let smallest_fitting_dtype = if dt.is_unsigned_integer() && !v.is_negative() {
474                                materialize_dyn_int_pos(*v).dtype()
475                            } else {
476                                materialize_smallest_dyn_int(*v).dtype()
477                            };
478                            match dt {
479                                UInt64 if smallest_fitting_dtype.is_signed_integer() => {
480                                    // Ensure we don't cast to float when dealing with dynamic literals
481                                    Some(Int64)
482                                },
483                                _ => {
484                                    get_supertype(dt, &smallest_fitting_dtype)
485                                }
486                            }
487                        }
488                    }
489                    UnknownKind::Int(v) if dt.is_bool() => {
490                        let int_dtype = materialize_dyn_int(*v).dtype();
491                        get_supertype(dt, &int_dtype)
492                    },
493                    #[cfg(feature = "dtype-decimal")]
494                    UnknownKind::Int(_) if dt.is_decimal() => {
495                        let DataType::Decimal(_prec, scale) = dt else { unreachable!() };
496                        Some(DataType::Decimal(DEC128_MAX_PREC, *scale))
497                    }
498                    _ => Some(Unknown(UnknownKind::Any))
499                }
500            },
501            #[cfg(feature = "dtype-struct")]
502            (Struct(fields_a), Struct(fields_b)) => {
503                super_type_structs(fields_a, fields_b)
504            }
505            #[cfg(feature = "dtype-struct")]
506            (Struct(fields_a), rhs) if rhs.is_primitive_numeric() => {
507                let mut new_fields = Vec::with_capacity(fields_a.len());
508                for a in fields_a {
509                    let st = get_supertype(&a.dtype, rhs)?;
510                    new_fields.push(Field::new(a.name.clone(), st))
511                }
512                Some(Struct(new_fields))
513            }
514            #[cfg(feature = "dtype-decimal")]
515            (Decimal(p1, s1), Decimal(p2, s2)) => {
516                Some(Decimal((*p1).max(*p2), (*s1).max(*s2)))
517            },
518            #[cfg(all(feature = "dtype-decimal", feature = "dtype-f16"))]
519            (Decimal(_, _), Float16) => Some(Float64),
520            #[cfg(feature = "dtype-decimal")]
521            (Decimal(_, _), Float32 | Float64) => Some(Float64),
522            #[cfg(feature = "dtype-decimal")]
523            (Decimal(prec, scale), dt) if dt.is_signed_integer() || dt.is_unsigned_integer() => {
524                let fits = |v| { i128_to_dec128(v, *prec, *scale).is_some() };
525                let fits_orig_prec_scale = match dt {
526                    UInt8 => fits(u8::MAX as i128),
527                    UInt16 => fits(u16::MAX as i128),
528                    UInt32 => fits(u32::MAX as i128),
529                    UInt64 => fits(u64::MAX as i128),
530                    UInt128 => false,
531                    Int8 => fits(i8::MAX as i128),
532                    Int16 => fits(i16::MAX as i128),
533                    Int32 => fits(i32::MAX as i128),
534                    Int64 => fits(i64::MAX as i128),
535                    Int128 => false,
536                    _ => unreachable!(),
537                };
538                if fits_orig_prec_scale {
539                    Some(Decimal(*prec, *scale))
540                } else {
541                    Some(Decimal(DEC128_MAX_PREC, *scale))
542                }
543            }
544            _ => None,
545        }
546    }
547
548    inner(l, r, options).or_else(|| inner(r, l, options))
549}
550
551/// Given multiple data types, determine the data type that all types can safely be cast to.
552///
553/// Returns [`DataType::Null`] if no data types were passed.
554pub fn dtypes_to_supertype<'a, I>(dtypes: I) -> PolarsResult<DataType>
555where
556    I: IntoIterator<Item = &'a DataType>,
557{
558    dtypes
559        .into_iter()
560        .try_fold(DataType::Null, |supertype, dtype| {
561            try_get_supertype(&supertype, dtype)
562        })
563}
564
565#[cfg(feature = "dtype-struct")]
566fn union_struct_fields(fields_a: &[Field], fields_b: &[Field]) -> Option<DataType> {
567    let (longest, shortest) = {
568        // if equal length we also take the lhs
569        // so that the lhs determines the order of the fields
570        if fields_a.len() >= fields_b.len() {
571            (fields_a, fields_b)
572        } else {
573            (fields_b, fields_a)
574        }
575    };
576    let mut longest_map =
577        PlIndexMap::from_iter(longest.iter().map(|fld| (&fld.name, fld.dtype.clone())));
578    for field in shortest {
579        let dtype_longest = longest_map
580            .entry(&field.name)
581            .or_insert_with(|| field.dtype.clone());
582        if &field.dtype != dtype_longest {
583            let st = get_supertype(&field.dtype, dtype_longest)?;
584            *dtype_longest = st
585        }
586    }
587    let new_fields = longest_map
588        .into_iter()
589        .map(|(name, dtype)| Field::new(name.clone(), dtype))
590        .collect::<Vec<_>>();
591    Some(DataType::Struct(new_fields))
592}
593
594#[cfg(feature = "dtype-struct")]
595fn super_type_structs(fields_a: &[Field], fields_b: &[Field]) -> Option<DataType> {
596    if fields_a.len() != fields_b.len() {
597        union_struct_fields(fields_a, fields_b)
598    } else {
599        let mut new_fields = Vec::with_capacity(fields_a.len());
600        for (a, b) in fields_a.iter().zip(fields_b) {
601            if a.name != b.name {
602                return union_struct_fields(fields_a, fields_b);
603            }
604            let st = get_supertype(&a.dtype, &b.dtype)?;
605            new_fields.push(Field::new(a.name.clone(), st))
606        }
607        Some(DataType::Struct(new_fields))
608    }
609}
610
611pub fn materialize_dyn_int(v: i128) -> AnyValue<'static> {
612    // Try to get the "smallest" fitting value.
613    // TODO! next breaking go to true smallest.
614    if let Ok(v) = i32::try_from(v) {
615        return AnyValue::Int32(v);
616    }
617    if let Ok(v) = i64::try_from(v) {
618        return AnyValue::Int64(v);
619    }
620    if let Ok(v) = u64::try_from(v) {
621        return AnyValue::UInt64(v);
622    }
623    #[cfg(feature = "dtype-i128")]
624    {
625        AnyValue::Int128(v)
626    }
627
628    #[cfg(not(feature = "dtype-i128"))]
629    AnyValue::Null
630}
631
632fn materialize_dyn_int_pos(v: i128) -> AnyValue<'static> {
633    // Try to get the "smallest" fitting value.
634    // TODO! next breaking go to true smallest.
635    #[cfg(feature = "dtype-u8")]
636    if let Ok(v) = u8::try_from(v) {
637        return AnyValue::UInt8(v);
638    }
639    #[cfg(feature = "dtype-u16")]
640    if let Ok(v) = u16::try_from(v) {
641        return AnyValue::UInt16(v);
642    }
643    match u32::try_from(v).ok() {
644        Some(v) => AnyValue::UInt32(v),
645        None => match u64::try_from(v).ok() {
646            Some(v) => AnyValue::UInt64(v),
647            None => AnyValue::Null,
648        },
649    }
650}
651
652fn materialize_smallest_dyn_int(v: i128) -> AnyValue<'static> {
653    #[cfg(feature = "dtype-i8")]
654    if let Ok(v) = i8::try_from(v) {
655        return AnyValue::Int8(v);
656    }
657    #[cfg(feature = "dtype-i16")]
658    if let Ok(v) = i16::try_from(v) {
659        return AnyValue::Int16(v);
660    }
661    match i32::try_from(v).ok() {
662        Some(v) => AnyValue::Int32(v),
663        None => match i64::try_from(v).ok() {
664            Some(v) => AnyValue::Int64(v),
665            None => match u64::try_from(v).ok() {
666                Some(v) => AnyValue::UInt64(v),
667                None => AnyValue::Null,
668            },
669        },
670    }
671}
672
673pub fn merge_dtypes_many<I: IntoIterator<Item = D> + Clone, D: AsRef<DataType>>(
674    into_iter: I,
675) -> PolarsResult<DataType> {
676    let mut iter = into_iter.into_iter();
677
678    let mut st = iter
679        .next()
680        .ok_or_else(|| polars_err!(ComputeError: "expect at least 1 dtype"))
681        .map(|d| d.as_ref().clone())?;
682
683    for d in iter {
684        st = try_get_supertype(d.as_ref(), &st)?;
685    }
686
687    Ok(st)
688}