Skip to main content

polars_core/series/arithmetic/
list_utils.rs

1/// Functionality shared between list and array arithmetic implementations.
2use arrow::array::{Array, PrimitiveArray};
3use arrow::compute::utils::combine_validities_and;
4use num_traits::Zero;
5use polars_compute::arithmetic::ArithmeticKernel;
6use polars_compute::comparisons::TotalEqKernel;
7use polars_error::PolarsResult;
8use polars_utils::float::IsFloat;
9
10use super::*;
11use crate::series::ChunkedArray;
12use crate::utils::try_get_supertype;
13
14#[derive(Debug, Clone)]
15pub(super) enum NumericOp {
16    Add,
17    Sub,
18    Mul,
19    Div,
20    Rem,
21    FloorDiv,
22}
23
24impl NumericOp {
25    pub(super) fn name(&self) -> &'static str {
26        match self {
27            Self::Add => "add",
28            Self::Sub => "sub",
29            Self::Mul => "mul",
30            Self::Div => "div",
31            Self::Rem => "rem",
32            Self::FloorDiv => "floor_div",
33        }
34    }
35
36    pub(super) fn try_get_leaf_supertype(
37        &self,
38        prim_dtype_lhs: &DataType,
39        prim_dtype_rhs: &DataType,
40    ) -> PolarsResult<DataType> {
41        let dtype = try_get_supertype(prim_dtype_lhs, prim_dtype_rhs)?;
42
43        Ok(if matches!(self, Self::Div) {
44            if dtype.is_float() {
45                dtype
46            } else {
47                DataType::Float64
48            }
49        } else if prim_dtype_lhs == &DataType::Boolean && prim_dtype_rhs == &DataType::Boolean {
50            return Ok(IDX_DTYPE);
51        } else {
52            dtype
53        })
54    }
55
56    /// For operations that perform divisions on integers, sets the validity to NULL on rows where
57    /// the denominator is 0.
58    pub(super) fn prepare_numeric_op_side_validities<T: PolarsNumericType>(
59        &self,
60        lhs: &mut PrimitiveArray<T::Native>,
61        rhs: &mut PrimitiveArray<T::Native>,
62        swapped: bool,
63    ) where
64        PrimitiveArray<T::Native>: polars_compute::comparisons::TotalEqKernel<Scalar = T::Native>,
65        T::Native: Zero + IsFloat,
66    {
67        if !T::Native::is_float() {
68            match self {
69                Self::Div | Self::Rem | Self::FloorDiv => {
70                    let target = if swapped { lhs } else { rhs };
71                    let ne_0 = target.tot_ne_kernel_broadcast(&T::Native::zero());
72                    let validity = combine_validities_and(target.validity(), Some(&ne_0));
73                    target.set_validity(validity);
74                },
75                _ => {},
76            }
77        }
78    }
79
80    /// # Panics
81    /// Panics if:
82    /// * lhs.len() != rhs.len()
83    /// * dtype is not numeric.
84    pub(super) fn apply_series(&self, lhs: &Series, rhs: &Series) -> Box<dyn Array> {
85        assert_eq!(lhs.len(), rhs.len());
86        debug_assert_eq!(lhs.dtype(), rhs.dtype());
87
88        let lhs = lhs.rechunk();
89        let rhs = rhs.rechunk();
90
91        with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| {
92            let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
93            let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref();
94
95            let lhs = lhs.downcast_get(0).unwrap();
96            let rhs = rhs.downcast_get(0).unwrap();
97
98            Box::new(self.apply_arithmetic_kernel::<$T>(lhs.clone(), rhs.clone()))
99        })
100    }
101
102    fn apply_arithmetic_kernel<T: PolarsNumericType>(
103        &self,
104        lhs: PrimitiveArray<T::Native>,
105        rhs: PrimitiveArray<T::Native>,
106    ) -> PrimitiveArray<T::Native> {
107        match self {
108            Self::Add => ArithmeticKernel::wrapping_add(lhs, rhs),
109            Self::Sub => ArithmeticKernel::wrapping_sub(lhs, rhs),
110            Self::Mul => ArithmeticKernel::wrapping_mul(lhs, rhs),
111            Self::Div => ArithmeticKernel::legacy_div(lhs, rhs),
112            Self::Rem => ArithmeticKernel::wrapping_mod(lhs, rhs),
113            Self::FloorDiv => ArithmeticKernel::wrapping_floor_div(lhs, rhs),
114        }
115    }
116
117    /// For list<->primitive where the primitive is broadcasted, we can dispatch to
118    /// `ArithmeticKernel`, which can have optimized codepaths for when one side is
119    /// a scalar.
120    pub(super) fn apply_array_to_scalar<T: PolarsNumericType>(
121        &self,
122        arr_lhs: PrimitiveArray<T::Native>,
123        r: T::Native,
124        swapped: bool,
125    ) -> PrimitiveArray<T::Native> {
126        match self {
127            Self::Add => ArithmeticKernel::wrapping_add_scalar(arr_lhs, r),
128            Self::Sub => {
129                if swapped {
130                    ArithmeticKernel::wrapping_sub_scalar_lhs(r, arr_lhs)
131                } else {
132                    ArithmeticKernel::wrapping_sub_scalar(arr_lhs, r)
133                }
134            },
135            Self::Mul => ArithmeticKernel::wrapping_mul_scalar(arr_lhs, r),
136            Self::Div => {
137                if swapped {
138                    ArithmeticKernel::legacy_div_scalar_lhs(r, arr_lhs)
139                } else {
140                    ArithmeticKernel::legacy_div_scalar(arr_lhs, r)
141                }
142            },
143            Self::Rem => {
144                if swapped {
145                    ArithmeticKernel::wrapping_mod_scalar_lhs(r, arr_lhs)
146                } else {
147                    ArithmeticKernel::wrapping_mod_scalar(arr_lhs, r)
148                }
149            },
150            Self::FloorDiv => {
151                if swapped {
152                    ArithmeticKernel::wrapping_floor_div_scalar_lhs(r, arr_lhs)
153                } else {
154                    ArithmeticKernel::wrapping_floor_div_scalar(arr_lhs, r)
155                }
156            },
157        }
158    }
159}
160
161macro_rules! with_match_pl_num_arith {
162    ($op:expr, $swapped:expr, | $_:tt $OP:tt | $($body:tt)* ) => ({
163        macro_rules! __with_func__ {( $_ $OP:tt ) => ( $($body)* )}
164
165        match $op {
166            NumericOp::Add => __with_func__! { (PlNumArithmetic::wrapping_add) },
167            NumericOp::Sub => {
168                if $swapped {
169                    __with_func__! { (|b, a| PlNumArithmetic::wrapping_sub(a, b)) }
170                } else {
171                    __with_func__! { (PlNumArithmetic::wrapping_sub) }
172                }
173            },
174            NumericOp::Mul => __with_func__! { (PlNumArithmetic::wrapping_mul) },
175            NumericOp::Div => {
176                if $swapped {
177                    __with_func__! { (|b, a| PlNumArithmetic::legacy_div(a, b)) }
178                } else {
179                    __with_func__! { (PlNumArithmetic::legacy_div) }
180                }
181            },
182            NumericOp::Rem => {
183                if $swapped {
184                    __with_func__! { (|b, a| PlNumArithmetic::wrapping_mod(a, b)) }
185                } else {
186                    __with_func__! { (PlNumArithmetic::wrapping_mod) }
187                }
188            },
189            NumericOp::FloorDiv => {
190                if $swapped {
191                    __with_func__! { (|b, a| PlNumArithmetic::wrapping_floor_div(a, b)) }
192                } else {
193                    __with_func__! { (PlNumArithmetic::wrapping_floor_div) }
194                }
195            },
196        }
197    })
198}
199
200pub(super) use with_match_pl_num_arith;
201
202#[derive(Debug)]
203pub(super) enum BinaryOpApplyType {
204    ListToList,
205    ListToPrimitive,
206    PrimitiveToList,
207}
208
209#[derive(Debug)]
210pub(super) enum Broadcast {
211    Left,
212    Right,
213    #[allow(clippy::enum_variant_names)]
214    NoBroadcast,
215}