polars_ops/series/ops/
horizontal.rs

1use std::borrow::Cow;
2
3use polars_core::chunked_array::cast::CastOptions;
4use polars_core::prelude::*;
5use polars_core::series::arithmetic::coerce_lhs_rhs;
6use polars_core::utils::dtypes_to_supertype;
7use polars_core::{POOL, with_match_physical_numeric_polars_type};
8use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
9
10fn validate_column_lengths(cs: &[Column]) -> PolarsResult<()> {
11    let mut length = 1;
12    for c in cs {
13        let len = c.len();
14        if len != 1 && len != length {
15            if length == 1 {
16                length = len;
17            } else {
18                polars_bail!(ShapeMismatch: "cannot evaluate two Series of different lengths ({len} and {length})");
19            }
20        }
21    }
22    Ok(())
23}
24
25pub trait MinMaxHorizontal {
26    /// Aggregate the column horizontally to their min values.
27    fn min_horizontal(&self) -> PolarsResult<Option<Column>>;
28    /// Aggregate the column horizontally to their max values.
29    fn max_horizontal(&self) -> PolarsResult<Option<Column>>;
30}
31
32impl MinMaxHorizontal for DataFrame {
33    fn min_horizontal(&self) -> PolarsResult<Option<Column>> {
34        min_horizontal(self.get_columns())
35    }
36    fn max_horizontal(&self) -> PolarsResult<Option<Column>> {
37        max_horizontal(self.get_columns())
38    }
39}
40
41#[derive(Copy, Clone, Debug, PartialEq)]
42pub enum NullStrategy {
43    Ignore,
44    Propagate,
45}
46
47pub trait SumMeanHorizontal {
48    /// Sum all values horizontally across columns.
49    fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
50
51    /// Compute the mean of all numeric values horizontally across columns.
52    fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
53}
54
55impl SumMeanHorizontal for DataFrame {
56    fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
57        sum_horizontal(self.get_columns(), null_strategy)
58    }
59    fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
60        mean_horizontal(self.get_columns(), null_strategy)
61    }
62}
63
64fn min_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
65where
66    T: PolarsNumericType,
67    T::Native: PartialOrd,
68{
69    let op = |l: T::Native, r: T::Native| {
70        if l < r { l } else { r }
71    };
72    arity::binary_elementwise_values(left, right, op)
73}
74
75fn max_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
76where
77    T: PolarsNumericType,
78    T::Native: PartialOrd,
79{
80    let op = |l: T::Native, r: T::Native| {
81        if l > r { l } else { r }
82    };
83    arity::binary_elementwise_values(left, right, op)
84}
85
86fn min_max_binary_columns(left: &Column, right: &Column, min: bool) -> PolarsResult<Column> {
87    if left.dtype().to_physical().is_primitive_numeric()
88        && left.null_count() == 0
89        && right.null_count() == 0
90        && left.len() == right.len()
91    {
92        match (left, right) {
93            (Column::Series(left), Column::Series(right)) => {
94                let (lhs, rhs) = coerce_lhs_rhs(left, right)?;
95                let logical = lhs.dtype();
96                let lhs = lhs.to_physical_repr();
97                let rhs = rhs.to_physical_repr();
98
99                with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| {
100                    let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
101                    let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref();
102
103                    unsafe {
104                        if min {
105                            min_binary(a, b).into_series().from_physical_unchecked(logical)
106                        } else {
107                            max_binary(a, b).into_series().from_physical_unchecked(logical)
108                        }
109                    }
110                })
111                .map(Column::from)
112            },
113            _ => {
114                let mask = if min {
115                    left.lt(right)?
116                } else {
117                    left.gt(right)?
118                };
119
120                left.zip_with(&mask, right)
121            },
122        }
123    } else {
124        let mask = if min {
125            left.lt(right)? & left.is_not_null() | right.is_null()
126        } else {
127            left.gt(right)? & left.is_not_null() | right.is_null()
128        };
129        left.zip_with(&mask, right)
130    }
131}
132
133pub fn max_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
134    validate_column_lengths(columns)?;
135
136    let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false);
137
138    match columns.len() {
139        0 => Ok(None),
140        1 => Ok(Some(columns[0].clone())),
141        2 => max_fn(&columns[0], &columns[1]).map(Some),
142        _ => {
143            // the try_reduce_with is a bit slower in parallelism,
144            // but I don't think it matters here as we parallelize over columns, not over elements
145            POOL.install(|| {
146                columns
147                    .par_iter()
148                    .map(|s| Ok(Cow::Borrowed(s)))
149                    .try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned))
150                    // we can unwrap the option, because we are certain there is a column
151                    // we started this operation on 3 columns
152                    .unwrap()
153                    .map(|cow| Some(cow.into_owned()))
154            })
155        },
156    }
157}
158
159pub fn min_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
160    validate_column_lengths(columns)?;
161
162    let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true);
163
164    match columns.len() {
165        0 => Ok(None),
166        1 => Ok(Some(columns[0].clone())),
167        2 => min_fn(&columns[0], &columns[1]).map(Some),
168        _ => {
169            // the try_reduce_with is a bit slower in parallelism,
170            // but I don't think it matters here as we parallelize over columns, not over elements
171            POOL.install(|| {
172                columns
173                    .par_iter()
174                    .map(|s| Ok(Cow::Borrowed(s)))
175                    .try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned))
176                    // we can unwrap the option, because we are certain there is a column
177                    // we started this operation on 3 columns
178                    .unwrap()
179                    .map(|cow| Some(cow.into_owned()))
180            })
181        },
182    }
183}
184
185pub fn sum_horizontal(
186    columns: &[Column],
187    null_strategy: NullStrategy,
188) -> PolarsResult<Option<Column>> {
189    validate_column_lengths(columns)?;
190    let ignore_nulls = null_strategy == NullStrategy::Ignore;
191
192    let apply_null_strategy = |s: Series| -> PolarsResult<Series> {
193        if ignore_nulls && s.null_count() > 0 {
194            s.fill_null(FillNullStrategy::Zero)
195        } else {
196            Ok(s)
197        }
198    };
199
200    let sum_fn = |acc: Series, s: Series| -> PolarsResult<Series> {
201        let acc: Series = apply_null_strategy(acc)?;
202        let s = apply_null_strategy(s)?;
203        // This will do owned arithmetic and can be mutable
204        std::ops::Add::add(acc, s)
205    };
206
207    // @scalar-opt
208    let non_null_cols = columns
209        .iter()
210        .filter(|x| x.dtype() != &DataType::Null)
211        .map(|c| c.as_materialized_series())
212        .collect::<Vec<_>>();
213
214    // If we have any null columns and null strategy is not `Ignore`, we can return immediately.
215    if !ignore_nulls && non_null_cols.len() < columns.len() {
216        // We must determine the correct return dtype.
217        let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
218            DataType::Boolean => IDX_DTYPE,
219            dt => dt,
220        };
221        return Ok(Some(Column::full_null(
222            columns[0].name().clone(),
223            columns[0].len(),
224            &return_dtype,
225        )));
226    }
227
228    match non_null_cols.len() {
229        0 => {
230            if columns.is_empty() {
231                Ok(None)
232            } else {
233                // all columns are null dtype, so result is null dtype
234                Ok(Some(columns[0].clone()))
235            }
236        },
237        1 => Ok(Some(
238            apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {
239                non_null_cols[0].cast(&IDX_DTYPE)?
240            } else {
241                non_null_cols[0].clone()
242            })?
243            .into(),
244        )),
245        2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone())
246            .map(Column::from)
247            .map(Some),
248        _ => {
249            // the try_reduce_with is a bit slower in parallelism,
250            // but I don't think it matters here as we parallelize over columns, not over elements
251            let out = POOL.install(|| {
252                non_null_cols
253                    .into_par_iter()
254                    .cloned()
255                    .map(Ok)
256                    .try_reduce_with(sum_fn)
257                    // We can unwrap because we started with at least 3 columns, so we always get a Some
258                    .unwrap()
259            });
260            out.map(Column::from).map(Some)
261        },
262    }
263}
264
265pub fn mean_horizontal(
266    columns: &[Column],
267    null_strategy: NullStrategy,
268) -> PolarsResult<Option<Column>> {
269    validate_column_lengths(columns)?;
270
271    let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| {
272        let dtype = s.dtype();
273        dtype.is_primitive_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null()
274    });
275
276    if !non_numeric_columns.is_empty() {
277        let col = non_numeric_columns.first().cloned();
278        polars_bail!(
279            InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})",
280            col.unwrap().name(),
281            col.unwrap().dtype(),
282        );
283    }
284    let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();
285    let num_rows = columns.len();
286    match num_rows {
287        0 => Ok(None),
288        1 => Ok(Some(match columns[0].dtype() {
289            dt if dt != &DataType::Float32 && !dt.is_decimal() => {
290                columns[0].cast(&DataType::Float64)?
291            },
292            _ => columns[0].clone(),
293        })),
294        _ => {
295            let sum = || sum_horizontal(columns.as_slice(), null_strategy);
296            let null_count = || {
297                columns
298                    .par_iter()
299                    .map(|c| {
300                        c.is_null()
301                            .into_column()
302                            .cast_with_options(&DataType::UInt32, CastOptions::NonStrict)
303                    })
304                    .reduce_with(|l, r| {
305                        let l = l?;
306                        let r = r?;
307                        let result = std::ops::Add::add(&l, &r)?;
308                        PolarsResult::Ok(result)
309                    })
310                    // we can unwrap the option, because we are certain there is a column
311                    // we started this operation on 2 columns
312                    .unwrap()
313            };
314
315            let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count));
316            let sum = sum?;
317            let null_count = null_count?;
318
319            // value lengths: len - null_count
320            let value_length: UInt32Chunked = (Column::new_scalar(
321                PlSmallStr::EMPTY,
322                Scalar::from(num_rows as u32),
323                null_count.len(),
324            ) - null_count)?
325                .u32()
326                .unwrap()
327                .clone();
328
329            // make sure that we do not divide by zero
330            // by replacing with None
331            let dt = if sum
332                .as_ref()
333                .is_some_and(|s| s.dtype() == &DataType::Float32)
334            {
335                &DataType::Float32
336            } else {
337                &DataType::Float64
338            };
339            let value_length = value_length
340                .set(&value_length.equal(0), None)?
341                .into_column()
342                .cast(dt)?;
343
344            sum.map(|sum| std::ops::Div::div(&sum, &value_length))
345                .transpose()
346        },
347    }
348}
349
350pub fn coalesce_columns(s: &[Column]) -> PolarsResult<Column> {
351    // TODO! this can be faster if we have more than two inputs.
352    polars_ensure!(!s.is_empty(), NoData: "cannot coalesce empty list");
353    let mut out = s[0].clone();
354    for s in s {
355        if !out.null_count() == 0 {
356            return Ok(out);
357        } else {
358            let mask = out.is_not_null();
359            out = out
360                .as_materialized_series()
361                .zip_with_same_type(&mask, s.as_materialized_series())?
362                .into();
363        }
364    }
365    Ok(out)
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    #[cfg_attr(miri, ignore)]
374    fn test_horizontal_agg() {
375        let a = Column::new("a".into(), [1, 2, 6]);
376        let b = Column::new("b".into(), [Some(1), None, None]);
377        let c = Column::new("c".into(), [Some(4), None, Some(3)]);
378
379        let df = DataFrame::new(vec![a, b, c]).unwrap();
380        assert_eq!(
381            Vec::from(
382                df.mean_horizontal(NullStrategy::Ignore)
383                    .unwrap()
384                    .unwrap()
385                    .f64()
386                    .unwrap()
387            ),
388            &[Some(2.0), Some(2.0), Some(4.5)]
389        );
390        assert_eq!(
391            Vec::from(
392                df.sum_horizontal(NullStrategy::Ignore)
393                    .unwrap()
394                    .unwrap()
395                    .i32()
396                    .unwrap()
397            ),
398            &[Some(6), Some(2), Some(9)]
399        );
400        assert_eq!(
401            Vec::from(df.min_horizontal().unwrap().unwrap().i32().unwrap()),
402            &[Some(1), Some(2), Some(3)]
403        );
404        assert_eq!(
405            Vec::from(df.max_horizontal().unwrap().unwrap().i32().unwrap()),
406            &[Some(4), Some(2), Some(6)]
407        );
408    }
409}