1use std::borrow::Cow;
2
3use polars_core::chunked_array::cast::CastOptions;
4use polars_core::prelude::*;
5use polars_core::series::arithmetic::coerce_lhs_rhs;
6use polars_core::utils::dtypes_to_supertype;
7use polars_core::{POOL, with_match_physical_numeric_polars_type};
8use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
9
10fn validate_column_lengths(cs: &[Column]) -> PolarsResult<()> {
11 let mut length = 1;
12 for c in cs {
13 let len = c.len();
14 if len != 1 && len != length {
15 if length == 1 {
16 length = len;
17 } else {
18 polars_bail!(ShapeMismatch: "cannot evaluate two Series of different lengths ({len} and {length})");
19 }
20 }
21 }
22 Ok(())
23}
24
25pub trait MinMaxHorizontal {
26 fn min_horizontal(&self) -> PolarsResult<Option<Column>>;
28 fn max_horizontal(&self) -> PolarsResult<Option<Column>>;
30}
31
32impl MinMaxHorizontal for DataFrame {
33 fn min_horizontal(&self) -> PolarsResult<Option<Column>> {
34 min_horizontal(self.get_columns())
35 }
36 fn max_horizontal(&self) -> PolarsResult<Option<Column>> {
37 max_horizontal(self.get_columns())
38 }
39}
40
41#[derive(Copy, Clone, Debug, PartialEq)]
42pub enum NullStrategy {
43 Ignore,
44 Propagate,
45}
46
47pub trait SumMeanHorizontal {
48 fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
50
51 fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
53}
54
55impl SumMeanHorizontal for DataFrame {
56 fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
57 sum_horizontal(self.get_columns(), null_strategy)
58 }
59 fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
60 mean_horizontal(self.get_columns(), null_strategy)
61 }
62}
63
64fn min_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
65where
66 T: PolarsNumericType,
67 T::Native: PartialOrd,
68{
69 let op = |l: T::Native, r: T::Native| {
70 if l < r { l } else { r }
71 };
72 arity::binary_elementwise_values(left, right, op)
73}
74
75fn max_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
76where
77 T: PolarsNumericType,
78 T::Native: PartialOrd,
79{
80 let op = |l: T::Native, r: T::Native| {
81 if l > r { l } else { r }
82 };
83 arity::binary_elementwise_values(left, right, op)
84}
85
86fn min_max_binary_columns(left: &Column, right: &Column, min: bool) -> PolarsResult<Column> {
87 if left.dtype().to_physical().is_primitive_numeric()
88 && right.dtype().to_physical().is_primitive_numeric()
89 && left.null_count() == 0
90 && right.null_count() == 0
91 && left.len() == right.len()
92 {
93 match (left, right) {
94 (Column::Series(left), Column::Series(right)) => {
95 let (lhs, rhs) = coerce_lhs_rhs(left, right)?;
96 let logical = lhs.dtype();
97 let lhs = lhs.to_physical_repr();
98 let rhs = rhs.to_physical_repr();
99
100 with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| {
101 let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
102 let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref();
103
104 unsafe {
105 if min {
106 min_binary(a, b).into_series().from_physical_unchecked(logical)
107 } else {
108 max_binary(a, b).into_series().from_physical_unchecked(logical)
109 }
110 }
111 })
112 .map(Column::from)
113 },
114 _ => {
115 let mask = if min {
116 left.lt(right)?
117 } else {
118 left.gt(right)?
119 };
120
121 left.zip_with(&mask, right)
122 },
123 }
124 } else {
125 let mask = if min {
126 left.lt(right)? & left.is_not_null() | right.is_null()
127 } else {
128 left.gt(right)? & left.is_not_null() | right.is_null()
129 };
130 left.zip_with(&mask, right)
131 }
132}
133
134pub fn max_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
135 validate_column_lengths(columns)?;
136
137 let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false);
138
139 match columns.len() {
140 0 => Ok(None),
141 1 => Ok(Some(columns[0].clone())),
142 2 => max_fn(&columns[0], &columns[1]).map(Some),
143 _ => {
144 POOL.install(|| {
147 columns
148 .par_iter()
149 .map(|s| Ok(Cow::Borrowed(s)))
150 .try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned))
151 .unwrap()
154 .map(|cow| Some(cow.into_owned()))
155 })
156 },
157 }
158}
159
160pub fn min_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
161 validate_column_lengths(columns)?;
162
163 let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true);
164
165 match columns.len() {
166 0 => Ok(None),
167 1 => Ok(Some(columns[0].clone())),
168 2 => min_fn(&columns[0], &columns[1]).map(Some),
169 _ => {
170 POOL.install(|| {
173 columns
174 .par_iter()
175 .map(|s| Ok(Cow::Borrowed(s)))
176 .try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned))
177 .unwrap()
180 .map(|cow| Some(cow.into_owned()))
181 })
182 },
183 }
184}
185
186pub fn sum_horizontal(
187 columns: &[Column],
188 null_strategy: NullStrategy,
189) -> PolarsResult<Option<Column>> {
190 validate_column_lengths(columns)?;
191 let ignore_nulls = null_strategy == NullStrategy::Ignore;
192
193 let apply_null_strategy = |s: Series| -> PolarsResult<Series> {
194 if ignore_nulls && s.null_count() > 0 {
195 s.fill_null(FillNullStrategy::Zero)
196 } else {
197 Ok(s)
198 }
199 };
200
201 let sum_fn = |acc: Series, s: Series| -> PolarsResult<Series> {
202 let acc: Series = apply_null_strategy(acc)?;
203 let s = apply_null_strategy(s)?;
204 std::ops::Add::add(acc, s)
206 };
207
208 let non_null_cols = columns
210 .iter()
211 .filter(|x| x.dtype() != &DataType::Null)
212 .map(|c| c.as_materialized_series())
213 .collect::<Vec<_>>();
214
215 if !ignore_nulls && non_null_cols.len() < columns.len() {
217 let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
219 DataType::Boolean => IDX_DTYPE,
220 dt => dt,
221 };
222 return Ok(Some(Column::full_null(
223 columns[0].name().clone(),
224 columns[0].len(),
225 &return_dtype,
226 )));
227 }
228
229 match non_null_cols.len() {
230 0 => {
231 if columns.is_empty() {
232 Ok(None)
233 } else {
234 Ok(Some(columns[0].clone()))
236 }
237 },
238 1 => Ok(Some(
239 apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {
240 non_null_cols[0].cast(&IDX_DTYPE)?
241 } else {
242 non_null_cols[0].clone()
243 })?
244 .into(),
245 )),
246 2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone())
247 .map(Column::from)
248 .map(Some),
249 _ => {
250 let out = POOL.install(|| {
253 non_null_cols
254 .into_par_iter()
255 .cloned()
256 .map(Ok)
257 .try_reduce_with(sum_fn)
258 .unwrap()
260 });
261 out.map(Column::from).map(Some)
262 },
263 }
264}
265
266pub fn mean_horizontal(
267 columns: &[Column],
268 null_strategy: NullStrategy,
269) -> PolarsResult<Option<Column>> {
270 validate_column_lengths(columns)?;
271
272 let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| {
273 let dtype = s.dtype();
274 dtype.is_primitive_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null()
275 });
276
277 if !non_numeric_columns.is_empty() {
278 let col = non_numeric_columns.first().cloned();
279 polars_bail!(
280 InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})",
281 col.unwrap().name(),
282 col.unwrap().dtype(),
283 );
284 }
285 let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();
286 let num_rows = columns.len();
287 match num_rows {
288 0 => Ok(None),
289 1 => Ok(Some(match columns[0].dtype() {
290 dt if dt != &DataType::Float32 && !dt.is_decimal() => {
291 columns[0].cast(&DataType::Float64)?
292 },
293 _ => columns[0].clone(),
294 })),
295 _ => {
296 let sum = || sum_horizontal(columns.as_slice(), null_strategy);
297 let null_count = || {
298 columns
299 .par_iter()
300 .map(|c| {
301 c.is_null()
302 .into_column()
303 .cast_with_options(&DataType::UInt32, CastOptions::NonStrict)
304 })
305 .reduce_with(|l, r| {
306 let l = l?;
307 let r = r?;
308 let result = std::ops::Add::add(&l, &r)?;
309 PolarsResult::Ok(result)
310 })
311 .unwrap()
314 };
315
316 let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count));
317 let sum = sum?;
318 let null_count = null_count?;
319
320 let value_length: UInt32Chunked = (Column::new_scalar(
322 PlSmallStr::EMPTY,
323 Scalar::from(num_rows as u32),
324 null_count.len(),
325 ) - null_count)?
326 .u32()
327 .unwrap()
328 .clone();
329
330 let dt = if sum
333 .as_ref()
334 .is_some_and(|s| s.dtype() == &DataType::Float32)
335 {
336 &DataType::Float32
337 } else {
338 &DataType::Float64
339 };
340 let value_length = value_length
341 .set(&value_length.equal(0), None)?
342 .into_column()
343 .cast(dt)?;
344
345 sum.map(|sum| std::ops::Div::div(&sum, &value_length))
346 .transpose()
347 },
348 }
349}
350
351pub fn coalesce_columns(s: &[Column]) -> PolarsResult<Column> {
352 polars_ensure!(!s.is_empty(), NoData: "cannot coalesce empty list");
354 let mut out = s[0].clone();
355 for s in s {
356 if !out.null_count() == 0 {
357 return Ok(out);
358 } else {
359 let mask = out.is_not_null();
360 out = out
361 .as_materialized_series()
362 .zip_with_same_type(&mask, s.as_materialized_series())?
363 .into();
364 }
365 }
366 Ok(out)
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 #[cfg_attr(miri, ignore)]
375 fn test_horizontal_agg() {
376 let a = Column::new("a".into(), [1, 2, 6]);
377 let b = Column::new("b".into(), [Some(1), None, None]);
378 let c = Column::new("c".into(), [Some(4), None, Some(3)]);
379
380 let df = DataFrame::new(vec![a, b, c]).unwrap();
381 assert_eq!(
382 Vec::from(
383 df.mean_horizontal(NullStrategy::Ignore)
384 .unwrap()
385 .unwrap()
386 .f64()
387 .unwrap()
388 ),
389 &[Some(2.0), Some(2.0), Some(4.5)]
390 );
391 assert_eq!(
392 Vec::from(
393 df.sum_horizontal(NullStrategy::Ignore)
394 .unwrap()
395 .unwrap()
396 .i32()
397 .unwrap()
398 ),
399 &[Some(6), Some(2), Some(9)]
400 );
401 assert_eq!(
402 Vec::from(df.min_horizontal().unwrap().unwrap().i32().unwrap()),
403 &[Some(1), Some(2), Some(3)]
404 );
405 assert_eq!(
406 Vec::from(df.max_horizontal().unwrap().unwrap().i32().unwrap()),
407 &[Some(4), Some(2), Some(6)]
408 );
409 }
410}