polars_core/chunked_array/arithmetic/
mod.rs

1//! Implementations of arithmetic operations on ChunkedArrays.
2#[cfg(feature = "dtype-decimal")]
3mod decimal;
4mod numeric;
5
6use std::ops::{Add, Div, Mul, Rem, Sub};
7
8use arrow::compute::utils::combine_validities_and;
9#[cfg(feature = "dtype-decimal")]
10pub use decimal::{_get_decimal_scale_add_sub, _get_decimal_scale_div, _get_decimal_scale_mul};
11use num_traits::{Num, NumCast, ToPrimitive};
12pub use numeric::ArithmeticChunked;
13
14use crate::prelude::arity::unary_elementwise_values;
15use crate::prelude::*;
16
17#[inline]
18fn concat_binary_arrs(l: &[u8], r: &[u8], buf: &mut Vec<u8>) {
19    buf.clear();
20
21    buf.extend_from_slice(l);
22    buf.extend_from_slice(r);
23}
24
25impl Add for &StringChunked {
26    type Output = StringChunked;
27
28    fn add(self, rhs: Self) -> Self::Output {
29        unsafe { (self.as_binary() + rhs.as_binary()).to_string_unchecked() }
30    }
31}
32
33impl Add for StringChunked {
34    type Output = StringChunked;
35
36    fn add(self, rhs: Self) -> Self::Output {
37        (&self).add(&rhs)
38    }
39}
40
41impl Add<&str> for &StringChunked {
42    type Output = StringChunked;
43
44    fn add(self, rhs: &str) -> Self::Output {
45        unsafe { ((&self.as_binary()) + rhs.as_bytes()).to_string_unchecked() }
46    }
47}
48
49fn concat_binview(a: &BinaryViewArray, b: &BinaryViewArray) -> BinaryViewArray {
50    let validity = combine_validities_and(a.validity(), b.validity());
51
52    let mut mutable = MutableBinaryViewArray::with_capacity(a.len());
53
54    let mut scratch = vec![];
55    for (a, b) in a.values_iter().zip(b.values_iter()) {
56        concat_binary_arrs(a, b, &mut scratch);
57        mutable.push_value(&scratch)
58    }
59
60    mutable.freeze().with_validity(validity)
61}
62
63impl Add for &BinaryChunked {
64    type Output = BinaryChunked;
65
66    fn add(self, rhs: Self) -> Self::Output {
67        // broadcasting path rhs
68        if rhs.len() == 1 {
69            let rhs = rhs.get(0);
70            let mut buf = vec![];
71            return match rhs {
72                Some(rhs) => {
73                    self.apply_mut(|s| {
74                        concat_binary_arrs(s, rhs, &mut buf);
75                        let out = buf.as_slice();
76                        // SAFETY: lifetime is bound to the outer scope and the
77                        // ref is valid for the lifetime of this closure.
78                        unsafe { std::mem::transmute::<_, &'static [u8]>(out) }
79                    })
80                },
81                None => BinaryChunked::full_null(self.name().clone(), self.len()),
82            };
83        }
84        // broadcasting path lhs
85        if self.len() == 1 {
86            let lhs = self.get(0);
87            let mut buf = vec![];
88            return match lhs {
89                Some(lhs) => rhs.apply_mut(|s| {
90                    concat_binary_arrs(lhs, s, &mut buf);
91                    let out = buf.as_slice();
92                    // SAFETY: lifetime is bound to the outer scope and the
93                    // ref is valid for the lifetime of this closure.
94                    unsafe { std::mem::transmute::<_, &'static [u8]>(out) }
95                }),
96                None => BinaryChunked::full_null(self.name().clone(), rhs.len()),
97            };
98        }
99
100        arity::binary(self, rhs, concat_binview)
101    }
102}
103
104impl Add for BinaryChunked {
105    type Output = BinaryChunked;
106
107    fn add(self, rhs: Self) -> Self::Output {
108        (&self).add(&rhs)
109    }
110}
111
112impl Add<&[u8]> for &BinaryChunked {
113    type Output = BinaryChunked;
114
115    fn add(self, rhs: &[u8]) -> Self::Output {
116        let arr = BinaryViewArray::from_slice_values([rhs]);
117        let rhs: BinaryChunked = arr.into();
118        self.add(&rhs)
119    }
120}
121
122fn add_boolean(a: &BooleanArray, b: &BooleanArray) -> PrimitiveArray<IdxSize> {
123    let validity = combine_validities_and(a.validity(), b.validity());
124
125    let values = a
126        .values_iter()
127        .zip(b.values_iter())
128        .map(|(a, b)| a as IdxSize + b as IdxSize)
129        .collect::<Vec<_>>();
130    PrimitiveArray::from_data_default(values.into(), validity)
131}
132
133impl Add for &BooleanChunked {
134    type Output = IdxCa;
135
136    fn add(self, rhs: Self) -> Self::Output {
137        // Broadcasting path rhs.
138        if rhs.len() == 1 {
139            let rhs = rhs.get(0);
140            return match rhs {
141                Some(rhs) => unary_elementwise_values(self, |v| v as IdxSize + rhs as IdxSize),
142                None => IdxCa::full_null(self.name().clone(), self.len()),
143            };
144        }
145        // Broadcasting path lhs.
146        if self.len() == 1 {
147            return rhs.add(self);
148        }
149        arity::binary(self, rhs, add_boolean)
150    }
151}
152
153impl Add for BooleanChunked {
154    type Output = IdxCa;
155
156    fn add(self, rhs: Self) -> Self::Output {
157        (&self).add(&rhs)
158    }
159}
160
161#[cfg(test)]
162pub(crate) mod test {
163    use crate::prelude::*;
164
165    pub(crate) fn create_two_chunked() -> (Int32Chunked, Int32Chunked) {
166        let mut a1 = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
167        let a2 = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 5, 6]);
168        let a3 = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3, 4, 5, 6]);
169        a1.append(&a2).unwrap();
170        (a1, a3)
171    }
172
173    #[test]
174    #[allow(clippy::eq_op)]
175    fn test_chunk_mismatch() {
176        let (a1, a2) = create_two_chunked();
177        // With different chunks.
178        let _ = &a1 + &a2;
179        let _ = &a1 - &a2;
180        let _ = &a1 / &a2;
181        let _ = &a1 * &a2;
182
183        // With same chunks.
184        let _ = &a1 + &a1;
185        let _ = &a1 - &a1;
186        let _ = &a1 / &a1;
187        let _ = &a1 * &a1;
188    }
189}