polars_core/chunked_array/arithmetic/
mod.rs

1//! Implementations of arithmetic operations on ChunkedArrays.
2#[cfg(feature = "dtype-decimal")]
3mod decimal;
4mod numeric;
5
6use std::ops::{Add, Div, Mul, Rem, Sub};
7
8use arrow::compute::utils::combine_validities_and;
9use num_traits::{Num, NumCast, ToPrimitive};
10pub use numeric::ArithmeticChunked;
11
12use crate::prelude::arity::unary_elementwise_values;
13use crate::prelude::*;
14
15#[inline]
16fn concat_binary_arrs(l: &[u8], r: &[u8], buf: &mut Vec<u8>) {
17    buf.clear();
18
19    buf.extend_from_slice(l);
20    buf.extend_from_slice(r);
21}
22
23impl Add for &StringChunked {
24    type Output = StringChunked;
25
26    fn add(self, rhs: Self) -> Self::Output {
27        unsafe { (self.as_binary() + rhs.as_binary()).to_string_unchecked() }
28    }
29}
30
31impl Add for StringChunked {
32    type Output = StringChunked;
33
34    fn add(self, rhs: Self) -> Self::Output {
35        (&self).add(&rhs)
36    }
37}
38
39impl Add<&str> for &StringChunked {
40    type Output = StringChunked;
41
42    fn add(self, rhs: &str) -> Self::Output {
43        unsafe { ((&self.as_binary()) + rhs.as_bytes()).to_string_unchecked() }
44    }
45}
46
47fn concat_binview(a: &BinaryViewArray, b: &BinaryViewArray) -> BinaryViewArray {
48    let validity = combine_validities_and(a.validity(), b.validity());
49
50    let mut mutable = MutableBinaryViewArray::with_capacity(a.len());
51
52    let mut scratch = vec![];
53    for (a, b) in a.values_iter().zip(b.values_iter()) {
54        concat_binary_arrs(a, b, &mut scratch);
55        mutable.push_value(&scratch)
56    }
57
58    mutable.freeze().with_validity(validity)
59}
60
61impl Add for &BinaryChunked {
62    type Output = BinaryChunked;
63
64    fn add(self, rhs: Self) -> Self::Output {
65        // broadcasting path rhs
66        if rhs.len() == 1 {
67            let rhs = rhs.get(0);
68            let mut buf = vec![];
69            return match rhs {
70                Some(rhs) => {
71                    self.apply_mut(|s| {
72                        concat_binary_arrs(s, rhs, &mut buf);
73                        let out = buf.as_slice();
74                        // SAFETY: lifetime is bound to the outer scope and the
75                        // ref is valid for the lifetime of this closure.
76                        unsafe { std::mem::transmute::<_, &'static [u8]>(out) }
77                    })
78                },
79                None => BinaryChunked::full_null(self.name().clone(), self.len()),
80            };
81        }
82        // broadcasting path lhs
83        if self.len() == 1 {
84            let lhs = self.get(0);
85            let mut buf = vec![];
86            return match lhs {
87                Some(lhs) => rhs.apply_mut(|s| {
88                    concat_binary_arrs(lhs, s, &mut buf);
89                    let out = buf.as_slice();
90                    // SAFETY: lifetime is bound to the outer scope and the
91                    // ref is valid for the lifetime of this closure.
92                    unsafe { std::mem::transmute::<_, &'static [u8]>(out) }
93                }),
94                None => BinaryChunked::full_null(self.name().clone(), rhs.len()),
95            };
96        }
97
98        arity::binary(self, rhs, concat_binview)
99    }
100}
101
102impl Add for BinaryChunked {
103    type Output = BinaryChunked;
104
105    fn add(self, rhs: Self) -> Self::Output {
106        (&self).add(&rhs)
107    }
108}
109
110impl Add<&[u8]> for &BinaryChunked {
111    type Output = BinaryChunked;
112
113    fn add(self, rhs: &[u8]) -> Self::Output {
114        let arr = BinaryViewArray::from_slice_values([rhs]);
115        let rhs: BinaryChunked = arr.into();
116        self.add(&rhs)
117    }
118}
119
120fn add_boolean(a: &BooleanArray, b: &BooleanArray) -> PrimitiveArray<IdxSize> {
121    let validity = combine_validities_and(a.validity(), b.validity());
122
123    let values = a
124        .values_iter()
125        .zip(b.values_iter())
126        .map(|(a, b)| a as IdxSize + b as IdxSize)
127        .collect::<Vec<_>>();
128    PrimitiveArray::from_data_default(values.into(), validity)
129}
130
131impl Add for &BooleanChunked {
132    type Output = IdxCa;
133
134    fn add(self, rhs: Self) -> Self::Output {
135        // Broadcasting path rhs.
136        if rhs.len() == 1 {
137            let rhs = rhs.get(0);
138            return match rhs {
139                Some(rhs) => unary_elementwise_values(self, |v| v as IdxSize + rhs as IdxSize),
140                None => IdxCa::full_null(self.name().clone(), self.len()),
141            };
142        }
143        // Broadcasting path lhs.
144        if self.len() == 1 {
145            return rhs.add(self);
146        }
147        arity::binary(self, rhs, add_boolean)
148    }
149}
150
151impl Add for BooleanChunked {
152    type Output = IdxCa;
153
154    fn add(self, rhs: Self) -> Self::Output {
155        (&self).add(&rhs)
156    }
157}
158
159#[cfg(test)]
160pub(crate) mod test {
161    use crate::prelude::*;
162
163    pub(crate) fn create_two_chunked() -> (Int32Chunked, Int32Chunked) {
164        let mut a1 = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
165        let a2 = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 5, 6]);
166        let a3 = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3, 4, 5, 6]);
167        a1.append(&a2).unwrap();
168        (a1, a3)
169    }
170
171    #[test]
172    #[allow(clippy::eq_op)]
173    fn test_chunk_mismatch() {
174        let (a1, a2) = create_two_chunked();
175        // With different chunks.
176        let _ = &a1 + &a2;
177        let _ = &a1 - &a2;
178        let _ = &a1 / &a2;
179        let _ = &a1 * &a2;
180
181        // With same chunks.
182        let _ = &a1 + &a1;
183        let _ = &a1 - &a1;
184        let _ = &a1 / &a1;
185        let _ = &a1 * &a1;
186    }
187}