polars_core/series/arithmetic/
owned.rs

1use super::*;
2#[cfg(feature = "performant")]
3use crate::utils::align_chunks_binary_owned_series;
4
5#[cfg(feature = "performant")]
6pub fn coerce_lhs_rhs_owned(lhs: Series, rhs: Series) -> PolarsResult<(Series, Series)> {
7    let dtype = try_get_supertype(lhs.dtype(), rhs.dtype())?;
8    let left = if lhs.dtype() == &dtype {
9        lhs
10    } else {
11        lhs.cast(&dtype)?
12    };
13    let right = if rhs.dtype() == &dtype {
14        rhs
15    } else {
16        rhs.cast(&dtype)?
17    };
18    Ok((left, right))
19}
20
21fn is_eligible(lhs: &DataType, rhs: &DataType) -> bool {
22    !lhs.is_logical()
23        && lhs.to_physical().is_primitive_numeric()
24        && rhs.to_physical().is_primitive_numeric()
25}
26
27#[cfg(feature = "performant")]
28fn apply_operation_mut<T, F>(mut lhs: Series, mut rhs: Series, op: F) -> Series
29where
30    T: PolarsNumericType,
31    F: Fn(ChunkedArray<T>, ChunkedArray<T>) -> ChunkedArray<T> + Copy,
32{
33    let lhs_ca: &mut ChunkedArray<T> = lhs._get_inner_mut().as_mut();
34    let rhs_ca: &mut ChunkedArray<T> = rhs._get_inner_mut().as_mut();
35
36    let lhs = std::mem::take(lhs_ca);
37    let rhs = std::mem::take(rhs_ca);
38
39    op(lhs, rhs).into_series()
40}
41
42macro_rules! impl_operation {
43    ($operation:ident, $method:ident, $function:expr) => {
44        impl $operation for Series {
45            type Output = PolarsResult<Series>;
46
47            fn $method(self, rhs: Self) -> Self::Output {
48                #[cfg(feature = "performant")]
49                {
50                    // only physical numeric values take the mutable path
51                    if is_eligible(self.dtype(), rhs.dtype()) {
52                        let (lhs, rhs) = coerce_lhs_rhs_owned(self, rhs).unwrap();
53                        let (lhs, rhs) = align_chunks_binary_owned_series(lhs, rhs);
54                        use DataType::*;
55                        Ok(match lhs.dtype() {
56                            #[cfg(feature = "dtype-i8")]
57                            Int8 => apply_operation_mut::<Int8Type, _>(lhs, rhs, $function),
58                            #[cfg(feature = "dtype-i16")]
59                            Int16 => apply_operation_mut::<Int16Type, _>(lhs, rhs, $function),
60                            Int32 => apply_operation_mut::<Int32Type, _>(lhs, rhs, $function),
61                            Int64 => apply_operation_mut::<Int64Type, _>(lhs, rhs, $function),
62                            #[cfg(feature = "dtype-i128")]
63                            Int128 => apply_operation_mut::<Int128Type, _>(lhs, rhs, $function),
64                            #[cfg(feature = "dtype-u8")]
65                            UInt8 => apply_operation_mut::<UInt8Type, _>(lhs, rhs, $function),
66                            #[cfg(feature = "dtype-u16")]
67                            UInt16 => apply_operation_mut::<UInt16Type, _>(lhs, rhs, $function),
68                            UInt32 => apply_operation_mut::<UInt32Type, _>(lhs, rhs, $function),
69                            UInt64 => apply_operation_mut::<UInt64Type, _>(lhs, rhs, $function),
70                            #[cfg(feature = "dtype-u128")]
71                            UInt128 => apply_operation_mut::<UInt128Type, _>(lhs, rhs, $function),
72                            #[cfg(feature = "dtype-f16")]
73                            Float16 => apply_operation_mut::<Float16Type, _>(lhs, rhs, $function),
74                            Float32 => apply_operation_mut::<Float32Type, _>(lhs, rhs, $function),
75                            Float64 => apply_operation_mut::<Float64Type, _>(lhs, rhs, $function),
76                            _ => unreachable!(),
77                        })
78                    } else {
79                        (&self).$method(&rhs)
80                    }
81                }
82                #[cfg(not(feature = "performant"))]
83                {
84                    (&self).$method(&rhs)
85                }
86            }
87        }
88    };
89}
90
91impl_operation!(Add, add, |a, b| a.add(b));
92impl_operation!(Sub, sub, |a, b| a.sub(b));
93impl_operation!(Mul, mul, |a, b| a.mul(b));
94impl_operation!(Div, div, |a, b| a.div(b));
95
96impl Series {
97    pub fn try_add_owned(self, other: Self) -> PolarsResult<Self> {
98        if is_eligible(self.dtype(), other.dtype()) {
99            self + other
100        } else {
101            std::ops::Add::add(&self, &other)
102        }
103    }
104
105    pub fn try_sub_owned(self, other: Self) -> PolarsResult<Self> {
106        if is_eligible(self.dtype(), other.dtype()) {
107            self - other
108        } else {
109            std::ops::Sub::sub(&self, &other)
110        }
111    }
112
113    pub fn try_mul_owned(self, other: Self) -> PolarsResult<Self> {
114        if is_eligible(self.dtype(), other.dtype()) {
115            self * other
116        } else {
117            std::ops::Mul::mul(&self, &other)
118        }
119    }
120}