polars_core/chunked_array/ops/aggregate/
mod.rs

1//! Implementations of the ChunkAgg trait.
2mod quantile;
3mod var;
4
5use arrow::types::NativeType;
6use num_traits::{AsPrimitive, Float, One, ToPrimitive, Zero};
7use polars_compute::float_sum;
8use polars_compute::min_max::MinMaxKernel;
9use polars_compute::rolling::QuantileMethod;
10use polars_compute::sum::{WrappingSum, wrapping_sum_arr};
11use polars_utils::float16::pf16;
12use polars_utils::min_max::MinMax;
13pub use quantile::*;
14pub use var::*;
15
16use super::float_sorted_arg_max::{
17    float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
18};
19use crate::chunked_array::{ChunkedArray, arg_max_binary, arg_min_binary};
20use crate::datatypes::{BooleanChunked, PolarsNumericType};
21use crate::prelude::*;
22use crate::series::IsSorted;
23
24/// Aggregations that return [`Series`] of unit length. Those can be used in broadcasting operations.
25pub trait ChunkAggSeries {
26    /// Get the sum of the [`ChunkedArray`] as a new [`Series`] of length 1.
27    fn sum_reduce(&self) -> Scalar {
28        unimplemented!()
29    }
30    /// Get the max of the [`ChunkedArray`] as a new [`Series`] of length 1.
31    fn max_reduce(&self) -> Scalar {
32        unimplemented!()
33    }
34    /// Get the min of the [`ChunkedArray`] as a new [`Series`] of length 1.
35    fn min_reduce(&self) -> Scalar {
36        unimplemented!()
37    }
38    /// Get the product of the [`ChunkedArray`] as a new [`Series`] of length 1.
39    fn prod_reduce(&self) -> Scalar {
40        unimplemented!()
41    }
42}
43
44fn sum<T>(array: &PrimitiveArray<T>) -> T
45where
46    T: NumericNative + NativeType + WrappingSum,
47{
48    if array.null_count() == array.len() {
49        return T::default();
50    }
51
52    if T::is_float() {
53        unsafe {
54            if T::is_f16() {
55                let f16_arr =
56                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<pf16>>(array);
57                // We do not trust the numerical accuracy of f16 summation
58                let sum: pf16 = float_sum::sum_arr_as_f32(f16_arr).as_();
59                std::mem::transmute_copy::<pf16, T>(&sum)
60            } else if T::is_f32() {
61                let f32_arr =
62                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f32>>(array);
63                let sum = float_sum::sum_arr_as_f32(f32_arr);
64                std::mem::transmute_copy::<f32, T>(&sum)
65            } else if T::is_f64() {
66                let f64_arr =
67                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f64>>(array);
68                let sum = float_sum::sum_arr_as_f64(f64_arr);
69                std::mem::transmute_copy::<f64, T>(&sum)
70            } else {
71                unreachable!("only supported float types are f16, f32 and f64");
72            }
73        }
74    } else {
75        wrapping_sum_arr(array)
76    }
77}
78
79impl<T> ChunkAgg<T::Native> for ChunkedArray<T>
80where
81    T: PolarsNumericType,
82    T::Native: WrappingSum,
83    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
84{
85    fn sum(&self) -> Option<T::Native> {
86        Some(
87            self.downcast_iter()
88                .map(sum)
89                .fold(T::Native::zero(), |acc, v| acc + v),
90        )
91    }
92
93    fn _sum_as_f64(&self) -> f64 {
94        self.downcast_iter().map(float_sum::sum_arr_as_f64).sum()
95    }
96
97    fn min(&self) -> Option<T::Native> {
98        if self.null_count() == self.len() {
99            return None;
100        }
101
102        // There is at least one non-null value.
103
104        match self.is_sorted_flag() {
105            IsSorted::Ascending => {
106                let idx = self.first_non_null().unwrap();
107                unsafe { self.get_unchecked(idx) }
108            },
109            IsSorted::Descending => {
110                let idx = self.last_non_null().unwrap();
111                unsafe { self.get_unchecked(idx) }
112            },
113            IsSorted::Not => self
114                .downcast_iter()
115                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
116                .reduce(MinMax::min_ignore_nan),
117        }
118    }
119
120    fn max(&self) -> Option<T::Native> {
121        if self.null_count() == self.len() {
122            return None;
123        }
124        // There is at least one non-null value.
125
126        match self.is_sorted_flag() {
127            IsSorted::Ascending => {
128                let idx = if T::get_static_dtype().is_float() {
129                    float_arg_max_sorted_ascending(self)
130                } else {
131                    self.last_non_null().unwrap()
132                };
133
134                unsafe { self.get_unchecked(idx) }
135            },
136            IsSorted::Descending => {
137                let idx = if T::get_static_dtype().is_float() {
138                    float_arg_max_sorted_descending(self)
139                } else {
140                    self.first_non_null().unwrap()
141                };
142
143                unsafe { self.get_unchecked(idx) }
144            },
145            IsSorted::Not => self
146                .downcast_iter()
147                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
148                .reduce(MinMax::max_ignore_nan),
149        }
150    }
151
152    fn min_max(&self) -> Option<(T::Native, T::Native)> {
153        if self.null_count() == self.len() {
154            return None;
155        }
156        // There is at least one non-null value.
157
158        match self.is_sorted_flag() {
159            IsSorted::Ascending => {
160                let min = unsafe { self.get_unchecked(self.first_non_null().unwrap()) };
161                let max = {
162                    let idx = if T::get_static_dtype().is_float() {
163                        float_arg_max_sorted_ascending(self)
164                    } else {
165                        self.last_non_null().unwrap()
166                    };
167
168                    unsafe { self.get_unchecked(idx) }
169                };
170                min.zip(max)
171            },
172            IsSorted::Descending => {
173                let min = unsafe { self.get_unchecked(self.last_non_null().unwrap()) };
174                let max = {
175                    let idx = if T::get_static_dtype().is_float() {
176                        float_arg_max_sorted_descending(self)
177                    } else {
178                        self.first_non_null().unwrap()
179                    };
180
181                    unsafe { self.get_unchecked(idx) }
182                };
183
184                min.zip(max)
185            },
186            IsSorted::Not => self
187                .downcast_iter()
188                .filter_map(MinMaxKernel::min_max_ignore_nan_kernel)
189                .reduce(|(min1, max1), (min2, max2)| {
190                    (
191                        MinMax::min_ignore_nan(min1, min2),
192                        MinMax::max_ignore_nan(max1, max2),
193                    )
194                }),
195        }
196    }
197
198    fn mean(&self) -> Option<f64> {
199        let count = self.len() - self.null_count();
200        if count == 0 {
201            return None;
202        }
203        Some(self._sum_as_f64() / count as f64)
204    }
205}
206
207/// Booleans are cast to 1 or 0.
208impl BooleanChunked {
209    pub fn sum(&self) -> Option<IdxSize> {
210        Some(if self.is_empty() {
211            0
212        } else {
213            self.downcast_iter()
214                .map(|arr| match arr.validity() {
215                    Some(validity) => {
216                        (arr.len() - (validity & arr.values()).unset_bits()) as IdxSize
217                    },
218                    None => (arr.len() - arr.values().unset_bits()) as IdxSize,
219                })
220                .sum()
221        })
222    }
223
224    pub fn min(&self) -> Option<bool> {
225        let nc = self.null_count();
226        let len = self.len();
227        if self.is_empty() || nc == len {
228            return None;
229        }
230        if nc == 0 {
231            if self.all() { Some(true) } else { Some(false) }
232        } else {
233            // we can unwrap as we already checked empty and all null above
234            if (self.sum().unwrap() + nc as IdxSize) == len as IdxSize {
235                Some(true)
236            } else {
237                Some(false)
238            }
239        }
240    }
241
242    pub fn max(&self) -> Option<bool> {
243        if self.is_empty() || self.null_count() == self.len() {
244            return None;
245        }
246        if self.any() { Some(true) } else { Some(false) }
247    }
248    pub fn mean(&self) -> Option<f64> {
249        if self.is_empty() || self.null_count() == self.len() {
250            return None;
251        }
252        self.sum()
253            .map(|sum| sum as f64 / (self.len() - self.null_count()) as f64)
254    }
255}
256
257// Needs the same trait bounds as the implementation of ChunkedArray<T> of dyn Series.
258impl<T> ChunkAggSeries for ChunkedArray<T>
259where
260    T: PolarsNumericType,
261    T::Native: WrappingSum,
262    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
263{
264    fn sum_reduce(&self) -> Scalar {
265        let v: Option<T::Native> = self.sum();
266        Scalar::new(T::get_static_dtype(), v.into())
267    }
268
269    fn max_reduce(&self) -> Scalar {
270        let v = ChunkAgg::max(self);
271        Scalar::new(T::get_static_dtype(), v.into())
272    }
273
274    fn min_reduce(&self) -> Scalar {
275        let v = ChunkAgg::min(self);
276        Scalar::new(T::get_static_dtype(), v.into())
277    }
278
279    fn prod_reduce(&self) -> Scalar {
280        let mut prod = T::Native::one();
281
282        for arr in self.downcast_iter() {
283            for v in arr.into_iter().flatten() {
284                prod = prod * *v
285            }
286        }
287        Scalar::new(T::get_static_dtype(), prod.into())
288    }
289}
290
291impl<T> VarAggSeries for ChunkedArray<T>
292where
293    T: PolarsIntegerType,
294    ChunkedArray<T>: ChunkVar,
295{
296    fn var_reduce(&self, ddof: u8) -> Scalar {
297        let v = self.var(ddof);
298        Scalar::new(DataType::Float64, v.into())
299    }
300
301    fn std_reduce(&self, ddof: u8) -> Scalar {
302        let v = self.std(ddof);
303        Scalar::new(DataType::Float64, v.into())
304    }
305}
306
307#[cfg(feature = "dtype-f16")]
308impl VarAggSeries for Float16Chunked {
309    fn var_reduce(&self, ddof: u8) -> Scalar {
310        let v = self.var(ddof).map(AsPrimitive::<pf16>::as_);
311        Scalar::new(DataType::Float16, v.into())
312    }
313
314    fn std_reduce(&self, ddof: u8) -> Scalar {
315        let v = self.std(ddof).map(AsPrimitive::<pf16>::as_);
316        Scalar::new(DataType::Float16, v.into())
317    }
318}
319
320impl VarAggSeries for Float32Chunked {
321    fn var_reduce(&self, ddof: u8) -> Scalar {
322        let v = self.var(ddof).map(|v| v as f32);
323        Scalar::new(DataType::Float32, v.into())
324    }
325
326    fn std_reduce(&self, ddof: u8) -> Scalar {
327        let v = self.std(ddof).map(|v| v as f32);
328        Scalar::new(DataType::Float32, v.into())
329    }
330}
331
332impl VarAggSeries for Float64Chunked {
333    fn var_reduce(&self, ddof: u8) -> Scalar {
334        let v = self.var(ddof);
335        Scalar::new(DataType::Float64, v.into())
336    }
337
338    fn std_reduce(&self, ddof: u8) -> Scalar {
339        let v = self.std(ddof);
340        Scalar::new(DataType::Float64, v.into())
341    }
342}
343
344impl<T> QuantileAggSeries for ChunkedArray<T>
345where
346    T: PolarsIntegerType,
347    T::Native: Ord + WrappingSum,
348{
349    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
350        let v = self.quantile(quantile, method)?;
351        Ok(Scalar::new(DataType::Float64, v.into()))
352    }
353
354    fn quantiles_reduce(&self, quantiles: &[f64], method: QuantileMethod) -> PolarsResult<Scalar> {
355        let v = self.quantiles(quantiles, method)?;
356        let s =
357            Float64Chunked::from_iter_options(PlSmallStr::from_static("quantiles"), v.into_iter())
358                .into_series();
359        let dtype = DataType::List(Box::new(s.dtype().clone()));
360        Ok(Scalar::new(dtype, AnyValue::List(s)))
361    }
362
363    fn median_reduce(&self) -> Scalar {
364        let v = self.median();
365        Scalar::new(DataType::Float64, v.into())
366    }
367}
368
369#[cfg(feature = "dtype-f16")]
370impl QuantileAggSeries for Float16Chunked {
371    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
372        let v = self.quantile(quantile, method)?;
373        Ok(Scalar::new(DataType::Float16, v.into()))
374    }
375
376    fn quantiles_reduce(&self, quantiles: &[f64], method: QuantileMethod) -> PolarsResult<Scalar> {
377        let v = self.quantiles(quantiles, method)?;
378        let s =
379            Float16Chunked::from_iter_options(PlSmallStr::from_static("quantiles"), v.into_iter())
380                .into_series();
381        let dtype = DataType::List(Box::new(s.dtype().clone()));
382        Ok(Scalar::new(dtype, AnyValue::List(s)))
383    }
384
385    fn median_reduce(&self) -> Scalar {
386        let v = self.median();
387        Scalar::new(DataType::Float16, v.into())
388    }
389}
390
391impl QuantileAggSeries for Float32Chunked {
392    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
393        let v = self.quantile(quantile, method)?;
394        Ok(Scalar::new(DataType::Float32, v.into()))
395    }
396
397    fn quantiles_reduce(&self, quantiles: &[f64], method: QuantileMethod) -> PolarsResult<Scalar> {
398        let v = self.quantiles(quantiles, method)?;
399        let s =
400            Float32Chunked::from_iter_options(PlSmallStr::from_static("quantiles"), v.into_iter())
401                .into_series();
402        let dtype = DataType::List(Box::new(s.dtype().clone()));
403        Ok(Scalar::new(dtype, AnyValue::List(s)))
404    }
405
406    fn median_reduce(&self) -> Scalar {
407        let v = self.median();
408        Scalar::new(DataType::Float32, v.into())
409    }
410}
411
412impl QuantileAggSeries for Float64Chunked {
413    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
414        let v = self.quantile(quantile, method)?;
415        Ok(Scalar::new(DataType::Float64, v.into()))
416    }
417
418    fn quantiles_reduce(&self, quantiles: &[f64], method: QuantileMethod) -> PolarsResult<Scalar> {
419        let v = self.quantiles(quantiles, method)?;
420        let s =
421            Float64Chunked::from_iter_options(PlSmallStr::from_static("quantiles"), v.into_iter())
422                .into_series();
423        let dtype = DataType::List(Box::new(s.dtype().clone()));
424        Ok(Scalar::new(dtype, AnyValue::List(s)))
425    }
426
427    fn median_reduce(&self) -> Scalar {
428        let v = self.median();
429        Scalar::new(DataType::Float64, v.into())
430    }
431}
432
433impl ChunkAggSeries for BooleanChunked {
434    fn sum_reduce(&self) -> Scalar {
435        let v = self.sum();
436        Scalar::new(IDX_DTYPE, v.into())
437    }
438    fn max_reduce(&self) -> Scalar {
439        let v = self.max();
440        Scalar::new(DataType::Boolean, v.into())
441    }
442    fn min_reduce(&self) -> Scalar {
443        let v = self.min();
444        Scalar::new(DataType::Boolean, v.into())
445    }
446}
447
448impl StringChunked {
449    pub(crate) fn max_str(&self) -> Option<&str> {
450        if self.is_empty() {
451            return None;
452        }
453        match self.is_sorted_flag() {
454            IsSorted::Ascending => {
455                self.last_non_null().and_then(|idx| {
456                    // SAFETY: last_non_null returns in bound index
457                    unsafe { self.get_unchecked(idx) }
458                })
459            },
460            IsSorted::Descending => {
461                self.first_non_null().and_then(|idx| {
462                    // SAFETY: first_non_null returns in bound index
463                    unsafe { self.get_unchecked(idx) }
464                })
465            },
466            IsSorted::Not => self
467                .downcast_iter()
468                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
469                .reduce(MinMax::max_ignore_nan),
470        }
471    }
472    pub(crate) fn min_str(&self) -> Option<&str> {
473        if self.is_empty() {
474            return None;
475        }
476        match self.is_sorted_flag() {
477            IsSorted::Ascending => {
478                self.first_non_null().and_then(|idx| {
479                    // SAFETY: first_non_null returns in bound index
480                    unsafe { self.get_unchecked(idx) }
481                })
482            },
483            IsSorted::Descending => {
484                self.last_non_null().and_then(|idx| {
485                    // SAFETY: last_non_null returns in bound index
486                    unsafe { self.get_unchecked(idx) }
487                })
488            },
489            IsSorted::Not => self
490                .downcast_iter()
491                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
492                .reduce(MinMax::min_ignore_nan),
493        }
494    }
495}
496
497impl ChunkAggSeries for StringChunked {
498    fn max_reduce(&self) -> Scalar {
499        let av: AnyValue = self.max_str().into();
500        Scalar::new(DataType::String, av.into_static())
501    }
502    fn min_reduce(&self) -> Scalar {
503        let av: AnyValue = self.min_str().into();
504        Scalar::new(DataType::String, av.into_static())
505    }
506}
507
508#[cfg(feature = "dtype-categorical")]
509impl<T: PolarsCategoricalType> CategoricalChunked<T>
510where
511    ChunkedArray<T::PolarsPhysical>: ChunkAgg<T::Native>,
512{
513    fn min_categorical(&self) -> Option<CatSize> {
514        if self.is_empty() || self.null_count() == self.len() {
515            return None;
516        }
517        if self.uses_lexical_ordering() {
518            let mapping = self.get_mapping();
519            let s = self
520                .physical()
521                .iter()
522                .flat_map(|opt_cat| {
523                    Some(unsafe { mapping.cat_to_str_unchecked(opt_cat?.as_cat()) })
524                })
525                .min();
526            mapping.get_cat(s.unwrap())
527        } else {
528            Some(self.physical().min()?.as_cat())
529        }
530    }
531
532    fn max_categorical(&self) -> Option<CatSize> {
533        if self.is_empty() || self.null_count() == self.len() {
534            return None;
535        }
536        if self.uses_lexical_ordering() {
537            let mapping = self.get_mapping();
538            let s = self
539                .physical()
540                .iter()
541                .flat_map(|opt_cat| {
542                    Some(unsafe { mapping.cat_to_str_unchecked(opt_cat?.as_cat()) })
543                })
544                .max();
545            mapping.get_cat(s.unwrap())
546        } else {
547            Some(self.physical().max()?.as_cat())
548        }
549    }
550}
551
552#[cfg(feature = "dtype-categorical")]
553impl<T: PolarsCategoricalType> ChunkAggSeries for CategoricalChunked<T>
554where
555    ChunkedArray<T::PolarsPhysical>: ChunkAgg<T::Native>,
556{
557    fn min_reduce(&self) -> Scalar {
558        let Some(min) = self.min_categorical() else {
559            return Scalar::new(self.dtype().clone(), AnyValue::Null);
560        };
561        let av = match self.dtype() {
562            DataType::Enum(_, mapping) => AnyValue::EnumOwned(min, mapping.clone()),
563            DataType::Categorical(_, mapping) => AnyValue::CategoricalOwned(min, mapping.clone()),
564            _ => unreachable!(),
565        };
566        Scalar::new(self.dtype().clone(), av)
567    }
568
569    fn max_reduce(&self) -> Scalar {
570        let Some(max) = self.max_categorical() else {
571            return Scalar::new(self.dtype().clone(), AnyValue::Null);
572        };
573        let av = match self.dtype() {
574            DataType::Enum(_, mapping) => AnyValue::EnumOwned(max, mapping.clone()),
575            DataType::Categorical(_, mapping) => AnyValue::CategoricalOwned(max, mapping.clone()),
576            _ => unreachable!(),
577        };
578        Scalar::new(self.dtype().clone(), av)
579    }
580}
581
582impl BinaryChunked {
583    pub fn max_binary(&self) -> Option<&[u8]> {
584        if self.is_empty() {
585            return None;
586        }
587        match self.is_sorted_flag() {
588            IsSorted::Ascending => {
589                self.last_non_null().and_then(|idx| {
590                    // SAFETY: last_non_null returns in bound index.
591                    unsafe { self.get_unchecked(idx) }
592                })
593            },
594            IsSorted::Descending => {
595                self.first_non_null().and_then(|idx| {
596                    // SAFETY: first_non_null returns in bound index.
597                    unsafe { self.get_unchecked(idx) }
598                })
599            },
600            IsSorted::Not => self
601                .downcast_iter()
602                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
603                .reduce(MinMax::max_ignore_nan),
604        }
605    }
606
607    pub fn min_binary(&self) -> Option<&[u8]> {
608        if self.is_empty() {
609            return None;
610        }
611        match self.is_sorted_flag() {
612            IsSorted::Ascending => {
613                self.first_non_null().and_then(|idx| {
614                    // SAFETY: first_non_null returns in bound index.
615                    unsafe { self.get_unchecked(idx) }
616                })
617            },
618            IsSorted::Descending => {
619                self.last_non_null().and_then(|idx| {
620                    // SAFETY: last_non_null returns in bound index.
621                    unsafe { self.get_unchecked(idx) }
622                })
623            },
624            IsSorted::Not => self
625                .downcast_iter()
626                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
627                .reduce(MinMax::min_ignore_nan),
628        }
629    }
630    pub fn arg_min_binary(&self) -> Option<usize> {
631        if self.is_empty() || self.null_count() == self.len() {
632            return None;
633        }
634
635        match self.is_sorted_flag() {
636            IsSorted::Ascending => self.first_non_null(),
637            IsSorted::Descending => self.last_non_null(),
638            IsSorted::Not => arg_min_binary(self),
639        }
640    }
641
642    pub fn arg_max_binary(&self) -> Option<usize> {
643        if self.is_empty() || self.null_count() == self.len() {
644            return None;
645        }
646
647        match self.is_sorted_flag() {
648            IsSorted::Ascending => self.last_non_null(),
649            IsSorted::Descending => self.first_non_null(),
650            IsSorted::Not => arg_max_binary(self),
651        }
652    }
653}
654
655impl ChunkAggSeries for BinaryChunked {
656    fn sum_reduce(&self) -> Scalar {
657        unimplemented!()
658    }
659    fn max_reduce(&self) -> Scalar {
660        let av: AnyValue = self.max_binary().into();
661        Scalar::new(self.dtype().clone(), av.into_static())
662    }
663    fn min_reduce(&self) -> Scalar {
664        let av: AnyValue = self.min_binary().into();
665        Scalar::new(self.dtype().clone(), av.into_static())
666    }
667}
668
669#[cfg(feature = "object")]
670impl<T: PolarsObject> ChunkAggSeries for ObjectChunked<T> {}
671
672#[cfg(test)]
673mod test {
674    use polars_compute::rolling::QuantileMethod;
675
676    use crate::prelude::*;
677
678    #[test]
679    #[cfg(not(miri))]
680    fn test_var() {
681        // Validated with numpy. Note that numpy uses ddof as an argument which
682        // influences results. The default ddof=0, we chose ddof=1, which is
683        // standard in statistics.
684        let ca1 = Int32Chunked::new(PlSmallStr::EMPTY, &[5, 8, 9, 5, 0]);
685        let ca2 = Int32Chunked::new(
686            PlSmallStr::EMPTY,
687            &[
688                Some(5),
689                None,
690                Some(8),
691                Some(9),
692                None,
693                Some(5),
694                Some(0),
695                None,
696            ],
697        );
698        for ca in &[ca1, ca2] {
699            let out = ca.var(1);
700            assert_eq!(out, Some(12.3));
701            let out = ca.std(1).unwrap();
702            assert!((3.5071355833500366 - out).abs() < 0.000000001);
703        }
704    }
705
706    #[test]
707    fn test_agg_float() {
708        let ca1 = Float32Chunked::new(PlSmallStr::from_static("a"), &[1.0, f32::NAN]);
709        let ca2 = Float32Chunked::new(PlSmallStr::from_static("b"), &[f32::NAN, 1.0]);
710        assert_eq!(ca1.min(), ca2.min());
711        let ca1 = Float64Chunked::new(PlSmallStr::from_static("a"), &[1.0, f64::NAN]);
712        let ca2 = Float64Chunked::from_slice(PlSmallStr::from_static("b"), &[f64::NAN, 1.0]);
713        assert_eq!(ca1.min(), ca2.min());
714        println!("{:?}", (ca1.min(), ca2.min()))
715    }
716
717    #[test]
718    fn test_median() {
719        let ca = UInt32Chunked::new(
720            PlSmallStr::from_static("a"),
721            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
722        );
723        assert_eq!(ca.median(), Some(3.0));
724        let ca = UInt32Chunked::new(
725            PlSmallStr::from_static("a"),
726            &[
727                None,
728                Some(7),
729                Some(6),
730                Some(2),
731                Some(1),
732                None,
733                Some(3),
734                Some(5),
735                None,
736                Some(4),
737            ],
738        );
739        assert_eq!(ca.median(), Some(4.0));
740
741        let ca = Float32Chunked::from_slice(
742            PlSmallStr::EMPTY,
743            &[
744                0.166189, 0.166559, 0.168517, 0.169393, 0.175272, 0.233167, 0.238787, 0.266562,
745                0.26903, 0.285792, 0.292801, 0.293429, 0.301706, 0.308534, 0.331489, 0.346095,
746                0.367644, 0.369939, 0.372074, 0.41014, 0.415789, 0.421781, 0.427725, 0.465363,
747                0.500208, 2.621727, 2.803311, 3.868526,
748            ],
749        );
750        assert!((ca.median().unwrap() - 0.3200115).abs() < 0.0001)
751    }
752
753    #[test]
754    fn test_mean() {
755        let ca = Float32Chunked::new(PlSmallStr::EMPTY, &[Some(1.0), Some(2.0), None]);
756        assert_eq!(ca.mean().unwrap(), 1.5);
757        assert_eq!(
758            ca.into_series()
759                .mean_reduce()
760                .unwrap()
761                .value()
762                .extract::<f32>()
763                .unwrap(),
764            1.5
765        );
766        // all null values case
767        let ca = Float32Chunked::full_null(PlSmallStr::EMPTY, 3);
768        assert_eq!(ca.mean(), None);
769        assert_eq!(
770            ca.into_series()
771                .mean_reduce()
772                .unwrap()
773                .value()
774                .extract::<f32>(),
775            None
776        );
777    }
778
779    #[test]
780    fn test_quantile_all_null() {
781        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
782        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
783        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
784        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
785
786        let methods = vec![
787            QuantileMethod::Nearest,
788            QuantileMethod::Lower,
789            QuantileMethod::Higher,
790            QuantileMethod::Midpoint,
791            QuantileMethod::Linear,
792            QuantileMethod::Equiprobable,
793        ];
794
795        for method in methods {
796            assert_eq!(test_f32.quantile(0.9, method).unwrap(), None);
797            assert_eq!(test_i32.quantile(0.9, method).unwrap(), None);
798            assert_eq!(test_f64.quantile(0.9, method).unwrap(), None);
799            assert_eq!(test_i64.quantile(0.9, method).unwrap(), None);
800        }
801    }
802
803    #[test]
804    fn test_quantile_single_value() {
805        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
806        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
807        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
808        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
809
810        let methods = vec![
811            QuantileMethod::Nearest,
812            QuantileMethod::Lower,
813            QuantileMethod::Higher,
814            QuantileMethod::Midpoint,
815            QuantileMethod::Linear,
816            QuantileMethod::Equiprobable,
817        ];
818
819        for method in methods {
820            assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0));
821            assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0));
822            assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0));
823            assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0));
824        }
825    }
826
827    #[test]
828    fn test_quantile_min_max() {
829        let test_f32 = Float32Chunked::from_slice_options(
830            PlSmallStr::EMPTY,
831            &[None, Some(1f32), Some(5f32), Some(1f32)],
832        );
833        let test_i32 = Int32Chunked::from_slice_options(
834            PlSmallStr::EMPTY,
835            &[None, Some(1i32), Some(5i32), Some(1i32)],
836        );
837        let test_f64 = Float64Chunked::from_slice_options(
838            PlSmallStr::EMPTY,
839            &[None, Some(1f64), Some(5f64), Some(1f64)],
840        );
841        let test_i64 = Int64Chunked::from_slice_options(
842            PlSmallStr::EMPTY,
843            &[None, Some(1i64), Some(5i64), Some(1i64)],
844        );
845
846        let methods = vec![
847            QuantileMethod::Nearest,
848            QuantileMethod::Lower,
849            QuantileMethod::Higher,
850            QuantileMethod::Midpoint,
851            QuantileMethod::Linear,
852            QuantileMethod::Equiprobable,
853        ];
854
855        for method in methods {
856            assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min());
857            assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max());
858
859            assert_eq!(
860                test_i32.quantile(0.0, method).unwrap().unwrap(),
861                test_i32.min().unwrap() as f64
862            );
863            assert_eq!(
864                test_i32.quantile(1.0, method).unwrap().unwrap(),
865                test_i32.max().unwrap() as f64
866            );
867
868            assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min());
869            assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max());
870            assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median());
871
872            assert_eq!(
873                test_i64.quantile(0.0, method).unwrap().unwrap(),
874                test_i64.min().unwrap() as f64
875            );
876            assert_eq!(
877                test_i64.quantile(1.0, method).unwrap().unwrap(),
878                test_i64.max().unwrap() as f64
879            );
880        }
881    }
882
883    #[test]
884    fn test_quantile() {
885        let ca = UInt32Chunked::new(
886            PlSmallStr::from_static("a"),
887            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
888        );
889
890        assert_eq!(
891            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
892            Some(1.0)
893        );
894        assert_eq!(
895            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
896            Some(5.0)
897        );
898        assert_eq!(
899            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
900            Some(3.0)
901        );
902
903        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
904        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0));
905        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0));
906
907        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
908        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0));
909        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0));
910
911        assert_eq!(
912            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
913            Some(1.5)
914        );
915        assert_eq!(
916            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
917            Some(4.5)
918        );
919        assert_eq!(
920            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
921            Some(3.5)
922        );
923
924        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4));
925        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6));
926        assert!(
927            (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001
928        );
929
930        assert_eq!(
931            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
932            Some(1.0)
933        );
934        assert_eq!(
935            ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(),
936            Some(2.0)
937        );
938        assert_eq!(
939            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
940            Some(3.0)
941        );
942
943        let ca = UInt32Chunked::new(
944            PlSmallStr::from_static("a"),
945            &[
946                None,
947                Some(7),
948                Some(6),
949                Some(2),
950                Some(1),
951                None,
952                Some(3),
953                Some(5),
954                None,
955                Some(4),
956            ],
957        );
958
959        assert_eq!(
960            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
961            Some(2.0)
962        );
963        assert_eq!(
964            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
965            Some(6.0)
966        );
967        assert_eq!(
968            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
969            Some(5.0)
970        );
971
972        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
973        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0));
974        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0));
975
976        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
977        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0));
978        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0));
979
980        assert_eq!(
981            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
982            Some(1.5)
983        );
984        assert_eq!(
985            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
986            Some(6.5)
987        );
988        assert_eq!(
989            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
990            Some(4.5)
991        );
992
993        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6));
994        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4));
995        assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6));
996
997        assert_eq!(
998            ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(),
999            Some(1.0)
1000        );
1001        assert_eq!(
1002            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
1003            Some(2.0)
1004        );
1005        assert_eq!(
1006            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
1007            Some(5.0)
1008        );
1009    }
1010}