Skip to main content

polars_ops/series/ops/
horizontal.rs

1use std::borrow::Cow;
2
3use polars_core::chunked_array::cast::CastOptions;
4use polars_core::prelude::*;
5use polars_core::runtime::RAYON;
6use polars_core::series::arithmetic::coerce_lhs_rhs;
7use polars_core::utils::dtypes_to_supertype;
8use polars_core::with_match_physical_numeric_polars_type;
9use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
10
11fn validate_column_lengths(cs: &[Column]) -> PolarsResult<()> {
12    let mut length = 1;
13    for c in cs {
14        let len = c.len();
15        if len != 1 && len != length {
16            if length == 1 {
17                length = len;
18            } else {
19                polars_bail!(ShapeMismatch: "cannot evaluate two Series of different lengths ({len} and {length})");
20            }
21        }
22    }
23    Ok(())
24}
25
26pub trait MinMaxHorizontal {
27    /// Aggregate the column horizontally to their min values.
28    fn min_horizontal(&self) -> PolarsResult<Option<Column>>;
29    /// Aggregate the column horizontally to their max values.
30    fn max_horizontal(&self) -> PolarsResult<Option<Column>>;
31}
32
33impl MinMaxHorizontal for DataFrame {
34    fn min_horizontal(&self) -> PolarsResult<Option<Column>> {
35        min_horizontal(self.columns())
36    }
37    fn max_horizontal(&self) -> PolarsResult<Option<Column>> {
38        max_horizontal(self.columns())
39    }
40}
41
42#[derive(Copy, Clone, Debug, PartialEq)]
43pub enum NullStrategy {
44    Ignore,
45    Propagate,
46}
47
48pub trait SumMeanHorizontal {
49    /// Sum all values horizontally across columns.
50    fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
51
52    /// Compute the mean of all numeric values horizontally across columns.
53    fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
54}
55
56impl SumMeanHorizontal for DataFrame {
57    fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
58        sum_horizontal(self.columns(), null_strategy)
59    }
60    fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
61        mean_horizontal(self.columns(), null_strategy)
62    }
63}
64
65fn min_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
66where
67    T: PolarsNumericType,
68    T::Native: PartialOrd,
69{
70    let op = |l: T::Native, r: T::Native| {
71        if l < r { l } else { r }
72    };
73    arity::binary_elementwise_values(left, right, op)
74}
75
76fn max_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
77where
78    T: PolarsNumericType,
79    T::Native: PartialOrd,
80{
81    let op = |l: T::Native, r: T::Native| {
82        if l > r { l } else { r }
83    };
84    arity::binary_elementwise_values(left, right, op)
85}
86
87fn min_max_binary_columns(left: &Column, right: &Column, min: bool) -> PolarsResult<Column> {
88    if left.dtype().to_physical().is_primitive_numeric()
89        && right.dtype().to_physical().is_primitive_numeric()
90        && left.null_count() == 0
91        && right.null_count() == 0
92        && left.len() == right.len()
93    {
94        match (left, right) {
95            (Column::Series(left), Column::Series(right)) => {
96                let (lhs, rhs) = coerce_lhs_rhs(left, right)?;
97                let logical = lhs.dtype();
98                let lhs = lhs.to_physical_repr();
99                let rhs = rhs.to_physical_repr();
100
101                with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| {
102                    let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
103                    let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref();
104
105                    unsafe {
106                        if min {
107                            min_binary(a, b).into_series().from_physical_unchecked(logical)
108                        } else {
109                            max_binary(a, b).into_series().from_physical_unchecked(logical)
110                        }
111                    }
112                })
113                .map(Column::from)
114            },
115            _ => {
116                let mask = if min {
117                    left.lt(right)?
118                } else {
119                    left.gt(right)?
120                };
121
122                left.zip_with(&mask, right)
123            },
124        }
125    } else {
126        let mask = if min {
127            left.lt(right)? & left.is_not_null() | right.is_null()
128        } else {
129            left.gt(right)? & left.is_not_null() | right.is_null()
130        };
131        left.zip_with(&mask, right)
132    }
133}
134
135pub fn max_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
136    validate_column_lengths(columns)?;
137
138    let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false);
139
140    match columns.len() {
141        0 => Ok(None),
142        1 => Ok(Some(columns[0].clone())),
143        2 => max_fn(&columns[0], &columns[1]).map(Some),
144        _ => {
145            // the try_reduce_with is a bit slower in parallelism,
146            // but I don't think it matters here as we parallelize over columns, not over elements
147            RAYON.install(|| {
148                columns
149                    .par_iter()
150                    .map(|s| Ok(Cow::Borrowed(s)))
151                    .try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned))
152                    // we can unwrap the option, because we are certain there is a column
153                    // we started this operation on 3 columns
154                    .unwrap()
155                    .map(|cow| Some(cow.into_owned()))
156            })
157        },
158    }
159}
160
161pub fn min_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
162    validate_column_lengths(columns)?;
163
164    let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true);
165
166    match columns.len() {
167        0 => Ok(None),
168        1 => Ok(Some(columns[0].clone())),
169        2 => min_fn(&columns[0], &columns[1]).map(Some),
170        _ => {
171            // the try_reduce_with is a bit slower in parallelism,
172            // but I don't think it matters here as we parallelize over columns, not over elements
173            RAYON.install(|| {
174                columns
175                    .par_iter()
176                    .map(|s| Ok(Cow::Borrowed(s)))
177                    .try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned))
178                    // we can unwrap the option, because we are certain there is a column
179                    // we started this operation on 3 columns
180                    .unwrap()
181                    .map(|cow| Some(cow.into_owned()))
182            })
183        },
184    }
185}
186
187pub fn sum_horizontal(
188    columns: &[Column],
189    null_strategy: NullStrategy,
190) -> PolarsResult<Option<Column>> {
191    validate_column_lengths(columns)?;
192    let ignore_nulls = null_strategy == NullStrategy::Ignore;
193
194    let apply_null_strategy = |s: Series| -> PolarsResult<Series> {
195        if ignore_nulls && s.null_count() > 0 {
196            s.fill_null(FillNullStrategy::Zero)
197        } else {
198            Ok(s)
199        }
200    };
201
202    let sum_fn = |acc: Series, s: Series| -> PolarsResult<Series> {
203        let acc: Series = apply_null_strategy(acc)?;
204        let s = apply_null_strategy(s)?;
205        // This will do owned arithmetic and can be mutable
206        std::ops::Add::add(acc, s)
207    };
208
209    // @scalar-opt
210    let non_null_cols = columns
211        .iter()
212        .filter(|x| x.dtype() != &DataType::Null)
213        .map(|c| c.as_materialized_series())
214        .collect::<Vec<_>>();
215
216    // If we have any null columns and null strategy is not `Ignore`, we can return immediately.
217    if !ignore_nulls && non_null_cols.len() < columns.len() {
218        // We must determine the correct return dtype.
219        let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
220            DataType::Boolean => IDX_DTYPE,
221            dt => dt,
222        };
223        return Ok(Some(Column::full_null(
224            columns[0].name().clone(),
225            columns[0].len(),
226            &return_dtype,
227        )));
228    }
229
230    match non_null_cols.len() {
231        0 => {
232            if columns.is_empty() {
233                Ok(None)
234            } else {
235                // all columns are null dtype, so result is null dtype
236                Ok(Some(columns[0].clone()))
237            }
238        },
239        1 => Ok(Some(
240            apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {
241                non_null_cols[0].cast(&IDX_DTYPE)?
242            } else {
243                non_null_cols[0].clone()
244            })?
245            .into(),
246        )),
247        2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone())
248            .map(Column::from)
249            .map(Some),
250        _ => {
251            // the try_reduce_with is a bit slower in parallelism,
252            // but I don't think it matters here as we parallelize over columns, not over elements
253            let out = RAYON.install(|| {
254                non_null_cols
255                    .into_par_iter()
256                    .cloned()
257                    .map(Ok)
258                    .try_reduce_with(sum_fn)
259                    // We can unwrap because we started with at least 3 columns, so we always get a Some
260                    .unwrap()
261            });
262            out.map(Column::from).map(Some)
263        },
264    }
265}
266
267pub fn mean_horizontal(
268    columns: &[Column],
269    null_strategy: NullStrategy,
270) -> PolarsResult<Option<Column>> {
271    validate_column_lengths(columns)?;
272
273    let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| {
274        let dtype = s.dtype();
275        dtype.is_primitive_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null()
276    });
277
278    if !non_numeric_columns.is_empty() {
279        let col = non_numeric_columns.first().cloned();
280        polars_bail!(
281            InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})",
282            col.unwrap().name(),
283            col.unwrap().dtype(),
284        );
285    }
286    let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();
287    let num_rows = columns.len();
288    match num_rows {
289        0 => Ok(None),
290        1 => Ok(Some(match columns[0].dtype() {
291            dt if !matches!(dt, DataType::Float16 | DataType::Float32) && !dt.is_decimal() => {
292                columns[0].cast(&DataType::Float64)?
293            },
294            _ => columns[0].clone(),
295        })),
296        _ => {
297            let sum = || sum_horizontal(columns.as_slice(), null_strategy);
298            let null_count = || {
299                columns
300                    .par_iter()
301                    .map(|c| {
302                        c.is_null()
303                            .into_column()
304                            .cast_with_options(&DataType::UInt32, CastOptions::NonStrict)
305                    })
306                    .reduce_with(|l, r| {
307                        let l = l?;
308                        let r = r?;
309                        let result = std::ops::Add::add(&l, &r)?;
310                        PolarsResult::Ok(result)
311                    })
312                    // we can unwrap the option, because we are certain there is a column
313                    // we started this operation on 2 columns
314                    .unwrap()
315            };
316
317            let (sum, null_count) = RAYON.install(|| rayon::join(sum, null_count));
318            let sum = sum?;
319            let null_count = null_count?;
320
321            // value lengths: len - null_count
322            let value_length: UInt32Chunked = (Column::new_scalar(
323                PlSmallStr::EMPTY,
324                Scalar::from(num_rows as u32),
325                null_count.len(),
326            ) - null_count)?
327                .u32()
328                .unwrap()
329                .clone();
330
331            // make sure that we do not divide by zero
332            // by replacing with None
333            let dt = sum
334                .as_ref()
335                .map(Column::dtype)
336                .filter(|dt| matches!(dt, DataType::Float16 | DataType::Float32))
337                .unwrap_or(&DataType::Float64);
338            let value_length = value_length
339                .set(&value_length.equal(0), None)?
340                .into_column()
341                .cast(dt)?;
342
343            sum.map(|sum| std::ops::Div::div(&sum, &value_length))
344                .transpose()
345        },
346    }
347}
348
349pub fn coalesce_columns(s: &[Column]) -> PolarsResult<Column> {
350    // TODO! this can be faster if we have more than two inputs.
351    polars_ensure!(!s.is_empty(), NoData: "cannot coalesce empty list");
352    let mut out = s[0].clone();
353    for s in s {
354        if !out.null_count() == 0 {
355            return Ok(out);
356        } else {
357            let mask = out.is_not_null();
358            out = out
359                .as_materialized_series()
360                .zip_with_same_type(&mask, s.as_materialized_series())?
361                .into();
362        }
363    }
364    Ok(out)
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    #[cfg_attr(miri, ignore)]
373    fn test_horizontal_agg() {
374        let a = Column::new("a".into(), [1, 2, 6]);
375        let b = Column::new("b".into(), [Some(1), None, None]);
376        let c = Column::new("c".into(), [Some(4), None, Some(3)]);
377
378        let df = DataFrame::new_infer_height(vec![a, b, c]).unwrap();
379        assert_eq!(
380            Vec::from(
381                df.mean_horizontal(NullStrategy::Ignore)
382                    .unwrap()
383                    .unwrap()
384                    .f64()
385                    .unwrap()
386            ),
387            &[Some(2.0), Some(2.0), Some(4.5)]
388        );
389        assert_eq!(
390            Vec::from(
391                df.sum_horizontal(NullStrategy::Ignore)
392                    .unwrap()
393                    .unwrap()
394                    .i32()
395                    .unwrap()
396            ),
397            &[Some(6), Some(2), Some(9)]
398        );
399        assert_eq!(
400            Vec::from(df.min_horizontal().unwrap().unwrap().i32().unwrap()),
401            &[Some(1), Some(2), Some(3)]
402        );
403        assert_eq!(
404            Vec::from(df.max_horizontal().unwrap().unwrap().i32().unwrap()),
405            &[Some(4), Some(2), Some(6)]
406        );
407    }
408}