Skip to main content

polars_core/chunked_array/ops/aggregate/
mod.rs

1//! Implementations of the ChunkAgg trait.
2mod quantile;
3mod var;
4
5use arrow::types::NativeType;
6use num_traits::{AsPrimitive, Float, One, ToPrimitive, Zero};
7#[cfg(feature = "dtype-decimal")]
8use polars_compute::decimal::DEC128_MAX_PREC;
9use polars_compute::float_sum;
10use polars_compute::min_max::MinMaxKernel;
11use polars_compute::rolling::QuantileMethod;
12use polars_compute::sum::{WrappingAdd, WrappingSum, wrapping_sum_arr, wrapping_sum_arr_upcast};
13use polars_utils::float::IsFloat;
14use polars_utils::float16::pf16;
15use polars_utils::min_max::MinMax;
16pub use quantile::*;
17pub use var::*;
18
19use super::float_sorted_arg_max::{
20    float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
21};
22use crate::chunked_array::{ChunkedArray, arg_max_binary, arg_min_binary};
23use crate::datatypes::{BooleanChunked, PolarsNumericType};
24use crate::prelude::*;
25use crate::series::IsSorted;
26
27pub trait SumCast: Sized {
28    type Sum: NumericNative + From<Self>;
29}
30
31macro_rules! impl_sum_cast {
32    ($($x:ty),*) => {
33        $(impl SumCast for $x { type Sum = $x; })*
34    };
35    ($($from:ty as $to:ty),*) => {
36        $(impl SumCast for $from { type Sum = $to; })*
37    };
38}
39
40impl_sum_cast!(
41    bool as IdxSize,
42    u8 as i64,
43    u16 as i64,
44    i8 as i64,
45    i16 as i64
46);
47impl_sum_cast!(u32, u64, i32, i64, f32, f64);
48#[cfg(feature = "dtype-f16")]
49impl_sum_cast!(pf16);
50#[cfg(feature = "dtype-i128")]
51impl_sum_cast!(i128);
52#[cfg(feature = "dtype-u128")]
53impl_sum_cast!(u128);
54
55pub fn sum_output_dtype(in_dtype: &DataType) -> DataType {
56    use DataType::*;
57    match in_dtype {
58        Boolean => IDX_DTYPE,
59        Int8 | UInt8 | Int16 | UInt16 => Int64,
60        #[cfg(feature = "dtype-decimal")]
61        Decimal(_, scale) => Decimal(DEC128_MAX_PREC, *scale),
62        dt => dt.clone(),
63    }
64}
65
66/// Aggregations that return [`Series`] of unit length. Those can be used in broadcasting operations.
67pub trait ChunkAggSeries {
68    /// Get the sum of the [`ChunkedArray`] as a new [`Series`] of length 1.
69    fn sum_reduce(&self) -> Scalar {
70        unimplemented!()
71    }
72    /// Get the max of the [`ChunkedArray`] as a new [`Series`] of length 1.
73    fn max_reduce(&self) -> Scalar {
74        unimplemented!()
75    }
76    /// Get the min of the [`ChunkedArray`] as a new [`Series`] of length 1.
77    fn min_reduce(&self) -> Scalar {
78        unimplemented!()
79    }
80    /// Get the product of the [`ChunkedArray`] as a new [`Series`] of length 1.
81    fn prod_reduce(&self) -> Scalar {
82        unimplemented!()
83    }
84}
85
86fn sum<T>(array: &PrimitiveArray<T>) -> T
87where
88    T: NumericNative + NativeType + WrappingSum,
89{
90    if array.null_count() == array.len() {
91        return T::default();
92    }
93
94    if T::is_float() {
95        unsafe {
96            if T::is_f16() {
97                let f16_arr =
98                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<pf16>>(array);
99                // We do not trust the numerical accuracy of f16 summation
100                let sum: pf16 = float_sum::sum_arr_as_f32(f16_arr).as_();
101                std::mem::transmute_copy::<pf16, T>(&sum)
102            } else if T::is_f32() {
103                let f32_arr =
104                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f32>>(array);
105                let sum = float_sum::sum_arr_as_f32(f32_arr);
106                std::mem::transmute_copy::<f32, T>(&sum)
107            } else if T::is_f64() {
108                let f64_arr =
109                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f64>>(array);
110                let sum = float_sum::sum_arr_as_f64(f64_arr);
111                std::mem::transmute_copy::<f64, T>(&sum)
112            } else {
113                unreachable!("only supported float types are f16, f32 and f64");
114            }
115        }
116    } else {
117        wrapping_sum_arr(array)
118    }
119}
120
121impl<T> ChunkAgg<T::Native> for ChunkedArray<T>
122where
123    T: PolarsNumericType,
124    T::Native: WrappingSum,
125    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
126{
127    fn sum(&self) -> Option<T::Native> {
128        Some(
129            self.downcast_iter()
130                .map(sum)
131                .fold(T::Native::zero(), |acc, v| acc + v),
132        )
133    }
134
135    fn _sum_as_f64(&self) -> f64 {
136        self.downcast_iter().map(float_sum::sum_arr_as_f64).sum()
137    }
138
139    fn min(&self) -> Option<T::Native> {
140        if self.null_count() == self.len() {
141            return None;
142        }
143
144        // There is at least one non-null value.
145
146        match self.is_sorted_flag() {
147            IsSorted::Ascending => {
148                let idx = self.first_non_null().unwrap();
149                unsafe { self.get_unchecked(idx) }
150            },
151            IsSorted::Descending => {
152                let idx = self.last_non_null().unwrap();
153                unsafe { self.get_unchecked(idx) }
154            },
155            IsSorted::Not => self
156                .downcast_iter()
157                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
158                .reduce(MinMax::min_ignore_nan),
159        }
160    }
161
162    fn max(&self) -> Option<T::Native> {
163        if self.null_count() == self.len() {
164            return None;
165        }
166        // There is at least one non-null value.
167
168        match self.is_sorted_flag() {
169            IsSorted::Ascending => {
170                let idx = if T::get_static_dtype().is_float() {
171                    float_arg_max_sorted_ascending(self)
172                } else {
173                    self.last_non_null().unwrap()
174                };
175
176                unsafe { self.get_unchecked(idx) }
177            },
178            IsSorted::Descending => {
179                let idx = if T::get_static_dtype().is_float() {
180                    float_arg_max_sorted_descending(self)
181                } else {
182                    self.first_non_null().unwrap()
183                };
184
185                unsafe { self.get_unchecked(idx) }
186            },
187            IsSorted::Not => self
188                .downcast_iter()
189                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
190                .reduce(MinMax::max_ignore_nan),
191        }
192    }
193
194    fn min_max(&self) -> Option<(T::Native, T::Native)> {
195        if self.null_count() == self.len() {
196            return None;
197        }
198        // There is at least one non-null value.
199
200        match self.is_sorted_flag() {
201            IsSorted::Ascending => {
202                let min = unsafe { self.get_unchecked(self.first_non_null().unwrap()) };
203                let max = {
204                    let idx = if T::get_static_dtype().is_float() {
205                        float_arg_max_sorted_ascending(self)
206                    } else {
207                        self.last_non_null().unwrap()
208                    };
209
210                    unsafe { self.get_unchecked(idx) }
211                };
212                min.zip(max)
213            },
214            IsSorted::Descending => {
215                let min = unsafe { self.get_unchecked(self.last_non_null().unwrap()) };
216                let max = {
217                    let idx = if T::get_static_dtype().is_float() {
218                        float_arg_max_sorted_descending(self)
219                    } else {
220                        self.first_non_null().unwrap()
221                    };
222
223                    unsafe { self.get_unchecked(idx) }
224                };
225
226                min.zip(max)
227            },
228            IsSorted::Not => self
229                .downcast_iter()
230                .filter_map(MinMaxKernel::min_max_ignore_nan_kernel)
231                .reduce(|(min1, max1), (min2, max2)| {
232                    (
233                        MinMax::min_ignore_nan(min1, min2),
234                        MinMax::max_ignore_nan(max1, max2),
235                    )
236                }),
237        }
238    }
239
240    fn mean(&self) -> Option<f64> {
241        let count = self.len() - self.null_count();
242        if count == 0 {
243            return None;
244        }
245        Some(self._sum_as_f64() / count as f64)
246    }
247}
248
249/// Booleans are cast to 1 or 0.
250impl BooleanChunked {
251    pub fn sum(&self) -> Option<IdxSize> {
252        Some(if self.is_empty() {
253            0
254        } else {
255            self.downcast_iter()
256                .map(|arr| match arr.validity() {
257                    Some(validity) => {
258                        (arr.len() - (validity & arr.values()).unset_bits()) as IdxSize
259                    },
260                    None => (arr.len() - arr.values().unset_bits()) as IdxSize,
261                })
262                .sum()
263        })
264    }
265
266    pub fn min(&self) -> Option<bool> {
267        let nc = self.null_count();
268        let len = self.len();
269        if self.is_empty() || nc == len {
270            return None;
271        }
272        if nc == 0 {
273            if self.all() { Some(true) } else { Some(false) }
274        } else {
275            // we can unwrap as we already checked empty and all null above
276            if (self.sum().unwrap() + nc as IdxSize) == len as IdxSize {
277                Some(true)
278            } else {
279                Some(false)
280            }
281        }
282    }
283
284    pub fn max(&self) -> Option<bool> {
285        if self.is_empty() || self.null_count() == self.len() {
286            return None;
287        }
288        if self.any() { Some(true) } else { Some(false) }
289    }
290    pub fn mean(&self) -> Option<f64> {
291        if self.is_empty() || self.null_count() == self.len() {
292            return None;
293        }
294        self.sum()
295            .map(|sum| sum as f64 / (self.len() - self.null_count()) as f64)
296    }
297}
298
299// Needs the same trait bounds as the implementation of ChunkedArray<T> of dyn Series.
300impl<T> ChunkAggSeries for ChunkedArray<T>
301where
302    T: PolarsNumericType,
303    T::Native: WrappingSum + SumCast,
304    <T::Native as SumCast>::Sum: WrappingAdd,
305    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
306{
307    fn sum_reduce(&self) -> Scalar {
308        let v: <T::Native as SumCast>::Sum = if T::Native::is_float() {
309            self.sum().map(Into::into).unwrap_or_else(Zero::zero)
310        } else {
311            self.downcast_iter()
312                .map(wrapping_sum_arr_upcast)
313                .fold(Zero::zero(), |a, b| a.wrapping_add(&b))
314        };
315        Scalar::new(sum_output_dtype(&T::get_static_dtype()), v.into())
316    }
317
318    fn max_reduce(&self) -> Scalar {
319        let v = ChunkAgg::max(self);
320        Scalar::new(T::get_static_dtype(), v.into())
321    }
322
323    fn min_reduce(&self) -> Scalar {
324        let v = ChunkAgg::min(self);
325        Scalar::new(T::get_static_dtype(), v.into())
326    }
327
328    fn prod_reduce(&self) -> Scalar {
329        let mut prod = T::Native::one();
330
331        for arr in self.downcast_iter() {
332            for v in arr.into_iter().flatten() {
333                prod = prod * *v
334            }
335        }
336        Scalar::new(T::get_static_dtype(), prod.into())
337    }
338}
339
340impl<T> VarAggSeries for ChunkedArray<T>
341where
342    T: PolarsIntegerType,
343    ChunkedArray<T>: ChunkVar,
344{
345    fn var_reduce(&self, ddof: u8) -> Scalar {
346        let v = self.var(ddof);
347        Scalar::new(DataType::Float64, v.into())
348    }
349
350    fn std_reduce(&self, ddof: u8) -> Scalar {
351        let v = self.std(ddof);
352        Scalar::new(DataType::Float64, v.into())
353    }
354}
355
356#[cfg(feature = "dtype-f16")]
357impl VarAggSeries for Float16Chunked {
358    fn var_reduce(&self, ddof: u8) -> Scalar {
359        let v = self.var(ddof).map(AsPrimitive::<pf16>::as_);
360        Scalar::new(DataType::Float16, v.into())
361    }
362
363    fn std_reduce(&self, ddof: u8) -> Scalar {
364        let v = self.std(ddof).map(AsPrimitive::<pf16>::as_);
365        Scalar::new(DataType::Float16, v.into())
366    }
367}
368
369impl VarAggSeries for Float32Chunked {
370    fn var_reduce(&self, ddof: u8) -> Scalar {
371        let v = self.var(ddof).map(|v| v as f32);
372        Scalar::new(DataType::Float32, v.into())
373    }
374
375    fn std_reduce(&self, ddof: u8) -> Scalar {
376        let v = self.std(ddof).map(|v| v as f32);
377        Scalar::new(DataType::Float32, v.into())
378    }
379}
380
381impl VarAggSeries for Float64Chunked {
382    fn var_reduce(&self, ddof: u8) -> Scalar {
383        let v = self.var(ddof);
384        Scalar::new(DataType::Float64, v.into())
385    }
386
387    fn std_reduce(&self, ddof: u8) -> Scalar {
388        let v = self.std(ddof);
389        Scalar::new(DataType::Float64, v.into())
390    }
391}
392
393impl<T> QuantileAggSeries for ChunkedArray<T>
394where
395    T: PolarsIntegerType,
396    T::Native: Ord + WrappingSum,
397{
398    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
399        let v = self.quantile(quantile, method)?;
400        Ok(Scalar::new(DataType::Float64, v.into()))
401    }
402
403    fn quantiles_reduce(&self, quantiles: &[f64], method: QuantileMethod) -> PolarsResult<Scalar> {
404        let v = self.quantiles(quantiles, method)?;
405        let s =
406            Float64Chunked::from_iter_options(PlSmallStr::from_static("quantiles"), v.into_iter())
407                .into_series();
408        let dtype = DataType::List(Box::new(s.dtype().clone()));
409        Ok(Scalar::new(dtype, AnyValue::List(s)))
410    }
411
412    fn median_reduce(&self) -> Scalar {
413        let v = self.median();
414        Scalar::new(DataType::Float64, v.into())
415    }
416}
417
418#[cfg(feature = "dtype-f16")]
419impl QuantileAggSeries for Float16Chunked {
420    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
421        let v = self.quantile(quantile, method)?;
422        Ok(Scalar::new(DataType::Float16, v.into()))
423    }
424
425    fn quantiles_reduce(&self, quantiles: &[f64], method: QuantileMethod) -> PolarsResult<Scalar> {
426        let v = self.quantiles(quantiles, method)?;
427        let s =
428            Float16Chunked::from_iter_options(PlSmallStr::from_static("quantiles"), v.into_iter())
429                .into_series();
430        let dtype = DataType::List(Box::new(s.dtype().clone()));
431        Ok(Scalar::new(dtype, AnyValue::List(s)))
432    }
433
434    fn median_reduce(&self) -> Scalar {
435        let v = self.median();
436        Scalar::new(DataType::Float16, v.into())
437    }
438}
439
440impl QuantileAggSeries for Float32Chunked {
441    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
442        let v = self.quantile(quantile, method)?;
443        Ok(Scalar::new(DataType::Float32, v.into()))
444    }
445
446    fn quantiles_reduce(&self, quantiles: &[f64], method: QuantileMethod) -> PolarsResult<Scalar> {
447        let v = self.quantiles(quantiles, method)?;
448        let s =
449            Float32Chunked::from_iter_options(PlSmallStr::from_static("quantiles"), v.into_iter())
450                .into_series();
451        let dtype = DataType::List(Box::new(s.dtype().clone()));
452        Ok(Scalar::new(dtype, AnyValue::List(s)))
453    }
454
455    fn median_reduce(&self) -> Scalar {
456        let v = self.median();
457        Scalar::new(DataType::Float32, v.into())
458    }
459}
460
461impl QuantileAggSeries for Float64Chunked {
462    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
463        let v = self.quantile(quantile, method)?;
464        Ok(Scalar::new(DataType::Float64, v.into()))
465    }
466
467    fn quantiles_reduce(&self, quantiles: &[f64], method: QuantileMethod) -> PolarsResult<Scalar> {
468        let v = self.quantiles(quantiles, method)?;
469        let s =
470            Float64Chunked::from_iter_options(PlSmallStr::from_static("quantiles"), v.into_iter())
471                .into_series();
472        let dtype = DataType::List(Box::new(s.dtype().clone()));
473        Ok(Scalar::new(dtype, AnyValue::List(s)))
474    }
475
476    fn median_reduce(&self) -> Scalar {
477        let v = self.median();
478        Scalar::new(DataType::Float64, v.into())
479    }
480}
481
482impl ChunkAggSeries for BooleanChunked {
483    fn sum_reduce(&self) -> Scalar {
484        let v = self.sum();
485        Scalar::new(IDX_DTYPE, v.into())
486    }
487    fn max_reduce(&self) -> Scalar {
488        let v = self.max();
489        Scalar::new(DataType::Boolean, v.into())
490    }
491    fn min_reduce(&self) -> Scalar {
492        let v = self.min();
493        Scalar::new(DataType::Boolean, v.into())
494    }
495}
496
497impl StringChunked {
498    pub(crate) fn max_str(&self) -> Option<&str> {
499        if self.is_empty() {
500            return None;
501        }
502        match self.is_sorted_flag() {
503            IsSorted::Ascending => {
504                self.last_non_null().and_then(|idx| {
505                    // SAFETY: last_non_null returns in bound index
506                    unsafe { self.get_unchecked(idx) }
507                })
508            },
509            IsSorted::Descending => {
510                self.first_non_null().and_then(|idx| {
511                    // SAFETY: first_non_null returns in bound index
512                    unsafe { self.get_unchecked(idx) }
513                })
514            },
515            IsSorted::Not => self
516                .downcast_iter()
517                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
518                .reduce(MinMax::max_ignore_nan),
519        }
520    }
521    pub(crate) fn min_str(&self) -> Option<&str> {
522        if self.is_empty() {
523            return None;
524        }
525        match self.is_sorted_flag() {
526            IsSorted::Ascending => {
527                self.first_non_null().and_then(|idx| {
528                    // SAFETY: first_non_null returns in bound index
529                    unsafe { self.get_unchecked(idx) }
530                })
531            },
532            IsSorted::Descending => {
533                self.last_non_null().and_then(|idx| {
534                    // SAFETY: last_non_null returns in bound index
535                    unsafe { self.get_unchecked(idx) }
536                })
537            },
538            IsSorted::Not => self
539                .downcast_iter()
540                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
541                .reduce(MinMax::min_ignore_nan),
542        }
543    }
544}
545
546impl ChunkAggSeries for StringChunked {
547    fn max_reduce(&self) -> Scalar {
548        let av: AnyValue = self.max_str().into();
549        Scalar::new(DataType::String, av.into_static())
550    }
551    fn min_reduce(&self) -> Scalar {
552        let av: AnyValue = self.min_str().into();
553        Scalar::new(DataType::String, av.into_static())
554    }
555}
556
557#[cfg(feature = "dtype-categorical")]
558impl<T: PolarsCategoricalType> CategoricalChunked<T>
559where
560    ChunkedArray<T::PolarsPhysical>: ChunkAgg<T::Native>,
561{
562    fn min_categorical(&self) -> Option<CatSize> {
563        if self.is_empty() || self.null_count() == self.len() {
564            return None;
565        }
566        if self.uses_lexical_ordering() {
567            let mapping = self.get_mapping();
568            let s = self
569                .physical()
570                .iter()
571                .flat_map(|opt_cat| {
572                    Some(unsafe { mapping.cat_to_str_unchecked(opt_cat?.as_cat()) })
573                })
574                .min();
575            mapping.get_cat(s.unwrap())
576        } else {
577            Some(self.physical().min()?.as_cat())
578        }
579    }
580
581    fn max_categorical(&self) -> Option<CatSize> {
582        if self.is_empty() || self.null_count() == self.len() {
583            return None;
584        }
585        if self.uses_lexical_ordering() {
586            let mapping = self.get_mapping();
587            let s = self
588                .physical()
589                .iter()
590                .flat_map(|opt_cat| {
591                    Some(unsafe { mapping.cat_to_str_unchecked(opt_cat?.as_cat()) })
592                })
593                .max();
594            mapping.get_cat(s.unwrap())
595        } else {
596            Some(self.physical().max()?.as_cat())
597        }
598    }
599}
600
601#[cfg(feature = "dtype-categorical")]
602impl<T: PolarsCategoricalType> ChunkAggSeries for CategoricalChunked<T>
603where
604    ChunkedArray<T::PolarsPhysical>: ChunkAgg<T::Native>,
605{
606    fn min_reduce(&self) -> Scalar {
607        let Some(min) = self.min_categorical() else {
608            return Scalar::new(self.dtype().clone(), AnyValue::Null);
609        };
610        let av = match self.dtype() {
611            DataType::Enum(_, mapping) => AnyValue::EnumOwned(min, mapping.clone()),
612            DataType::Categorical(_, mapping) => AnyValue::CategoricalOwned(min, mapping.clone()),
613            _ => unreachable!(),
614        };
615        Scalar::new(self.dtype().clone(), av)
616    }
617
618    fn max_reduce(&self) -> Scalar {
619        let Some(max) = self.max_categorical() else {
620            return Scalar::new(self.dtype().clone(), AnyValue::Null);
621        };
622        let av = match self.dtype() {
623            DataType::Enum(_, mapping) => AnyValue::EnumOwned(max, mapping.clone()),
624            DataType::Categorical(_, mapping) => AnyValue::CategoricalOwned(max, mapping.clone()),
625            _ => unreachable!(),
626        };
627        Scalar::new(self.dtype().clone(), av)
628    }
629}
630
631impl BinaryChunked {
632    pub fn max_binary(&self) -> Option<&[u8]> {
633        if self.is_empty() {
634            return None;
635        }
636        match self.is_sorted_flag() {
637            IsSorted::Ascending => {
638                self.last_non_null().and_then(|idx| {
639                    // SAFETY: last_non_null returns in bound index.
640                    unsafe { self.get_unchecked(idx) }
641                })
642            },
643            IsSorted::Descending => {
644                self.first_non_null().and_then(|idx| {
645                    // SAFETY: first_non_null returns in bound index.
646                    unsafe { self.get_unchecked(idx) }
647                })
648            },
649            IsSorted::Not => self
650                .downcast_iter()
651                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
652                .reduce(MinMax::max_ignore_nan),
653        }
654    }
655
656    pub fn min_binary(&self) -> Option<&[u8]> {
657        if self.is_empty() {
658            return None;
659        }
660        match self.is_sorted_flag() {
661            IsSorted::Ascending => {
662                self.first_non_null().and_then(|idx| {
663                    // SAFETY: first_non_null returns in bound index.
664                    unsafe { self.get_unchecked(idx) }
665                })
666            },
667            IsSorted::Descending => {
668                self.last_non_null().and_then(|idx| {
669                    // SAFETY: last_non_null returns in bound index.
670                    unsafe { self.get_unchecked(idx) }
671                })
672            },
673            IsSorted::Not => self
674                .downcast_iter()
675                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
676                .reduce(MinMax::min_ignore_nan),
677        }
678    }
679    pub fn arg_min_binary(&self) -> Option<usize> {
680        if self.is_empty() || self.null_count() == self.len() {
681            return None;
682        }
683
684        match self.is_sorted_flag() {
685            IsSorted::Ascending => self.first_non_null(),
686            IsSorted::Descending => self.last_non_null(),
687            IsSorted::Not => arg_min_binary(self),
688        }
689    }
690
691    pub fn arg_max_binary(&self) -> Option<usize> {
692        if self.is_empty() || self.null_count() == self.len() {
693            return None;
694        }
695
696        match self.is_sorted_flag() {
697            IsSorted::Ascending => self.last_non_null(),
698            IsSorted::Descending => self.first_non_null(),
699            IsSorted::Not => arg_max_binary(self),
700        }
701    }
702}
703
704impl ChunkAggSeries for BinaryChunked {
705    fn sum_reduce(&self) -> Scalar {
706        unimplemented!()
707    }
708    fn max_reduce(&self) -> Scalar {
709        let av: AnyValue = self.max_binary().into();
710        Scalar::new(self.dtype().clone(), av.into_static())
711    }
712    fn min_reduce(&self) -> Scalar {
713        let av: AnyValue = self.min_binary().into();
714        Scalar::new(self.dtype().clone(), av.into_static())
715    }
716}
717
718#[cfg(feature = "object")]
719impl<T: PolarsObject> ChunkAggSeries for ObjectChunked<T> {}
720
721#[cfg(test)]
722mod test {
723    use polars_compute::rolling::QuantileMethod;
724
725    use crate::prelude::*;
726
727    #[test]
728    #[cfg(not(miri))]
729    fn test_var() {
730        // Validated with numpy. Note that numpy uses ddof as an argument which
731        // influences results. The default ddof=0, we chose ddof=1, which is
732        // standard in statistics.
733        let ca1 = Int32Chunked::new(PlSmallStr::EMPTY, &[5, 8, 9, 5, 0]);
734        let ca2 = Int32Chunked::new(
735            PlSmallStr::EMPTY,
736            &[
737                Some(5),
738                None,
739                Some(8),
740                Some(9),
741                None,
742                Some(5),
743                Some(0),
744                None,
745            ],
746        );
747        for ca in &[ca1, ca2] {
748            let out = ca.var(1);
749            assert_eq!(out, Some(12.3));
750            let out = ca.std(1).unwrap();
751            assert!((3.5071355833500366 - out).abs() < 0.000000001);
752        }
753    }
754
755    #[test]
756    fn test_agg_float() {
757        let ca1 = Float32Chunked::new(PlSmallStr::from_static("a"), &[1.0, f32::NAN]);
758        let ca2 = Float32Chunked::new(PlSmallStr::from_static("b"), &[f32::NAN, 1.0]);
759        assert_eq!(ca1.min(), ca2.min());
760        let ca1 = Float64Chunked::new(PlSmallStr::from_static("a"), &[1.0, f64::NAN]);
761        let ca2 = Float64Chunked::from_slice(PlSmallStr::from_static("b"), &[f64::NAN, 1.0]);
762        assert_eq!(ca1.min(), ca2.min());
763        println!("{:?}", (ca1.min(), ca2.min()))
764    }
765
766    #[test]
767    fn test_median() {
768        let ca = UInt32Chunked::new(
769            PlSmallStr::from_static("a"),
770            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
771        );
772        assert_eq!(ca.median(), Some(3.0));
773        let ca = UInt32Chunked::new(
774            PlSmallStr::from_static("a"),
775            &[
776                None,
777                Some(7),
778                Some(6),
779                Some(2),
780                Some(1),
781                None,
782                Some(3),
783                Some(5),
784                None,
785                Some(4),
786            ],
787        );
788        assert_eq!(ca.median(), Some(4.0));
789
790        let ca = Float32Chunked::from_slice(
791            PlSmallStr::EMPTY,
792            &[
793                0.166189, 0.166559, 0.168517, 0.169393, 0.175272, 0.233167, 0.238787, 0.266562,
794                0.26903, 0.285792, 0.292801, 0.293429, 0.301706, 0.308534, 0.331489, 0.346095,
795                0.367644, 0.369939, 0.372074, 0.41014, 0.415789, 0.421781, 0.427725, 0.465363,
796                0.500208, 2.621727, 2.803311, 3.868526,
797            ],
798        );
799        assert!((ca.median().unwrap() - 0.3200115).abs() < 0.0001)
800    }
801
802    #[test]
803    fn test_mean() {
804        let ca = Float32Chunked::new(PlSmallStr::EMPTY, &[Some(1.0), Some(2.0), None]);
805        assert_eq!(ca.mean().unwrap(), 1.5);
806        assert_eq!(
807            ca.into_series()
808                .mean_reduce()
809                .unwrap()
810                .value()
811                .extract::<f32>()
812                .unwrap(),
813            1.5
814        );
815        // all null values case
816        let ca = Float32Chunked::full_null(PlSmallStr::EMPTY, 3);
817        assert_eq!(ca.mean(), None);
818        assert_eq!(
819            ca.into_series()
820                .mean_reduce()
821                .unwrap()
822                .value()
823                .extract::<f32>(),
824            None
825        );
826    }
827
828    #[test]
829    fn test_quantile_all_null() {
830        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
831        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
832        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
833        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
834
835        let methods = vec![
836            QuantileMethod::Nearest,
837            QuantileMethod::Lower,
838            QuantileMethod::Higher,
839            QuantileMethod::Midpoint,
840            QuantileMethod::Linear,
841            QuantileMethod::Equiprobable,
842        ];
843
844        for method in methods {
845            assert_eq!(test_f32.quantile(0.9, method).unwrap(), None);
846            assert_eq!(test_i32.quantile(0.9, method).unwrap(), None);
847            assert_eq!(test_f64.quantile(0.9, method).unwrap(), None);
848            assert_eq!(test_i64.quantile(0.9, method).unwrap(), None);
849        }
850    }
851
852    #[test]
853    fn test_quantile_single_value() {
854        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
855        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
856        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
857        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
858
859        let methods = vec![
860            QuantileMethod::Nearest,
861            QuantileMethod::Lower,
862            QuantileMethod::Higher,
863            QuantileMethod::Midpoint,
864            QuantileMethod::Linear,
865            QuantileMethod::Equiprobable,
866        ];
867
868        for method in methods {
869            assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0));
870            assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0));
871            assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0));
872            assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0));
873        }
874    }
875
876    #[test]
877    fn test_quantile_min_max() {
878        let test_f32 = Float32Chunked::from_slice_options(
879            PlSmallStr::EMPTY,
880            &[None, Some(1f32), Some(5f32), Some(1f32)],
881        );
882        let test_i32 = Int32Chunked::from_slice_options(
883            PlSmallStr::EMPTY,
884            &[None, Some(1i32), Some(5i32), Some(1i32)],
885        );
886        let test_f64 = Float64Chunked::from_slice_options(
887            PlSmallStr::EMPTY,
888            &[None, Some(1f64), Some(5f64), Some(1f64)],
889        );
890        let test_i64 = Int64Chunked::from_slice_options(
891            PlSmallStr::EMPTY,
892            &[None, Some(1i64), Some(5i64), Some(1i64)],
893        );
894
895        let methods = vec![
896            QuantileMethod::Nearest,
897            QuantileMethod::Lower,
898            QuantileMethod::Higher,
899            QuantileMethod::Midpoint,
900            QuantileMethod::Linear,
901            QuantileMethod::Equiprobable,
902        ];
903
904        for method in methods {
905            assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min());
906            assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max());
907
908            assert_eq!(
909                test_i32.quantile(0.0, method).unwrap().unwrap(),
910                test_i32.min().unwrap() as f64
911            );
912            assert_eq!(
913                test_i32.quantile(1.0, method).unwrap().unwrap(),
914                test_i32.max().unwrap() as f64
915            );
916
917            assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min());
918            assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max());
919            assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median());
920
921            assert_eq!(
922                test_i64.quantile(0.0, method).unwrap().unwrap(),
923                test_i64.min().unwrap() as f64
924            );
925            assert_eq!(
926                test_i64.quantile(1.0, method).unwrap().unwrap(),
927                test_i64.max().unwrap() as f64
928            );
929        }
930    }
931
932    #[test]
933    fn test_quantile() {
934        let ca = UInt32Chunked::new(
935            PlSmallStr::from_static("a"),
936            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
937        );
938
939        assert_eq!(
940            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
941            Some(1.0)
942        );
943        assert_eq!(
944            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
945            Some(5.0)
946        );
947        assert_eq!(
948            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
949            Some(3.0)
950        );
951
952        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
953        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0));
954        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0));
955
956        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
957        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0));
958        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0));
959
960        assert_eq!(
961            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
962            Some(1.5)
963        );
964        assert_eq!(
965            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
966            Some(4.5)
967        );
968        assert_eq!(
969            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
970            Some(3.5)
971        );
972
973        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4));
974        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6));
975        assert!(
976            (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001
977        );
978
979        assert_eq!(
980            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
981            Some(1.0)
982        );
983        assert_eq!(
984            ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(),
985            Some(2.0)
986        );
987        assert_eq!(
988            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
989            Some(3.0)
990        );
991
992        let ca = UInt32Chunked::new(
993            PlSmallStr::from_static("a"),
994            &[
995                None,
996                Some(7),
997                Some(6),
998                Some(2),
999                Some(1),
1000                None,
1001                Some(3),
1002                Some(5),
1003                None,
1004                Some(4),
1005            ],
1006        );
1007
1008        assert_eq!(
1009            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
1010            Some(2.0)
1011        );
1012        assert_eq!(
1013            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
1014            Some(6.0)
1015        );
1016        assert_eq!(
1017            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
1018            Some(5.0)
1019        );
1020
1021        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
1022        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0));
1023        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0));
1024
1025        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
1026        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0));
1027        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0));
1028
1029        assert_eq!(
1030            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
1031            Some(1.5)
1032        );
1033        assert_eq!(
1034            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
1035            Some(6.5)
1036        );
1037        assert_eq!(
1038            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
1039            Some(4.5)
1040        );
1041
1042        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6));
1043        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4));
1044        assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6));
1045
1046        assert_eq!(
1047            ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(),
1048            Some(1.0)
1049        );
1050        assert_eq!(
1051            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
1052            Some(2.0)
1053        );
1054        assert_eq!(
1055            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
1056            Some(5.0)
1057        );
1058    }
1059}