polars_core/utils/
supertype.rs

1use bitflags::bitflags;
2use num_traits::Signed;
3#[cfg(feature = "dtype-decimal")]
4use polars_compute::decimal::{DEC128_MAX_PREC, i128_to_dec128};
5
6use super::*;
7
8/// Given two data types, determine the data type that both types can safely be cast to.
9///
10/// Returns a [`PolarsError::ComputeError`] if no such data type exists.
11pub fn try_get_supertype(l: &DataType, r: &DataType) -> PolarsResult<DataType> {
12    get_supertype(l, r).ok_or_else(
13        || polars_err!(SchemaMismatch: "failed to determine supertype of {} and {}", l, r),
14    )
15}
16
17pub fn try_get_supertype_with_options(
18    l: &DataType,
19    r: &DataType,
20    options: SuperTypeOptions,
21) -> PolarsResult<DataType> {
22    get_supertype_with_options(l, r, options).ok_or_else(
23        || polars_err!(SchemaMismatch: "failed to determine supertype of {} and {}", l, r),
24    )
25}
26
27/// Returns a numeric supertype that `l` and `r` can be safely upcasted to if it exists.
28pub fn get_numeric_upcast_supertype_lossless(l: &DataType, r: &DataType) -> Option<DataType> {
29    use DataType::*;
30
31    if l == r || matches!(l, Unknown(_)) || matches!(r, Unknown(_)) {
32        None
33    } else if l.is_float() && r.is_float() {
34        match (l, r) {
35            (Float64, _) | (_, Float64) => Some(Float64),
36            v => {
37                // Did we add a new float type?
38                if cfg!(debug_assertions) {
39                    panic!("{v:?}")
40                } else {
41                    None
42                }
43            },
44        }
45    } else if l.is_signed_integer() && r.is_signed_integer() {
46        match (l, r) {
47            (Int128, _) | (_, Int128) => Some(Int128),
48            (Int64, _) | (_, Int64) => Some(Int64),
49            (Int32, _) | (_, Int32) => Some(Int32),
50            (Int16, _) | (_, Int16) => Some(Int16),
51            (Int8, _) | (_, Int8) => Some(Int8),
52            v => {
53                if cfg!(debug_assertions) {
54                    panic!("{v:?}")
55                } else {
56                    None
57                }
58            },
59        }
60    } else if l.is_unsigned_integer() && r.is_unsigned_integer() {
61        match (l, r) {
62            (UInt128, _) | (_, UInt128) => Some(UInt128),
63            (UInt64, _) | (_, UInt64) => Some(UInt64),
64            (UInt32, _) | (_, UInt32) => Some(UInt32),
65            (UInt16, _) | (_, UInt16) => Some(UInt16),
66            (UInt8, _) | (_, UInt8) => Some(UInt8),
67            v => {
68                if cfg!(debug_assertions) {
69                    panic!("{v:?}")
70                } else {
71                    None
72                }
73            },
74        }
75    } else if l.is_integer() && r.is_integer() {
76        // One side is signed, the other is unsigned. We just need to upcast the
77        // unsigned side to a signed integer with the next-largest bit width.
78        match (l, r) {
79            (UInt128, _) | (_, UInt128) => None, // No lossless cast possible
80            (UInt64, _) | (_, UInt64) | (Int128, _) | (_, Int128) => Some(Int128),
81            (UInt32, _) | (_, UInt32) | (Int64, _) | (_, Int64) => Some(Int64),
82            (UInt16, _) | (_, UInt16) | (Int32, _) | (_, Int32) => Some(Int32),
83            (UInt8, _) | (_, UInt8) | (Int16, _) | (_, Int16) => Some(Int16),
84            v => {
85                // One side was UInt and we should have already matched against
86                // all the UInt types
87                if cfg!(debug_assertions) {
88                    panic!("{v:?}")
89                } else {
90                    None
91                }
92            },
93        }
94    } else {
95        None
96    }
97}
98
99bitflags! {
100    #[repr(transparent)]
101    #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
102    pub struct SuperTypeFlags: u8 {
103        /// Implode lists to match nesting types.
104        const ALLOW_IMPLODE_LIST = 1 << 0;
105        /// Allow casting of primitive types (numeric, bools) to strings
106        const ALLOW_PRIMITIVE_TO_STRING = 1 << 1;
107    }
108}
109
110impl Default for SuperTypeFlags {
111    fn default() -> Self {
112        SuperTypeFlags::from_bits_truncate(0) | SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING
113    }
114}
115
116#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Default)]
117pub struct SuperTypeOptions {
118    pub flags: SuperTypeFlags,
119}
120
121impl From<SuperTypeFlags> for SuperTypeOptions {
122    fn from(flags: SuperTypeFlags) -> Self {
123        SuperTypeOptions { flags }
124    }
125}
126
127impl SuperTypeOptions {
128    pub fn allow_implode_list(&self) -> bool {
129        self.flags.contains(SuperTypeFlags::ALLOW_IMPLODE_LIST)
130    }
131
132    pub fn allow_primitive_to_string(&self) -> bool {
133        self.flags
134            .contains(SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING)
135    }
136}
137
138pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
139    get_supertype_with_options(l, r, SuperTypeOptions::default())
140}
141
142/// Given two data types, determine the data type that both types can reasonably safely be cast to.
143///
144///
145/// Returns [`None`] if no such data type exists.
146pub fn get_supertype_with_options(
147    l: &DataType,
148    r: &DataType,
149    options: SuperTypeOptions,
150) -> Option<DataType> {
151    fn inner(l: &DataType, r: &DataType, options: SuperTypeOptions) -> Option<DataType> {
152        use DataType::*;
153        if l == r {
154            return Some(l.clone());
155        }
156        match (l, r) {
157            #[cfg(feature = "dtype-i8")]
158            (Int8, Boolean) => Some(Int8),
159            //(Int8, Int8) => Some(Int8),
160            #[cfg(all(feature = "dtype-i8", feature = "dtype-i16"))]
161            (Int8, Int16) => Some(Int16),
162            #[cfg(feature = "dtype-i8")]
163            (Int8, Int32) => Some(Int32),
164            #[cfg(feature = "dtype-i8")]
165            (Int8, Int64) => Some(Int64),
166            #[cfg(all(feature = "dtype-i8", feature = "dtype-i16"))]
167            (Int8, UInt8) => Some(Int16),
168            #[cfg(all(feature = "dtype-i8", feature = "dtype-u16"))]
169            (Int8, UInt16) => Some(Int32),
170            #[cfg(feature = "dtype-i8")]
171            (Int8, UInt32) => Some(Int64),
172            #[cfg(feature = "dtype-i8")]
173            (Int8, UInt64) => Some(Float64), // Follow numpy
174            #[cfg(all(feature = "dtype-i8", feature = "dtype-f16"))]
175            (Int8, Float16) => Some(Float16),
176            #[cfg(feature = "dtype-i8")]
177            (Int8, Float32) => Some(Float32),
178            #[cfg(feature = "dtype-i8")]
179            (Int8, Float64) => Some(Float64),
180
181            #[cfg(feature = "dtype-i16")]
182            (Int16, Boolean) => Some(Int16),
183            #[cfg(all(feature = "dtype-i16", feature = "dtype-i8"))]
184            (Int16, Int8) => Some(Int16),
185            //(Int16, Int16) => Some(Int16),
186            #[cfg(feature = "dtype-i16")]
187            (Int16, Int32) => Some(Int32),
188            #[cfg(feature = "dtype-i16")]
189            (Int16, Int64) => Some(Int64),
190            #[cfg(all(feature = "dtype-i16", feature = "dtype-u8"))]
191            (Int16, UInt8) => Some(Int16),
192            #[cfg(all(feature = "dtype-i16", feature = "dtype-u16"))]
193            (Int16, UInt16) => Some(Int32),
194            #[cfg(feature = "dtype-i16")]
195            (Int16, UInt32) => Some(Int64),
196            #[cfg(feature = "dtype-i16")]
197            (Int16, UInt64) => Some(Float64), // Follow numpy
198            #[cfg(all(feature = "dtype-i16", feature = "dtype-f16"))]
199            (Int16, Float16) => Some(Float32),
200            #[cfg(feature = "dtype-i16")]
201            (Int16, Float32) => Some(Float32),
202            #[cfg(feature = "dtype-i16")]
203            (Int16, Float64) => Some(Float64),
204
205            (Int32, Boolean) => Some(Int32),
206            #[cfg(feature = "dtype-i8")]
207            (Int32, Int8) => Some(Int32),
208            #[cfg(feature = "dtype-i16")]
209            (Int32, Int16) => Some(Int32),
210            //(Int32, Int32) => Some(Int32),
211            (Int32, Int64) => Some(Int64),
212            #[cfg(feature = "dtype-u8")]
213            (Int32, UInt8) => Some(Int32),
214            #[cfg(feature = "dtype-u16")]
215            (Int32, UInt16) => Some(Int32),
216            (Int32, UInt32) => Some(Int64),
217            #[cfg(not(feature = "bigidx"))]
218            (Int32, UInt64) => Some(Float64), // Follow numpy
219            #[cfg(feature = "bigidx")]
220            (Int32, UInt64) => Some(Int64), // Needed for bigidx
221            #[cfg(feature = "dtype-f16")]
222            (Int32, Float16) => Some(Float64),
223            (Int32, Float32) => Some(Float64), // Follow numpy
224            (Int32, Float64) => Some(Float64),
225
226            (Int64, Boolean) => Some(Int64),
227            #[cfg(feature = "dtype-i8")]
228            (Int64, Int8) => Some(Int64),
229            #[cfg(feature = "dtype-i16")]
230            (Int64, Int16) => Some(Int64),
231            (Int64, Int32) => Some(Int64),
232            //(Int64, Int64) => Some(Int64),
233            #[cfg(feature = "dtype-u8")]
234            (Int64, UInt8) => Some(Int64),
235            #[cfg(feature = "dtype-u16")]
236            (Int64, UInt16) => Some(Int64),
237            (Int64, UInt32) => Some(Int64),
238            #[cfg(not(feature = "bigidx"))]
239            (Int64, UInt64) => Some(Float64), // Follow numpy
240            #[cfg(feature = "bigidx")]
241            (Int64, UInt64) => Some(Int64), // Needed for bigidx
242            #[cfg(feature = "dtype-f16")]
243            (Int64, Float16) => Some(Float64), // Follow (Int64, Float32) case
244            (Int64, Float32) => Some(Float64), // Follow numpy
245            (Int64, Float64) => Some(Float64),
246
247            #[cfg(feature = "dtype-i128")]
248            (Int128, a) if a.is_integer() | a.is_bool() => Some(Int128),
249            #[cfg(feature = "dtype-i128")]
250            (Int128, a) if a.is_float() => Some(Float64),
251
252            #[cfg(feature = "dtype-u8")]
253            (UInt8, UInt32) => Some(UInt32),
254            #[cfg(feature = "dtype-u8")]
255            (UInt8, UInt64) => Some(UInt64),
256
257            #[cfg(all(feature = "dtype-u16", feature = "dtype-u8"))]
258            (UInt16, UInt8) => Some(UInt16),
259            #[cfg(feature = "dtype-u16")]
260            (UInt16, UInt32) => Some(UInt32),
261            #[cfg(feature = "dtype-u16")]
262            (UInt16, UInt64) => Some(UInt64),
263
264            (UInt32, UInt64) => Some(UInt64),
265
266            #[cfg(feature = "dtype-u128")]
267            (UInt128, a) if a.is_unsigned_integer() || a.is_bool() => Some(UInt128),
268            #[cfg(feature = "dtype-u128")]
269            (UInt128, a) if a.is_signed_integer() => Some(Int128),
270            #[cfg(feature = "dtype-u128")]
271            (UInt128, a) if a.is_float() => Some(Float64),
272
273            #[cfg(feature = "dtype-u8")]
274            (Boolean, UInt8) => Some(UInt8),
275            #[cfg(feature = "dtype-u16")]
276            (Boolean, UInt16) => Some(UInt16),
277            (Boolean, UInt32) => Some(UInt32),
278            (Boolean, UInt64) => Some(UInt64),
279
280            #[cfg(all(feature = "dtype-f16", feature = "dtype-u8"))]
281            (Float16, UInt8) => Some(Float16),
282            #[cfg(all(feature = "dtype-f16", feature = "dtype-u16"))]
283            (Float16, UInt16) => Some(Float32),
284            #[cfg(feature = "dtype-f16")]
285            (Float16, UInt32) => Some(Float64),
286            #[cfg(feature = "dtype-f16")]
287            (Float16, UInt64) => Some(Float64),
288
289            #[cfg(feature = "dtype-u8")]
290            (Float32, UInt8) => Some(Float32),
291            #[cfg(feature = "dtype-u16")]
292            (Float32, UInt16) => Some(Float32),
293            (Float32, UInt32) => Some(Float64),
294            (Float32, UInt64) => Some(Float64),
295
296            #[cfg(feature = "dtype-u8")]
297            (Float64, UInt8) => Some(Float64),
298            #[cfg(feature = "dtype-u16")]
299            (Float64, UInt16) => Some(Float64),
300            (Float64, UInt32) => Some(Float64),
301            (Float64, UInt64) => Some(Float64),
302
303            #[cfg(feature = "dtype-f16")]
304            (Float16, Float32) => Some(Float32),
305            #[cfg(feature = "dtype-f16")]
306            (Float16, Float64) => Some(Float64),
307            (Float32, Float64) => Some(Float64),
308            #[cfg(feature = "dtype-f16")]
309            (Float32, Float16) => Some(Float32),
310            #[cfg(feature = "dtype-f16")]
311            (Float64, Float16) => Some(Float64),
312            (Float64, Float32) => Some(Float64),
313
314            // Time related dtypes
315            #[cfg(feature = "dtype-date")]
316            (Date, UInt32) => Some(Int64),
317            #[cfg(feature = "dtype-date")]
318            (Date, UInt64) => Some(Int64),
319            #[cfg(feature = "dtype-date")]
320            (Date, Int32) => Some(Int32),
321            #[cfg(feature = "dtype-date")]
322            (Date, Int64) => Some(Int64),
323            #[cfg(all(feature = "dtype-date", feature = "dtype-f16"))]
324            (Date, Float16) => Some(Float32),
325            #[cfg(feature = "dtype-date")]
326            (Date, Float32) => Some(Float32),
327            #[cfg(feature = "dtype-date")]
328            (Date, Float64) => Some(Float64),
329            #[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))]
330            (Date, Datetime(tu, tz)) => Some(Datetime(*tu, tz.clone())),
331
332            #[cfg(feature = "dtype-datetime")]
333            (Datetime(_, _), UInt32) => Some(Int64),
334            #[cfg(feature = "dtype-datetime")]
335            (Datetime(_, _), UInt64) => Some(Int64),
336            #[cfg(feature = "dtype-datetime")]
337            (Datetime(_, _), Int32) => Some(Int64),
338            #[cfg(feature = "dtype-datetime")]
339            (Datetime(_, _), Int64) => Some(Int64),
340            #[cfg(all(feature = "dtype-datetime", feature = "dtype-f16"))]
341            (Datetime(_, _), Float16) => Some(Float64),
342            #[cfg(feature = "dtype-datetime")]
343            (Datetime(_, _), Float32) => Some(Float64),
344            #[cfg(feature = "dtype-datetime")]
345            (Datetime(_, _), Float64) => Some(Float64),
346            #[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))]
347            (Datetime(tu, tz), Date) => Some(Datetime(*tu, tz.clone())),
348
349            #[cfg(feature = "dtype-f16")]
350            (Boolean, Float16) => Some(Float16),
351            (Boolean, Float32) => Some(Float32),
352            (Boolean, Float64) => Some(Float64),
353
354            #[cfg(feature = "dtype-duration")]
355            (Duration(_), UInt32) => Some(Int64),
356            #[cfg(feature = "dtype-duration")]
357            (Duration(_), UInt64) => Some(Int64),
358            #[cfg(feature = "dtype-duration")]
359            (Duration(_), Int32) => Some(Int64),
360            #[cfg(feature = "dtype-duration")]
361            (Duration(_), Int64) => Some(Int64),
362            #[cfg(all(feature = "dtype-duration", feature = "dtype-f16"))]
363            (Duration(_), Float16) => Some(Float64),
364            #[cfg(feature = "dtype-duration")]
365            (Duration(_), Float32) => Some(Float64),
366            #[cfg(feature = "dtype-duration")]
367            (Duration(_), Float64) => Some(Float64),
368
369            #[cfg(feature = "dtype-time")]
370            (Time, Int32) => Some(Int64),
371            #[cfg(feature = "dtype-time")]
372            (Time, Int64) => Some(Int64),
373            #[cfg(all(feature = "dtype-time", feature = "dtype-f16"))]
374            (Time, Float16) => Some(Float64),
375            #[cfg(feature = "dtype-time")]
376            (Time, Float32) => Some(Float64),
377            #[cfg(feature = "dtype-time")]
378            (Time, Float64) => Some(Float64),
379
380            // Every known type can be cast to a string except binary
381            (dt, String) if !matches!(dt, Unknown(UnknownKind::Any)) && dt != &Binary && options.allow_primitive_to_string() || !dt.to_physical().is_primitive() => Some(String),
382            (String, Binary) => Some(Binary),
383            (dt, Null) => Some(dt.clone()),
384
385            #[cfg(all(feature = "dtype-duration", feature = "dtype-datetime"))]
386            (Duration(lu), Datetime(ru, Some(tz))) | (Datetime(lu, Some(tz)), Duration(ru)) => {
387                if tz.is_empty() {
388                    Some(Datetime(get_time_units(lu, ru), None))
389                } else {
390                    Some(Datetime(get_time_units(lu, ru), Some(tz.clone())))
391                }
392            }
393            #[cfg(all(feature = "dtype-duration", feature = "dtype-datetime"))]
394            (Duration(lu), Datetime(ru, None)) | (Datetime(lu, None), Duration(ru)) => {
395                Some(Datetime(get_time_units(lu, ru), None))
396            }
397            #[cfg(all(feature = "dtype-duration", feature = "dtype-date"))]
398            (Duration(_), Date) | (Date, Duration(_)) => Some(Date),
399            #[cfg(feature = "dtype-duration")]
400            (Duration(lu), Duration(ru)) => Some(Duration(get_time_units(lu, ru))),
401
402            // both None or both Some("<tz>") timezones
403            // we cast from more precision to higher precision as that always fits with occasional loss of precision
404            #[cfg(feature = "dtype-datetime")]
405            (Datetime(tu_l, tz_l), Datetime(tu_r, tz_r)) if
406                // both are none
407                (tz_l.is_none() && tz_r.is_none())
408                // both have the same time zone
409                || (tz_l.is_some() && (tz_l == tz_r)) => {
410                let tu = get_time_units(tu_l, tu_r);
411                Some(Datetime(tu, tz_r.clone()))
412            }
413            (List(inner_left), List(inner_right)) => {
414                let st = get_supertype(inner_left, inner_right)?;
415                Some(List(Box::new(st)))
416            }
417            #[cfg(feature = "dtype-array")]
418            (List(inner_left), Array(inner_right, _)) | (Array(inner_left, _), List(inner_right)) => {
419                let st = get_supertype(inner_left, inner_right)?;
420                Some(List(Box::new(st)))
421            }
422            #[cfg(feature = "dtype-array")]
423            (Array(inner_left, width_left), Array(inner_right, width_right)) if *width_left == *width_right => {
424                let st = get_supertype(inner_left, inner_right)?;
425                Some(Array(Box::new(st), *width_left))
426            }
427            (List(inner), other) | (other, List(inner)) if options.allow_implode_list() => {
428                let st = get_supertype(inner, other)?;
429                Some(List(Box::new(st)))
430            }
431            #[cfg(feature = "dtype-array")]
432            (Array(inner_left, _), Array(inner_right, _)) => {
433                let st = get_supertype(inner_left, inner_right)?;
434                Some(List(Box::new(st)))
435            }
436            #[cfg(feature = "dtype-struct")]
437            (Struct(inner), right @ Unknown(UnknownKind::Float | UnknownKind::Int(_))) => {
438                match inner.first() {
439                    Some(inner) => get_supertype(&inner.dtype, right),
440                    None => None
441                }
442            },
443            (dt, Unknown(kind)) => {
444                match kind {
445                    UnknownKind::Float | UnknownKind::Int(_) if  dt.is_string() => {
446                        if options.allow_primitive_to_string() {
447                            Some(dt.clone())
448                        } else {
449                            None
450                        }
451                    },
452                    // Materialize float to float
453                    UnknownKind::Float | UnknownKind::Int(_) if dt.is_float() => Some(dt.clone()),
454                    UnknownKind::Float if dt.is_integer() | dt.is_decimal() => Some(Unknown(UnknownKind::Float)),
455                    // Materialize str
456                    UnknownKind::Str if dt.is_string() | dt.is_enum() => Some(dt.clone()),
457                    // Materialize str
458                    #[cfg(feature = "dtype-categorical")]
459                    UnknownKind::Str if dt.is_categorical() => Some(dt.clone()),
460                    // Keep unknown
461                    dynam if dt.is_null() => Some(Unknown(*dynam)),
462                    // Find integers sizes
463                    UnknownKind::Int(v) if dt.is_primitive_numeric() => {
464                        // Both dyn int
465                        if let Unknown(UnknownKind::Int(v_other)) = dt {
466                            // Take the maximum value to ensure we bubble up the required minimal size.
467                            Some(Unknown(UnknownKind::Int(std::cmp::max(*v, *v_other))))
468                        }
469                        // dyn int vs number
470                        else {
471                            let smallest_fitting_dtype = if dt.is_unsigned_integer() && !v.is_negative() {
472                                materialize_dyn_int_pos(*v).dtype()
473                            } else {
474                                materialize_smallest_dyn_int(*v).dtype()
475                            };
476                            match dt {
477                                UInt64 if smallest_fitting_dtype.is_signed_integer() => {
478                                    // Ensure we don't cast to float when dealing with dynamic literals
479                                    Some(Int64)
480                                },
481                                _ => {
482                                    get_supertype(dt, &smallest_fitting_dtype)
483                                }
484                            }
485                        }
486                    }
487                    #[cfg(feature = "dtype-decimal")]
488                    UnknownKind::Int(_) if dt.is_decimal() => {
489                        let DataType::Decimal(_prec, scale) = dt else { unreachable!() };
490                        Some(DataType::Decimal(DEC128_MAX_PREC, *scale))
491                    }
492                    _ => Some(Unknown(UnknownKind::Any))
493                }
494            },
495            #[cfg(feature = "dtype-struct")]
496            (Struct(fields_a), Struct(fields_b)) => {
497                super_type_structs(fields_a, fields_b)
498            }
499            #[cfg(feature = "dtype-struct")]
500            (Struct(fields_a), rhs) if rhs.is_primitive_numeric() => {
501                let mut new_fields = Vec::with_capacity(fields_a.len());
502                for a in fields_a {
503                    let st = get_supertype(&a.dtype, rhs)?;
504                    new_fields.push(Field::new(a.name.clone(), st))
505                }
506                Some(Struct(new_fields))
507            }
508            #[cfg(feature = "dtype-decimal")]
509            (Decimal(p1, s1), Decimal(p2, s2)) => {
510                Some(Decimal((*p1).max(*p2), (*s1).max(*s2)))
511            },
512            #[cfg(all(feature = "dtype-decimal", feature = "dtype-f16"))]
513            (Decimal(_, _), Float16) => Some(Float64),
514            #[cfg(feature = "dtype-decimal")]
515            (Decimal(_, _), Float32 | Float64) => Some(Float64),
516            #[cfg(feature = "dtype-decimal")]
517            (Decimal(prec, scale), dt) if dt.is_signed_integer() || dt.is_unsigned_integer() => {
518                let fits = |v| { i128_to_dec128(v, *prec, *scale).is_some() };
519                let fits_orig_prec_scale = match dt {
520                    UInt8 => fits(u8::MAX as i128),
521                    UInt16 => fits(u16::MAX as i128),
522                    UInt32 => fits(u32::MAX as i128),
523                    UInt64 => fits(u64::MAX as i128),
524                    UInt128 => false,
525                    Int8 => fits(i8::MAX as i128),
526                    Int16 => fits(i16::MAX as i128),
527                    Int32 => fits(i32::MAX as i128),
528                    Int64 => fits(i64::MAX as i128),
529                    Int128 => false,
530                    _ => unreachable!(),
531                };
532                if fits_orig_prec_scale {
533                    Some(Decimal(*prec, *scale))
534                } else {
535                    Some(Decimal(DEC128_MAX_PREC, *scale))
536                }
537            }
538            _ => None,
539        }
540    }
541
542    inner(l, r, options).or_else(|| inner(r, l, options))
543}
544
545/// Given multiple data types, determine the data type that all types can safely be cast to.
546///
547/// Returns [`DataType::Null`] if no data types were passed.
548pub fn dtypes_to_supertype<'a, I>(dtypes: I) -> PolarsResult<DataType>
549where
550    I: IntoIterator<Item = &'a DataType>,
551{
552    dtypes
553        .into_iter()
554        .try_fold(DataType::Null, |supertype, dtype| {
555            try_get_supertype(&supertype, dtype)
556        })
557}
558
559#[cfg(feature = "dtype-struct")]
560fn union_struct_fields(fields_a: &[Field], fields_b: &[Field]) -> Option<DataType> {
561    let (longest, shortest) = {
562        // if equal length we also take the lhs
563        // so that the lhs determines the order of the fields
564        if fields_a.len() >= fields_b.len() {
565            (fields_a, fields_b)
566        } else {
567            (fields_b, fields_a)
568        }
569    };
570    let mut longest_map =
571        PlIndexMap::from_iter(longest.iter().map(|fld| (&fld.name, fld.dtype.clone())));
572    for field in shortest {
573        let dtype_longest = longest_map
574            .entry(&field.name)
575            .or_insert_with(|| field.dtype.clone());
576        if &field.dtype != dtype_longest {
577            let st = get_supertype(&field.dtype, dtype_longest)?;
578            *dtype_longest = st
579        }
580    }
581    let new_fields = longest_map
582        .into_iter()
583        .map(|(name, dtype)| Field::new(name.clone(), dtype))
584        .collect::<Vec<_>>();
585    Some(DataType::Struct(new_fields))
586}
587
588#[cfg(feature = "dtype-struct")]
589fn super_type_structs(fields_a: &[Field], fields_b: &[Field]) -> Option<DataType> {
590    if fields_a.len() != fields_b.len() {
591        union_struct_fields(fields_a, fields_b)
592    } else {
593        let mut new_fields = Vec::with_capacity(fields_a.len());
594        for (a, b) in fields_a.iter().zip(fields_b) {
595            if a.name != b.name {
596                return union_struct_fields(fields_a, fields_b);
597            }
598            let st = get_supertype(&a.dtype, &b.dtype)?;
599            new_fields.push(Field::new(a.name.clone(), st))
600        }
601        Some(DataType::Struct(new_fields))
602    }
603}
604
605pub fn materialize_dyn_int(v: i128) -> AnyValue<'static> {
606    // Try to get the "smallest" fitting value.
607    // TODO! next breaking go to true smallest.
608    if let Ok(v) = i32::try_from(v) {
609        return AnyValue::Int32(v);
610    }
611    if let Ok(v) = i64::try_from(v) {
612        return AnyValue::Int64(v);
613    }
614    if let Ok(v) = u64::try_from(v) {
615        return AnyValue::UInt64(v);
616    }
617    #[cfg(feature = "dtype-i128")]
618    {
619        AnyValue::Int128(v)
620    }
621
622    #[cfg(not(feature = "dtype-i128"))]
623    AnyValue::Null
624}
625
626fn materialize_dyn_int_pos(v: i128) -> AnyValue<'static> {
627    // Try to get the "smallest" fitting value.
628    // TODO! next breaking go to true smallest.
629    #[cfg(feature = "dtype-u8")]
630    if let Ok(v) = u8::try_from(v) {
631        return AnyValue::UInt8(v);
632    }
633    #[cfg(feature = "dtype-u16")]
634    if let Ok(v) = u16::try_from(v) {
635        return AnyValue::UInt16(v);
636    }
637    match u32::try_from(v).ok() {
638        Some(v) => AnyValue::UInt32(v),
639        None => match u64::try_from(v).ok() {
640            Some(v) => AnyValue::UInt64(v),
641            None => AnyValue::Null,
642        },
643    }
644}
645
646fn materialize_smallest_dyn_int(v: i128) -> AnyValue<'static> {
647    #[cfg(feature = "dtype-i8")]
648    if let Ok(v) = i8::try_from(v) {
649        return AnyValue::Int8(v);
650    }
651    #[cfg(feature = "dtype-i16")]
652    if let Ok(v) = i16::try_from(v) {
653        return AnyValue::Int16(v);
654    }
655    match i32::try_from(v).ok() {
656        Some(v) => AnyValue::Int32(v),
657        None => match i64::try_from(v).ok() {
658            Some(v) => AnyValue::Int64(v),
659            None => match u64::try_from(v).ok() {
660                Some(v) => AnyValue::UInt64(v),
661                None => AnyValue::Null,
662            },
663        },
664    }
665}
666
667pub fn merge_dtypes_many<I: IntoIterator<Item = D> + Clone, D: AsRef<DataType>>(
668    into_iter: I,
669) -> PolarsResult<DataType> {
670    let mut iter = into_iter.into_iter();
671
672    let mut st = iter
673        .next()
674        .ok_or_else(|| polars_err!(ComputeError: "expect at least 1 dtype"))
675        .map(|d| d.as_ref().clone())?;
676
677    for d in iter {
678        st = try_get_supertype(d.as_ref(), &st)?;
679    }
680
681    Ok(st)
682}