polars_core/series/arithmetic/
list_utils.rs1use arrow::array::{Array, PrimitiveArray};
3use arrow::compute::utils::combine_validities_and;
4use num_traits::Zero;
5use polars_compute::arithmetic::ArithmeticKernel;
6use polars_compute::comparisons::TotalEqKernel;
7use polars_error::PolarsResult;
8use polars_utils::float::IsFloat;
9
10use super::*;
11use crate::series::ChunkedArray;
12use crate::utils::try_get_supertype;
13
14#[derive(Debug, Clone)]
15pub(super) enum NumericOp {
16 Add,
17 Sub,
18 Mul,
19 Div,
20 Rem,
21 FloorDiv,
22}
23
24impl NumericOp {
25 pub(super) fn name(&self) -> &'static str {
26 match self {
27 Self::Add => "add",
28 Self::Sub => "sub",
29 Self::Mul => "mul",
30 Self::Div => "div",
31 Self::Rem => "rem",
32 Self::FloorDiv => "floor_div",
33 }
34 }
35
36 pub(super) fn try_get_leaf_supertype(
37 &self,
38 prim_dtype_lhs: &DataType,
39 prim_dtype_rhs: &DataType,
40 ) -> PolarsResult<DataType> {
41 let dtype = try_get_supertype(prim_dtype_lhs, prim_dtype_rhs)?;
42
43 Ok(if matches!(self, Self::Div) {
44 if dtype.is_float() {
45 dtype
46 } else {
47 DataType::Float64
48 }
49 } else if prim_dtype_lhs == &DataType::Boolean && prim_dtype_rhs == &DataType::Boolean {
50 return Ok(IDX_DTYPE);
51 } else {
52 dtype
53 })
54 }
55
56 pub(super) fn prepare_numeric_op_side_validities<T: PolarsNumericType>(
59 &self,
60 lhs: &mut PrimitiveArray<T::Native>,
61 rhs: &mut PrimitiveArray<T::Native>,
62 swapped: bool,
63 ) where
64 PrimitiveArray<T::Native>: polars_compute::comparisons::TotalEqKernel<Scalar = T::Native>,
65 T::Native: Zero + IsFloat,
66 {
67 if !T::Native::is_float() {
68 match self {
69 Self::Div | Self::Rem | Self::FloorDiv => {
70 let target = if swapped { lhs } else { rhs };
71 let ne_0 = target.tot_ne_kernel_broadcast(&T::Native::zero());
72 let validity = combine_validities_and(target.validity(), Some(&ne_0));
73 target.set_validity(validity);
74 },
75 _ => {},
76 }
77 }
78 }
79
80 pub(super) fn apply_series(&self, lhs: &Series, rhs: &Series) -> Box<dyn Array> {
85 assert_eq!(lhs.len(), rhs.len());
86 debug_assert_eq!(lhs.dtype(), rhs.dtype());
87
88 let lhs = lhs.rechunk();
89 let rhs = rhs.rechunk();
90
91 with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| {
92 let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
93 let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref();
94
95 let lhs = lhs.downcast_get(0).unwrap();
96 let rhs = rhs.downcast_get(0).unwrap();
97
98 Box::new(self.apply_arithmetic_kernel::<$T>(lhs.clone(), rhs.clone()))
99 })
100 }
101
102 fn apply_arithmetic_kernel<T: PolarsNumericType>(
103 &self,
104 lhs: PrimitiveArray<T::Native>,
105 rhs: PrimitiveArray<T::Native>,
106 ) -> PrimitiveArray<T::Native> {
107 match self {
108 Self::Add => ArithmeticKernel::wrapping_add(lhs, rhs),
109 Self::Sub => ArithmeticKernel::wrapping_sub(lhs, rhs),
110 Self::Mul => ArithmeticKernel::wrapping_mul(lhs, rhs),
111 Self::Div => ArithmeticKernel::legacy_div(lhs, rhs),
112 Self::Rem => ArithmeticKernel::wrapping_mod(lhs, rhs),
113 Self::FloorDiv => ArithmeticKernel::wrapping_floor_div(lhs, rhs),
114 }
115 }
116
117 pub(super) fn apply_array_to_scalar<T: PolarsNumericType>(
121 &self,
122 arr_lhs: PrimitiveArray<T::Native>,
123 r: T::Native,
124 swapped: bool,
125 ) -> PrimitiveArray<T::Native> {
126 match self {
127 Self::Add => ArithmeticKernel::wrapping_add_scalar(arr_lhs, r),
128 Self::Sub => {
129 if swapped {
130 ArithmeticKernel::wrapping_sub_scalar_lhs(r, arr_lhs)
131 } else {
132 ArithmeticKernel::wrapping_sub_scalar(arr_lhs, r)
133 }
134 },
135 Self::Mul => ArithmeticKernel::wrapping_mul_scalar(arr_lhs, r),
136 Self::Div => {
137 if swapped {
138 ArithmeticKernel::legacy_div_scalar_lhs(r, arr_lhs)
139 } else {
140 ArithmeticKernel::legacy_div_scalar(arr_lhs, r)
141 }
142 },
143 Self::Rem => {
144 if swapped {
145 ArithmeticKernel::wrapping_mod_scalar_lhs(r, arr_lhs)
146 } else {
147 ArithmeticKernel::wrapping_mod_scalar(arr_lhs, r)
148 }
149 },
150 Self::FloorDiv => {
151 if swapped {
152 ArithmeticKernel::wrapping_floor_div_scalar_lhs(r, arr_lhs)
153 } else {
154 ArithmeticKernel::wrapping_floor_div_scalar(arr_lhs, r)
155 }
156 },
157 }
158 }
159}
160
161macro_rules! with_match_pl_num_arith {
162 ($op:expr, $swapped:expr, | $_:tt $OP:tt | $($body:tt)* ) => ({
163 macro_rules! __with_func__ {( $_ $OP:tt ) => ( $($body)* )}
164
165 match $op {
166 NumericOp::Add => __with_func__! { (PlNumArithmetic::wrapping_add) },
167 NumericOp::Sub => {
168 if $swapped {
169 __with_func__! { (|b, a| PlNumArithmetic::wrapping_sub(a, b)) }
170 } else {
171 __with_func__! { (PlNumArithmetic::wrapping_sub) }
172 }
173 },
174 NumericOp::Mul => __with_func__! { (PlNumArithmetic::wrapping_mul) },
175 NumericOp::Div => {
176 if $swapped {
177 __with_func__! { (|b, a| PlNumArithmetic::legacy_div(a, b)) }
178 } else {
179 __with_func__! { (PlNumArithmetic::legacy_div) }
180 }
181 },
182 NumericOp::Rem => {
183 if $swapped {
184 __with_func__! { (|b, a| PlNumArithmetic::wrapping_mod(a, b)) }
185 } else {
186 __with_func__! { (PlNumArithmetic::wrapping_mod) }
187 }
188 },
189 NumericOp::FloorDiv => {
190 if $swapped {
191 __with_func__! { (|b, a| PlNumArithmetic::wrapping_floor_div(a, b)) }
192 } else {
193 __with_func__! { (PlNumArithmetic::wrapping_floor_div) }
194 }
195 },
196 }
197 })
198}
199
200pub(super) use with_match_pl_num_arith;
201
202#[derive(Debug)]
203pub(super) enum BinaryOpApplyType {
204 ListToList,
205 ListToPrimitive,
206 PrimitiveToList,
207}
208
209#[derive(Debug)]
210pub(super) enum Broadcast {
211 Left,
212 Right,
213 #[allow(clippy::enum_variant_names)]
214 NoBroadcast,
215}