polars_ops/series/ops/
horizontal.rs

1use std::borrow::Cow;
2
3use polars_core::chunked_array::cast::CastOptions;
4use polars_core::prelude::*;
5use polars_core::series::arithmetic::coerce_lhs_rhs;
6use polars_core::utils::dtypes_to_supertype;
7use polars_core::{POOL, with_match_physical_numeric_polars_type};
8use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
9
10fn validate_column_lengths(cs: &[Column]) -> PolarsResult<()> {
11    let mut length = 1;
12    for c in cs {
13        let len = c.len();
14        if len != 1 && len != length {
15            if length == 1 {
16                length = len;
17            } else {
18                polars_bail!(ShapeMismatch: "cannot evaluate two Series of different lengths ({len} and {length})");
19            }
20        }
21    }
22    Ok(())
23}
24
25pub trait MinMaxHorizontal {
26    /// Aggregate the column horizontally to their min values.
27    fn min_horizontal(&self) -> PolarsResult<Option<Column>>;
28    /// Aggregate the column horizontally to their max values.
29    fn max_horizontal(&self) -> PolarsResult<Option<Column>>;
30}
31
32impl MinMaxHorizontal for DataFrame {
33    fn min_horizontal(&self) -> PolarsResult<Option<Column>> {
34        min_horizontal(self.get_columns())
35    }
36    fn max_horizontal(&self) -> PolarsResult<Option<Column>> {
37        max_horizontal(self.get_columns())
38    }
39}
40
41#[derive(Copy, Clone, Debug, PartialEq)]
42pub enum NullStrategy {
43    Ignore,
44    Propagate,
45}
46
47pub trait SumMeanHorizontal {
48    /// Sum all values horizontally across columns.
49    fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
50
51    /// Compute the mean of all numeric values horizontally across columns.
52    fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
53}
54
55impl SumMeanHorizontal for DataFrame {
56    fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
57        sum_horizontal(self.get_columns(), null_strategy)
58    }
59    fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
60        mean_horizontal(self.get_columns(), null_strategy)
61    }
62}
63
64fn min_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
65where
66    T: PolarsNumericType,
67    T::Native: PartialOrd,
68{
69    let op = |l: T::Native, r: T::Native| {
70        if l < r { l } else { r }
71    };
72    arity::binary_elementwise_values(left, right, op)
73}
74
75fn max_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
76where
77    T: PolarsNumericType,
78    T::Native: PartialOrd,
79{
80    let op = |l: T::Native, r: T::Native| {
81        if l > r { l } else { r }
82    };
83    arity::binary_elementwise_values(left, right, op)
84}
85
86fn min_max_binary_columns(left: &Column, right: &Column, min: bool) -> PolarsResult<Column> {
87    if left.dtype().to_physical().is_primitive_numeric()
88        && right.dtype().to_physical().is_primitive_numeric()
89        && left.null_count() == 0
90        && right.null_count() == 0
91        && left.len() == right.len()
92    {
93        match (left, right) {
94            (Column::Series(left), Column::Series(right)) => {
95                let (lhs, rhs) = coerce_lhs_rhs(left, right)?;
96                let logical = lhs.dtype();
97                let lhs = lhs.to_physical_repr();
98                let rhs = rhs.to_physical_repr();
99
100                with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| {
101                    let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
102                    let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref();
103
104                    unsafe {
105                        if min {
106                            min_binary(a, b).into_series().from_physical_unchecked(logical)
107                        } else {
108                            max_binary(a, b).into_series().from_physical_unchecked(logical)
109                        }
110                    }
111                })
112                .map(Column::from)
113            },
114            _ => {
115                let mask = if min {
116                    left.lt(right)?
117                } else {
118                    left.gt(right)?
119                };
120
121                left.zip_with(&mask, right)
122            },
123        }
124    } else {
125        let mask = if min {
126            left.lt(right)? & left.is_not_null() | right.is_null()
127        } else {
128            left.gt(right)? & left.is_not_null() | right.is_null()
129        };
130        left.zip_with(&mask, right)
131    }
132}
133
134pub fn max_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
135    validate_column_lengths(columns)?;
136
137    let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false);
138
139    match columns.len() {
140        0 => Ok(None),
141        1 => Ok(Some(columns[0].clone())),
142        2 => max_fn(&columns[0], &columns[1]).map(Some),
143        _ => {
144            // the try_reduce_with is a bit slower in parallelism,
145            // but I don't think it matters here as we parallelize over columns, not over elements
146            POOL.install(|| {
147                columns
148                    .par_iter()
149                    .map(|s| Ok(Cow::Borrowed(s)))
150                    .try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned))
151                    // we can unwrap the option, because we are certain there is a column
152                    // we started this operation on 3 columns
153                    .unwrap()
154                    .map(|cow| Some(cow.into_owned()))
155            })
156        },
157    }
158}
159
160pub fn min_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
161    validate_column_lengths(columns)?;
162
163    let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true);
164
165    match columns.len() {
166        0 => Ok(None),
167        1 => Ok(Some(columns[0].clone())),
168        2 => min_fn(&columns[0], &columns[1]).map(Some),
169        _ => {
170            // the try_reduce_with is a bit slower in parallelism,
171            // but I don't think it matters here as we parallelize over columns, not over elements
172            POOL.install(|| {
173                columns
174                    .par_iter()
175                    .map(|s| Ok(Cow::Borrowed(s)))
176                    .try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned))
177                    // we can unwrap the option, because we are certain there is a column
178                    // we started this operation on 3 columns
179                    .unwrap()
180                    .map(|cow| Some(cow.into_owned()))
181            })
182        },
183    }
184}
185
186pub fn sum_horizontal(
187    columns: &[Column],
188    null_strategy: NullStrategy,
189) -> PolarsResult<Option<Column>> {
190    validate_column_lengths(columns)?;
191    let ignore_nulls = null_strategy == NullStrategy::Ignore;
192
193    let apply_null_strategy = |s: Series| -> PolarsResult<Series> {
194        if ignore_nulls && s.null_count() > 0 {
195            s.fill_null(FillNullStrategy::Zero)
196        } else {
197            Ok(s)
198        }
199    };
200
201    let sum_fn = |acc: Series, s: Series| -> PolarsResult<Series> {
202        let acc: Series = apply_null_strategy(acc)?;
203        let s = apply_null_strategy(s)?;
204        // This will do owned arithmetic and can be mutable
205        std::ops::Add::add(acc, s)
206    };
207
208    // @scalar-opt
209    let non_null_cols = columns
210        .iter()
211        .filter(|x| x.dtype() != &DataType::Null)
212        .map(|c| c.as_materialized_series())
213        .collect::<Vec<_>>();
214
215    // If we have any null columns and null strategy is not `Ignore`, we can return immediately.
216    if !ignore_nulls && non_null_cols.len() < columns.len() {
217        // We must determine the correct return dtype.
218        let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
219            DataType::Boolean => IDX_DTYPE,
220            dt => dt,
221        };
222        return Ok(Some(Column::full_null(
223            columns[0].name().clone(),
224            columns[0].len(),
225            &return_dtype,
226        )));
227    }
228
229    match non_null_cols.len() {
230        0 => {
231            if columns.is_empty() {
232                Ok(None)
233            } else {
234                // all columns are null dtype, so result is null dtype
235                Ok(Some(columns[0].clone()))
236            }
237        },
238        1 => Ok(Some(
239            apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {
240                non_null_cols[0].cast(&IDX_DTYPE)?
241            } else {
242                non_null_cols[0].clone()
243            })?
244            .into(),
245        )),
246        2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone())
247            .map(Column::from)
248            .map(Some),
249        _ => {
250            // the try_reduce_with is a bit slower in parallelism,
251            // but I don't think it matters here as we parallelize over columns, not over elements
252            let out = POOL.install(|| {
253                non_null_cols
254                    .into_par_iter()
255                    .cloned()
256                    .map(Ok)
257                    .try_reduce_with(sum_fn)
258                    // We can unwrap because we started with at least 3 columns, so we always get a Some
259                    .unwrap()
260            });
261            out.map(Column::from).map(Some)
262        },
263    }
264}
265
266pub fn mean_horizontal(
267    columns: &[Column],
268    null_strategy: NullStrategy,
269) -> PolarsResult<Option<Column>> {
270    validate_column_lengths(columns)?;
271
272    let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| {
273        let dtype = s.dtype();
274        dtype.is_primitive_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null()
275    });
276
277    if !non_numeric_columns.is_empty() {
278        let col = non_numeric_columns.first().cloned();
279        polars_bail!(
280            InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})",
281            col.unwrap().name(),
282            col.unwrap().dtype(),
283        );
284    }
285    let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();
286    let num_rows = columns.len();
287    match num_rows {
288        0 => Ok(None),
289        1 => Ok(Some(match columns[0].dtype() {
290            dt if dt != &DataType::Float32 && !dt.is_decimal() => {
291                columns[0].cast(&DataType::Float64)?
292            },
293            _ => columns[0].clone(),
294        })),
295        _ => {
296            let sum = || sum_horizontal(columns.as_slice(), null_strategy);
297            let null_count = || {
298                columns
299                    .par_iter()
300                    .map(|c| {
301                        c.is_null()
302                            .into_column()
303                            .cast_with_options(&DataType::UInt32, CastOptions::NonStrict)
304                    })
305                    .reduce_with(|l, r| {
306                        let l = l?;
307                        let r = r?;
308                        let result = std::ops::Add::add(&l, &r)?;
309                        PolarsResult::Ok(result)
310                    })
311                    // we can unwrap the option, because we are certain there is a column
312                    // we started this operation on 2 columns
313                    .unwrap()
314            };
315
316            let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count));
317            let sum = sum?;
318            let null_count = null_count?;
319
320            // value lengths: len - null_count
321            let value_length: UInt32Chunked = (Column::new_scalar(
322                PlSmallStr::EMPTY,
323                Scalar::from(num_rows as u32),
324                null_count.len(),
325            ) - null_count)?
326                .u32()
327                .unwrap()
328                .clone();
329
330            // make sure that we do not divide by zero
331            // by replacing with None
332            let dt = if sum
333                .as_ref()
334                .is_some_and(|s| s.dtype() == &DataType::Float32)
335            {
336                &DataType::Float32
337            } else {
338                &DataType::Float64
339            };
340            let value_length = value_length
341                .set(&value_length.equal(0), None)?
342                .into_column()
343                .cast(dt)?;
344
345            sum.map(|sum| std::ops::Div::div(&sum, &value_length))
346                .transpose()
347        },
348    }
349}
350
351pub fn coalesce_columns(s: &[Column]) -> PolarsResult<Column> {
352    // TODO! this can be faster if we have more than two inputs.
353    polars_ensure!(!s.is_empty(), NoData: "cannot coalesce empty list");
354    let mut out = s[0].clone();
355    for s in s {
356        if !out.null_count() == 0 {
357            return Ok(out);
358        } else {
359            let mask = out.is_not_null();
360            out = out
361                .as_materialized_series()
362                .zip_with_same_type(&mask, s.as_materialized_series())?
363                .into();
364        }
365    }
366    Ok(out)
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    #[cfg_attr(miri, ignore)]
375    fn test_horizontal_agg() {
376        let a = Column::new("a".into(), [1, 2, 6]);
377        let b = Column::new("b".into(), [Some(1), None, None]);
378        let c = Column::new("c".into(), [Some(4), None, Some(3)]);
379
380        let df = DataFrame::new(vec![a, b, c]).unwrap();
381        assert_eq!(
382            Vec::from(
383                df.mean_horizontal(NullStrategy::Ignore)
384                    .unwrap()
385                    .unwrap()
386                    .f64()
387                    .unwrap()
388            ),
389            &[Some(2.0), Some(2.0), Some(4.5)]
390        );
391        assert_eq!(
392            Vec::from(
393                df.sum_horizontal(NullStrategy::Ignore)
394                    .unwrap()
395                    .unwrap()
396                    .i32()
397                    .unwrap()
398            ),
399            &[Some(6), Some(2), Some(9)]
400        );
401        assert_eq!(
402            Vec::from(df.min_horizontal().unwrap().unwrap().i32().unwrap()),
403            &[Some(1), Some(2), Some(3)]
404        );
405        assert_eq!(
406            Vec::from(df.max_horizontal().unwrap().unwrap().i32().unwrap()),
407            &[Some(4), Some(2), Some(6)]
408        );
409    }
410}