1use std::borrow::Cow;
2
3use polars_core::chunked_array::cast::CastOptions;
4use polars_core::prelude::*;
5use polars_core::series::arithmetic::coerce_lhs_rhs;
6use polars_core::utils::dtypes_to_supertype;
7use polars_core::{POOL, with_match_physical_numeric_polars_type};
8use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
9
10fn validate_column_lengths(cs: &[Column]) -> PolarsResult<()> {
11 let mut length = 1;
12 for c in cs {
13 let len = c.len();
14 if len != 1 && len != length {
15 if length == 1 {
16 length = len;
17 } else {
18 polars_bail!(ShapeMismatch: "cannot evaluate two Series of different lengths ({len} and {length})");
19 }
20 }
21 }
22 Ok(())
23}
24
25pub trait MinMaxHorizontal {
26 fn min_horizontal(&self) -> PolarsResult<Option<Column>>;
28 fn max_horizontal(&self) -> PolarsResult<Option<Column>>;
30}
31
32impl MinMaxHorizontal for DataFrame {
33 fn min_horizontal(&self) -> PolarsResult<Option<Column>> {
34 min_horizontal(self.get_columns())
35 }
36 fn max_horizontal(&self) -> PolarsResult<Option<Column>> {
37 max_horizontal(self.get_columns())
38 }
39}
40
41#[derive(Copy, Clone, Debug, PartialEq)]
42pub enum NullStrategy {
43 Ignore,
44 Propagate,
45}
46
47pub trait SumMeanHorizontal {
48 fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
50
51 fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
53}
54
55impl SumMeanHorizontal for DataFrame {
56 fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
57 sum_horizontal(self.get_columns(), null_strategy)
58 }
59 fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
60 mean_horizontal(self.get_columns(), null_strategy)
61 }
62}
63
64fn min_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
65where
66 T: PolarsNumericType,
67 T::Native: PartialOrd,
68{
69 let op = |l: T::Native, r: T::Native| {
70 if l < r { l } else { r }
71 };
72 arity::binary_elementwise_values(left, right, op)
73}
74
75fn max_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
76where
77 T: PolarsNumericType,
78 T::Native: PartialOrd,
79{
80 let op = |l: T::Native, r: T::Native| {
81 if l > r { l } else { r }
82 };
83 arity::binary_elementwise_values(left, right, op)
84}
85
86fn min_max_binary_columns(left: &Column, right: &Column, min: bool) -> PolarsResult<Column> {
87 if left.dtype().to_physical().is_primitive_numeric()
88 && left.null_count() == 0
89 && right.null_count() == 0
90 && left.len() == right.len()
91 {
92 match (left, right) {
93 (Column::Series(left), Column::Series(right)) => {
94 let (lhs, rhs) = coerce_lhs_rhs(left, right)?;
95 let logical = lhs.dtype();
96 let lhs = lhs.to_physical_repr();
97 let rhs = rhs.to_physical_repr();
98
99 with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| {
100 let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
101 let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref();
102
103 unsafe {
104 if min {
105 min_binary(a, b).into_series().from_physical_unchecked(logical)
106 } else {
107 max_binary(a, b).into_series().from_physical_unchecked(logical)
108 }
109 }
110 })
111 .map(Column::from)
112 },
113 _ => {
114 let mask = if min {
115 left.lt(right)?
116 } else {
117 left.gt(right)?
118 };
119
120 left.zip_with(&mask, right)
121 },
122 }
123 } else {
124 let mask = if min {
125 left.lt(right)? & left.is_not_null() | right.is_null()
126 } else {
127 left.gt(right)? & left.is_not_null() | right.is_null()
128 };
129 left.zip_with(&mask, right)
130 }
131}
132
133pub fn max_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
134 validate_column_lengths(columns)?;
135
136 let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false);
137
138 match columns.len() {
139 0 => Ok(None),
140 1 => Ok(Some(columns[0].clone())),
141 2 => max_fn(&columns[0], &columns[1]).map(Some),
142 _ => {
143 POOL.install(|| {
146 columns
147 .par_iter()
148 .map(|s| Ok(Cow::Borrowed(s)))
149 .try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned))
150 .unwrap()
153 .map(|cow| Some(cow.into_owned()))
154 })
155 },
156 }
157}
158
159pub fn min_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
160 validate_column_lengths(columns)?;
161
162 let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true);
163
164 match columns.len() {
165 0 => Ok(None),
166 1 => Ok(Some(columns[0].clone())),
167 2 => min_fn(&columns[0], &columns[1]).map(Some),
168 _ => {
169 POOL.install(|| {
172 columns
173 .par_iter()
174 .map(|s| Ok(Cow::Borrowed(s)))
175 .try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned))
176 .unwrap()
179 .map(|cow| Some(cow.into_owned()))
180 })
181 },
182 }
183}
184
185pub fn sum_horizontal(
186 columns: &[Column],
187 null_strategy: NullStrategy,
188) -> PolarsResult<Option<Column>> {
189 validate_column_lengths(columns)?;
190 let ignore_nulls = null_strategy == NullStrategy::Ignore;
191
192 let apply_null_strategy = |s: Series| -> PolarsResult<Series> {
193 if ignore_nulls && s.null_count() > 0 {
194 s.fill_null(FillNullStrategy::Zero)
195 } else {
196 Ok(s)
197 }
198 };
199
200 let sum_fn = |acc: Series, s: Series| -> PolarsResult<Series> {
201 let acc: Series = apply_null_strategy(acc)?;
202 let s = apply_null_strategy(s)?;
203 std::ops::Add::add(acc, s)
205 };
206
207 let non_null_cols = columns
209 .iter()
210 .filter(|x| x.dtype() != &DataType::Null)
211 .map(|c| c.as_materialized_series())
212 .collect::<Vec<_>>();
213
214 if !ignore_nulls && non_null_cols.len() < columns.len() {
216 let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
218 DataType::Boolean => IDX_DTYPE,
219 dt => dt,
220 };
221 return Ok(Some(Column::full_null(
222 columns[0].name().clone(),
223 columns[0].len(),
224 &return_dtype,
225 )));
226 }
227
228 match non_null_cols.len() {
229 0 => {
230 if columns.is_empty() {
231 Ok(None)
232 } else {
233 Ok(Some(columns[0].clone()))
235 }
236 },
237 1 => Ok(Some(
238 apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {
239 non_null_cols[0].cast(&IDX_DTYPE)?
240 } else {
241 non_null_cols[0].clone()
242 })?
243 .into(),
244 )),
245 2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone())
246 .map(Column::from)
247 .map(Some),
248 _ => {
249 let out = POOL.install(|| {
252 non_null_cols
253 .into_par_iter()
254 .cloned()
255 .map(Ok)
256 .try_reduce_with(sum_fn)
257 .unwrap()
259 });
260 out.map(Column::from).map(Some)
261 },
262 }
263}
264
265pub fn mean_horizontal(
266 columns: &[Column],
267 null_strategy: NullStrategy,
268) -> PolarsResult<Option<Column>> {
269 validate_column_lengths(columns)?;
270
271 let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| {
272 let dtype = s.dtype();
273 dtype.is_primitive_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null()
274 });
275
276 if !non_numeric_columns.is_empty() {
277 let col = non_numeric_columns.first().cloned();
278 polars_bail!(
279 InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})",
280 col.unwrap().name(),
281 col.unwrap().dtype(),
282 );
283 }
284 let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();
285 let num_rows = columns.len();
286 match num_rows {
287 0 => Ok(None),
288 1 => Ok(Some(match columns[0].dtype() {
289 dt if dt != &DataType::Float32 && !dt.is_decimal() => {
290 columns[0].cast(&DataType::Float64)?
291 },
292 _ => columns[0].clone(),
293 })),
294 _ => {
295 let sum = || sum_horizontal(columns.as_slice(), null_strategy);
296 let null_count = || {
297 columns
298 .par_iter()
299 .map(|c| {
300 c.is_null()
301 .into_column()
302 .cast_with_options(&DataType::UInt32, CastOptions::NonStrict)
303 })
304 .reduce_with(|l, r| {
305 let l = l?;
306 let r = r?;
307 let result = std::ops::Add::add(&l, &r)?;
308 PolarsResult::Ok(result)
309 })
310 .unwrap()
313 };
314
315 let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count));
316 let sum = sum?;
317 let null_count = null_count?;
318
319 let value_length: UInt32Chunked = (Column::new_scalar(
321 PlSmallStr::EMPTY,
322 Scalar::from(num_rows as u32),
323 null_count.len(),
324 ) - null_count)?
325 .u32()
326 .unwrap()
327 .clone();
328
329 let dt = if sum
332 .as_ref()
333 .is_some_and(|s| s.dtype() == &DataType::Float32)
334 {
335 &DataType::Float32
336 } else {
337 &DataType::Float64
338 };
339 let value_length = value_length
340 .set(&value_length.equal(0), None)?
341 .into_column()
342 .cast(dt)?;
343
344 sum.map(|sum| std::ops::Div::div(&sum, &value_length))
345 .transpose()
346 },
347 }
348}
349
350pub fn coalesce_columns(s: &[Column]) -> PolarsResult<Column> {
351 polars_ensure!(!s.is_empty(), NoData: "cannot coalesce empty list");
353 let mut out = s[0].clone();
354 for s in s {
355 if !out.null_count() == 0 {
356 return Ok(out);
357 } else {
358 let mask = out.is_not_null();
359 out = out
360 .as_materialized_series()
361 .zip_with_same_type(&mask, s.as_materialized_series())?
362 .into();
363 }
364 }
365 Ok(out)
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 #[cfg_attr(miri, ignore)]
374 fn test_horizontal_agg() {
375 let a = Column::new("a".into(), [1, 2, 6]);
376 let b = Column::new("b".into(), [Some(1), None, None]);
377 let c = Column::new("c".into(), [Some(4), None, Some(3)]);
378
379 let df = DataFrame::new(vec![a, b, c]).unwrap();
380 assert_eq!(
381 Vec::from(
382 df.mean_horizontal(NullStrategy::Ignore)
383 .unwrap()
384 .unwrap()
385 .f64()
386 .unwrap()
387 ),
388 &[Some(2.0), Some(2.0), Some(4.5)]
389 );
390 assert_eq!(
391 Vec::from(
392 df.sum_horizontal(NullStrategy::Ignore)
393 .unwrap()
394 .unwrap()
395 .i32()
396 .unwrap()
397 ),
398 &[Some(6), Some(2), Some(9)]
399 );
400 assert_eq!(
401 Vec::from(df.min_horizontal().unwrap().unwrap().i32().unwrap()),
402 &[Some(1), Some(2), Some(3)]
403 );
404 assert_eq!(
405 Vec::from(df.max_horizontal().unwrap().unwrap().i32().unwrap()),
406 &[Some(4), Some(2), Some(6)]
407 );
408 }
409}