polars_core/frame/
arithmetic.rs

1use std::ops::{Add, Div, Mul, Rem, Sub};
2
3use rayon::prelude::*;
4
5use crate::POOL;
6use crate::prelude::*;
7use crate::utils::try_get_supertype;
8
9/// Get the supertype that is valid for all columns in the [`DataFrame`].
10/// This reduces casting of the rhs in arithmetic.
11fn get_supertype_all(df: &DataFrame, rhs: &Series) -> PolarsResult<DataType> {
12    df.columns.iter().try_fold(rhs.dtype().clone(), |dt, s| {
13        try_get_supertype(s.dtype(), &dt)
14    })
15}
16
17macro_rules! impl_arithmetic {
18    ($self:expr, $rhs:expr, $operand:expr) => {{
19        let st = get_supertype_all($self, $rhs)?;
20        let rhs = $rhs.cast(&st)?;
21        let cols = POOL.install(|| {
22            $self
23                .par_materialized_column_iter()
24                .map(|s| $operand(&s.cast(&st)?, &rhs))
25                .map(|s| s.map(Column::from))
26                .collect::<PolarsResult<_>>()
27        })?;
28        Ok(unsafe { DataFrame::new_no_checks($self.height(), cols) })
29    }};
30}
31
32impl Add<&Series> for &DataFrame {
33    type Output = PolarsResult<DataFrame>;
34
35    fn add(self, rhs: &Series) -> Self::Output {
36        impl_arithmetic!(self, rhs, std::ops::Add::add)
37    }
38}
39
40impl Add<&Series> for DataFrame {
41    type Output = PolarsResult<DataFrame>;
42
43    fn add(self, rhs: &Series) -> Self::Output {
44        (&self).add(rhs)
45    }
46}
47
48impl Sub<&Series> for &DataFrame {
49    type Output = PolarsResult<DataFrame>;
50
51    fn sub(self, rhs: &Series) -> Self::Output {
52        impl_arithmetic!(self, rhs, std::ops::Sub::sub)
53    }
54}
55
56impl Sub<&Series> for DataFrame {
57    type Output = PolarsResult<DataFrame>;
58
59    fn sub(self, rhs: &Series) -> Self::Output {
60        (&self).sub(rhs)
61    }
62}
63
64impl Mul<&Series> for &DataFrame {
65    type Output = PolarsResult<DataFrame>;
66
67    fn mul(self, rhs: &Series) -> Self::Output {
68        impl_arithmetic!(self, rhs, std::ops::Mul::mul)
69    }
70}
71
72impl Mul<&Series> for DataFrame {
73    type Output = PolarsResult<DataFrame>;
74
75    fn mul(self, rhs: &Series) -> Self::Output {
76        (&self).mul(rhs)
77    }
78}
79
80impl Div<&Series> for &DataFrame {
81    type Output = PolarsResult<DataFrame>;
82
83    fn div(self, rhs: &Series) -> Self::Output {
84        impl_arithmetic!(self, rhs, std::ops::Div::div)
85    }
86}
87
88impl Div<&Series> for DataFrame {
89    type Output = PolarsResult<DataFrame>;
90
91    fn div(self, rhs: &Series) -> Self::Output {
92        (&self).div(rhs)
93    }
94}
95
96impl Rem<&Series> for &DataFrame {
97    type Output = PolarsResult<DataFrame>;
98
99    fn rem(self, rhs: &Series) -> Self::Output {
100        impl_arithmetic!(self, rhs, std::ops::Rem::rem)
101    }
102}
103
104impl Rem<&Series> for DataFrame {
105    type Output = PolarsResult<DataFrame>;
106
107    fn rem(self, rhs: &Series) -> Self::Output {
108        (&self).rem(rhs)
109    }
110}
111
112impl DataFrame {
113    fn binary_aligned(
114        &self,
115        other: &DataFrame,
116        f: &(dyn Fn(&Series, &Series) -> PolarsResult<Series> + Sync + Send),
117    ) -> PolarsResult<DataFrame> {
118        let max_len = std::cmp::max(self.height(), other.height());
119        let max_width = std::cmp::max(self.width(), other.width());
120        let cols = self
121            .get_columns()
122            .par_iter()
123            .zip(other.get_columns().par_iter())
124            .map(|(l, r)| {
125                let l = l.as_materialized_series();
126                let r = r.as_materialized_series();
127
128                let diff_l = max_len - l.len();
129                let diff_r = max_len - r.len();
130
131                let st = try_get_supertype(l.dtype(), r.dtype())?;
132                let mut l = l.cast(&st)?;
133                let mut r = r.cast(&st)?;
134
135                if diff_l > 0 {
136                    l = l.extend_constant(AnyValue::Null, diff_l)?;
137                };
138                if diff_r > 0 {
139                    r = r.extend_constant(AnyValue::Null, diff_r)?;
140                };
141
142                f(&l, &r).map(Column::from)
143            });
144        let mut cols = POOL.install(|| cols.collect::<PolarsResult<Vec<_>>>())?;
145
146        let col_len = cols.len();
147        if col_len < max_width {
148            let df = if col_len < self.width() { self } else { other };
149
150            for i in col_len..max_len {
151                let s = &df.get_columns().get(i).ok_or_else(|| polars_err!(InvalidOperation: "cannot do arithmetic on DataFrames with shapes: {:?} and {:?}", self.shape(), other.shape()))?;
152                let name = s.name();
153                let dtype = s.dtype();
154
155                // trick to fill a series with nulls
156                let vals: &[Option<i32>] = &[None];
157                let s = Series::new(name.clone(), vals).cast(dtype)?;
158                cols.push(s.new_from_index(0, max_len).into())
159            }
160        }
161        DataFrame::new(cols)
162    }
163}
164
165impl Add<&DataFrame> for &DataFrame {
166    type Output = PolarsResult<DataFrame>;
167
168    fn add(self, rhs: &DataFrame) -> Self::Output {
169        self.binary_aligned(rhs, &|a, b| a + b)
170    }
171}
172
173impl Sub<&DataFrame> for &DataFrame {
174    type Output = PolarsResult<DataFrame>;
175
176    fn sub(self, rhs: &DataFrame) -> Self::Output {
177        self.binary_aligned(rhs, &|a, b| a - b)
178    }
179}
180
181impl Div<&DataFrame> for &DataFrame {
182    type Output = PolarsResult<DataFrame>;
183
184    fn div(self, rhs: &DataFrame) -> Self::Output {
185        self.binary_aligned(rhs, &|a, b| a / b)
186    }
187}
188
189impl Mul<&DataFrame> for &DataFrame {
190    type Output = PolarsResult<DataFrame>;
191
192    fn mul(self, rhs: &DataFrame) -> Self::Output {
193        self.binary_aligned(rhs, &|a, b| a * b)
194    }
195}
196
197impl Rem<&DataFrame> for &DataFrame {
198    type Output = PolarsResult<DataFrame>;
199
200    fn rem(self, rhs: &DataFrame) -> Self::Output {
201        self.binary_aligned(rhs, &|a, b| a % b)
202    }
203}