1use std::borrow::Cow;
2
3use polars_core::chunked_array::cast::CastOptions;
4use polars_core::prelude::*;
5use polars_core::runtime::RAYON;
6use polars_core::series::arithmetic::coerce_lhs_rhs;
7use polars_core::utils::dtypes_to_supertype;
8use polars_core::with_match_physical_numeric_polars_type;
9use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
10
11fn validate_column_lengths(cs: &[Column]) -> PolarsResult<()> {
12 let mut length = 1;
13 for c in cs {
14 let len = c.len();
15 if len != 1 && len != length {
16 if length == 1 {
17 length = len;
18 } else {
19 polars_bail!(ShapeMismatch: "cannot evaluate two Series of different lengths ({len} and {length})");
20 }
21 }
22 }
23 Ok(())
24}
25
26pub trait MinMaxHorizontal {
27 fn min_horizontal(&self) -> PolarsResult<Option<Column>>;
29 fn max_horizontal(&self) -> PolarsResult<Option<Column>>;
31}
32
33impl MinMaxHorizontal for DataFrame {
34 fn min_horizontal(&self) -> PolarsResult<Option<Column>> {
35 min_horizontal(self.columns())
36 }
37 fn max_horizontal(&self) -> PolarsResult<Option<Column>> {
38 max_horizontal(self.columns())
39 }
40}
41
42#[derive(Copy, Clone, Debug, PartialEq)]
43pub enum NullStrategy {
44 Ignore,
45 Propagate,
46}
47
48pub trait SumMeanHorizontal {
49 fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
51
52 fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>>;
54}
55
56impl SumMeanHorizontal for DataFrame {
57 fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
58 sum_horizontal(self.columns(), null_strategy)
59 }
60 fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Column>> {
61 mean_horizontal(self.columns(), null_strategy)
62 }
63}
64
65fn min_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
66where
67 T: PolarsNumericType,
68 T::Native: PartialOrd,
69{
70 let op = |l: T::Native, r: T::Native| {
71 if l < r { l } else { r }
72 };
73 arity::binary_elementwise_values(left, right, op)
74}
75
76fn max_binary<T>(left: &ChunkedArray<T>, right: &ChunkedArray<T>) -> ChunkedArray<T>
77where
78 T: PolarsNumericType,
79 T::Native: PartialOrd,
80{
81 let op = |l: T::Native, r: T::Native| {
82 if l > r { l } else { r }
83 };
84 arity::binary_elementwise_values(left, right, op)
85}
86
87fn min_max_binary_columns(left: &Column, right: &Column, min: bool) -> PolarsResult<Column> {
88 if left.dtype().to_physical().is_primitive_numeric()
89 && right.dtype().to_physical().is_primitive_numeric()
90 && left.null_count() == 0
91 && right.null_count() == 0
92 && left.len() == right.len()
93 {
94 match (left, right) {
95 (Column::Series(left), Column::Series(right)) => {
96 let (lhs, rhs) = coerce_lhs_rhs(left, right)?;
97 let logical = lhs.dtype();
98 let lhs = lhs.to_physical_repr();
99 let rhs = rhs.to_physical_repr();
100
101 with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| {
102 let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref();
103 let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref();
104
105 unsafe {
106 if min {
107 min_binary(a, b).into_series().from_physical_unchecked(logical)
108 } else {
109 max_binary(a, b).into_series().from_physical_unchecked(logical)
110 }
111 }
112 })
113 .map(Column::from)
114 },
115 _ => {
116 let mask = if min {
117 left.lt(right)?
118 } else {
119 left.gt(right)?
120 };
121
122 left.zip_with(&mask, right)
123 },
124 }
125 } else {
126 let mask = if min {
127 left.lt(right)? & left.is_not_null() | right.is_null()
128 } else {
129 left.gt(right)? & left.is_not_null() | right.is_null()
130 };
131 left.zip_with(&mask, right)
132 }
133}
134
135pub fn max_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
136 validate_column_lengths(columns)?;
137
138 let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false);
139
140 match columns.len() {
141 0 => Ok(None),
142 1 => Ok(Some(columns[0].clone())),
143 2 => max_fn(&columns[0], &columns[1]).map(Some),
144 _ => {
145 RAYON.install(|| {
148 columns
149 .par_iter()
150 .map(|s| Ok(Cow::Borrowed(s)))
151 .try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned))
152 .unwrap()
155 .map(|cow| Some(cow.into_owned()))
156 })
157 },
158 }
159}
160
161pub fn min_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
162 validate_column_lengths(columns)?;
163
164 let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true);
165
166 match columns.len() {
167 0 => Ok(None),
168 1 => Ok(Some(columns[0].clone())),
169 2 => min_fn(&columns[0], &columns[1]).map(Some),
170 _ => {
171 RAYON.install(|| {
174 columns
175 .par_iter()
176 .map(|s| Ok(Cow::Borrowed(s)))
177 .try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned))
178 .unwrap()
181 .map(|cow| Some(cow.into_owned()))
182 })
183 },
184 }
185}
186
187pub fn sum_horizontal(
188 columns: &[Column],
189 null_strategy: NullStrategy,
190) -> PolarsResult<Option<Column>> {
191 validate_column_lengths(columns)?;
192 let ignore_nulls = null_strategy == NullStrategy::Ignore;
193
194 let apply_null_strategy = |s: Series| -> PolarsResult<Series> {
195 if ignore_nulls && s.null_count() > 0 {
196 s.fill_null(FillNullStrategy::Zero)
197 } else {
198 Ok(s)
199 }
200 };
201
202 let sum_fn = |acc: Series, s: Series| -> PolarsResult<Series> {
203 let acc: Series = apply_null_strategy(acc)?;
204 let s = apply_null_strategy(s)?;
205 std::ops::Add::add(acc, s)
207 };
208
209 let non_null_cols = columns
211 .iter()
212 .filter(|x| x.dtype() != &DataType::Null)
213 .map(|c| c.as_materialized_series())
214 .collect::<Vec<_>>();
215
216 if !ignore_nulls && non_null_cols.len() < columns.len() {
218 let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
220 DataType::Boolean => IDX_DTYPE,
221 dt => dt,
222 };
223 return Ok(Some(Column::full_null(
224 columns[0].name().clone(),
225 columns[0].len(),
226 &return_dtype,
227 )));
228 }
229
230 match non_null_cols.len() {
231 0 => {
232 if columns.is_empty() {
233 Ok(None)
234 } else {
235 Ok(Some(columns[0].clone()))
237 }
238 },
239 1 => Ok(Some(
240 apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {
241 non_null_cols[0].cast(&IDX_DTYPE)?
242 } else {
243 non_null_cols[0].clone()
244 })?
245 .into(),
246 )),
247 2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone())
248 .map(Column::from)
249 .map(Some),
250 _ => {
251 let out = RAYON.install(|| {
254 non_null_cols
255 .into_par_iter()
256 .cloned()
257 .map(Ok)
258 .try_reduce_with(sum_fn)
259 .unwrap()
261 });
262 out.map(Column::from).map(Some)
263 },
264 }
265}
266
267pub fn mean_horizontal(
268 columns: &[Column],
269 null_strategy: NullStrategy,
270) -> PolarsResult<Option<Column>> {
271 validate_column_lengths(columns)?;
272
273 let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| {
274 let dtype = s.dtype();
275 dtype.is_primitive_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null()
276 });
277
278 if !non_numeric_columns.is_empty() {
279 let col = non_numeric_columns.first().cloned();
280 polars_bail!(
281 InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})",
282 col.unwrap().name(),
283 col.unwrap().dtype(),
284 );
285 }
286 let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();
287 let num_rows = columns.len();
288 match num_rows {
289 0 => Ok(None),
290 1 => Ok(Some(match columns[0].dtype() {
291 dt if !matches!(dt, DataType::Float16 | DataType::Float32) && !dt.is_decimal() => {
292 columns[0].cast(&DataType::Float64)?
293 },
294 _ => columns[0].clone(),
295 })),
296 _ => {
297 let sum = || sum_horizontal(columns.as_slice(), null_strategy);
298 let null_count = || {
299 columns
300 .par_iter()
301 .map(|c| {
302 c.is_null()
303 .into_column()
304 .cast_with_options(&DataType::UInt32, CastOptions::NonStrict)
305 })
306 .reduce_with(|l, r| {
307 let l = l?;
308 let r = r?;
309 let result = std::ops::Add::add(&l, &r)?;
310 PolarsResult::Ok(result)
311 })
312 .unwrap()
315 };
316
317 let (sum, null_count) = RAYON.install(|| rayon::join(sum, null_count));
318 let sum = sum?;
319 let null_count = null_count?;
320
321 let value_length: UInt32Chunked = (Column::new_scalar(
323 PlSmallStr::EMPTY,
324 Scalar::from(num_rows as u32),
325 null_count.len(),
326 ) - null_count)?
327 .u32()
328 .unwrap()
329 .clone();
330
331 let dt = sum
334 .as_ref()
335 .map(Column::dtype)
336 .filter(|dt| matches!(dt, DataType::Float16 | DataType::Float32))
337 .unwrap_or(&DataType::Float64);
338 let value_length = value_length
339 .set(&value_length.equal(0), None)?
340 .into_column()
341 .cast(dt)?;
342
343 sum.map(|sum| std::ops::Div::div(&sum, &value_length))
344 .transpose()
345 },
346 }
347}
348
349pub fn coalesce_columns(s: &[Column]) -> PolarsResult<Column> {
350 polars_ensure!(!s.is_empty(), NoData: "cannot coalesce empty list");
352 let mut out = s[0].clone();
353 for s in s {
354 if !out.null_count() == 0 {
355 return Ok(out);
356 } else {
357 let mask = out.is_not_null();
358 out = out
359 .as_materialized_series()
360 .zip_with_same_type(&mask, s.as_materialized_series())?
361 .into();
362 }
363 }
364 Ok(out)
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 #[cfg_attr(miri, ignore)]
373 fn test_horizontal_agg() {
374 let a = Column::new("a".into(), [1, 2, 6]);
375 let b = Column::new("b".into(), [Some(1), None, None]);
376 let c = Column::new("c".into(), [Some(4), None, Some(3)]);
377
378 let df = DataFrame::new_infer_height(vec![a, b, c]).unwrap();
379 assert_eq!(
380 Vec::from(
381 df.mean_horizontal(NullStrategy::Ignore)
382 .unwrap()
383 .unwrap()
384 .f64()
385 .unwrap()
386 ),
387 &[Some(2.0), Some(2.0), Some(4.5)]
388 );
389 assert_eq!(
390 Vec::from(
391 df.sum_horizontal(NullStrategy::Ignore)
392 .unwrap()
393 .unwrap()
394 .i32()
395 .unwrap()
396 ),
397 &[Some(6), Some(2), Some(9)]
398 );
399 assert_eq!(
400 Vec::from(df.min_horizontal().unwrap().unwrap().i32().unwrap()),
401 &[Some(1), Some(2), Some(3)]
402 );
403 assert_eq!(
404 Vec::from(df.max_horizontal().unwrap().unwrap().i32().unwrap()),
405 &[Some(4), Some(2), Some(6)]
406 );
407 }
408}