polars_core/chunked_array/ops/aggregate/
mod.rs

1//! Implementations of the ChunkAgg trait.
2mod quantile;
3mod var;
4
5use arrow::types::NativeType;
6use num_traits::{Float, One, ToPrimitive, Zero};
7use polars_compute::float_sum;
8use polars_compute::min_max::MinMaxKernel;
9use polars_compute::rolling::QuantileMethod;
10use polars_compute::sum::{WrappingSum, wrapping_sum_arr};
11use polars_utils::min_max::MinMax;
12pub use quantile::*;
13pub use var::*;
14
15use super::float_sorted_arg_max::{
16    float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
17};
18use crate::chunked_array::ChunkedArray;
19use crate::datatypes::{BooleanChunked, PolarsNumericType};
20use crate::prelude::*;
21use crate::series::IsSorted;
22
23/// Aggregations that return [`Series`] of unit length. Those can be used in broadcasting operations.
24pub trait ChunkAggSeries {
25    /// Get the sum of the [`ChunkedArray`] as a new [`Series`] of length 1.
26    fn sum_reduce(&self) -> Scalar {
27        unimplemented!()
28    }
29    /// Get the max of the [`ChunkedArray`] as a new [`Series`] of length 1.
30    fn max_reduce(&self) -> Scalar {
31        unimplemented!()
32    }
33    /// Get the min of the [`ChunkedArray`] as a new [`Series`] of length 1.
34    fn min_reduce(&self) -> Scalar {
35        unimplemented!()
36    }
37    /// Get the product of the [`ChunkedArray`] as a new [`Series`] of length 1.
38    fn prod_reduce(&self) -> Scalar {
39        unimplemented!()
40    }
41}
42
43fn sum<T>(array: &PrimitiveArray<T>) -> T
44where
45    T: NumericNative + NativeType + WrappingSum,
46{
47    if array.null_count() == array.len() {
48        return T::default();
49    }
50
51    if T::is_float() {
52        unsafe {
53            if T::is_f32() {
54                let f32_arr =
55                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f32>>(array);
56                let sum = float_sum::sum_arr_as_f32(f32_arr);
57                std::mem::transmute_copy::<f32, T>(&sum)
58            } else if T::is_f64() {
59                let f64_arr =
60                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f64>>(array);
61                let sum = float_sum::sum_arr_as_f64(f64_arr);
62                std::mem::transmute_copy::<f64, T>(&sum)
63            } else {
64                unreachable!("only supported float types are f32 and f64");
65            }
66        }
67    } else {
68        wrapping_sum_arr(array)
69    }
70}
71
72impl<T> ChunkAgg<T::Native> for ChunkedArray<T>
73where
74    T: PolarsNumericType,
75    T::Native: WrappingSum,
76    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
77{
78    fn sum(&self) -> Option<T::Native> {
79        Some(
80            self.downcast_iter()
81                .map(sum)
82                .fold(T::Native::zero(), |acc, v| acc + v),
83        )
84    }
85
86    fn _sum_as_f64(&self) -> f64 {
87        self.downcast_iter().map(float_sum::sum_arr_as_f64).sum()
88    }
89
90    fn min(&self) -> Option<T::Native> {
91        if self.null_count() == self.len() {
92            return None;
93        }
94
95        // There is at least one non-null value.
96
97        match self.is_sorted_flag() {
98            IsSorted::Ascending => {
99                let idx = self.first_non_null().unwrap();
100                unsafe { self.get_unchecked(idx) }
101            },
102            IsSorted::Descending => {
103                let idx = self.last_non_null().unwrap();
104                unsafe { self.get_unchecked(idx) }
105            },
106            IsSorted::Not => self
107                .downcast_iter()
108                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
109                .reduce(MinMax::min_ignore_nan),
110        }
111    }
112
113    fn max(&self) -> Option<T::Native> {
114        if self.null_count() == self.len() {
115            return None;
116        }
117        // There is at least one non-null value.
118
119        match self.is_sorted_flag() {
120            IsSorted::Ascending => {
121                let idx = if T::get_static_dtype().is_float() {
122                    float_arg_max_sorted_ascending(self)
123                } else {
124                    self.last_non_null().unwrap()
125                };
126
127                unsafe { self.get_unchecked(idx) }
128            },
129            IsSorted::Descending => {
130                let idx = if T::get_static_dtype().is_float() {
131                    float_arg_max_sorted_descending(self)
132                } else {
133                    self.first_non_null().unwrap()
134                };
135
136                unsafe { self.get_unchecked(idx) }
137            },
138            IsSorted::Not => self
139                .downcast_iter()
140                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
141                .reduce(MinMax::max_ignore_nan),
142        }
143    }
144
145    fn min_max(&self) -> Option<(T::Native, T::Native)> {
146        if self.null_count() == self.len() {
147            return None;
148        }
149        // There is at least one non-null value.
150
151        match self.is_sorted_flag() {
152            IsSorted::Ascending => {
153                let min = unsafe { self.get_unchecked(self.first_non_null().unwrap()) };
154                let max = {
155                    let idx = if T::get_static_dtype().is_float() {
156                        float_arg_max_sorted_ascending(self)
157                    } else {
158                        self.last_non_null().unwrap()
159                    };
160
161                    unsafe { self.get_unchecked(idx) }
162                };
163                min.zip(max)
164            },
165            IsSorted::Descending => {
166                let min = unsafe { self.get_unchecked(self.last_non_null().unwrap()) };
167                let max = {
168                    let idx = if T::get_static_dtype().is_float() {
169                        float_arg_max_sorted_descending(self)
170                    } else {
171                        self.first_non_null().unwrap()
172                    };
173
174                    unsafe { self.get_unchecked(idx) }
175                };
176
177                min.zip(max)
178            },
179            IsSorted::Not => self
180                .downcast_iter()
181                .filter_map(MinMaxKernel::min_max_ignore_nan_kernel)
182                .reduce(|(min1, max1), (min2, max2)| {
183                    (
184                        MinMax::min_ignore_nan(min1, min2),
185                        MinMax::max_ignore_nan(max1, max2),
186                    )
187                }),
188        }
189    }
190
191    fn mean(&self) -> Option<f64> {
192        let count = self.len() - self.null_count();
193        if count == 0 {
194            return None;
195        }
196        Some(self._sum_as_f64() / count as f64)
197    }
198}
199
200/// Booleans are cast to 1 or 0.
201impl BooleanChunked {
202    pub fn sum(&self) -> Option<IdxSize> {
203        Some(if self.is_empty() {
204            0
205        } else {
206            self.downcast_iter()
207                .map(|arr| match arr.validity() {
208                    Some(validity) => {
209                        (arr.len() - (validity & arr.values()).unset_bits()) as IdxSize
210                    },
211                    None => (arr.len() - arr.values().unset_bits()) as IdxSize,
212                })
213                .sum()
214        })
215    }
216
217    pub fn min(&self) -> Option<bool> {
218        let nc = self.null_count();
219        let len = self.len();
220        if self.is_empty() || nc == len {
221            return None;
222        }
223        if nc == 0 {
224            if self.all() { Some(true) } else { Some(false) }
225        } else {
226            // we can unwrap as we already checked empty and all null above
227            if (self.sum().unwrap() + nc as IdxSize) == len as IdxSize {
228                Some(true)
229            } else {
230                Some(false)
231            }
232        }
233    }
234
235    pub fn max(&self) -> Option<bool> {
236        if self.is_empty() || self.null_count() == self.len() {
237            return None;
238        }
239        if self.any() { Some(true) } else { Some(false) }
240    }
241    pub fn mean(&self) -> Option<f64> {
242        if self.is_empty() || self.null_count() == self.len() {
243            return None;
244        }
245        self.sum()
246            .map(|sum| sum as f64 / (self.len() - self.null_count()) as f64)
247    }
248}
249
250// Needs the same trait bounds as the implementation of ChunkedArray<T> of dyn Series.
251impl<T> ChunkAggSeries for ChunkedArray<T>
252where
253    T: PolarsNumericType,
254    T::Native: WrappingSum,
255    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
256{
257    fn sum_reduce(&self) -> Scalar {
258        let v: Option<T::Native> = self.sum();
259        Scalar::new(T::get_static_dtype(), v.into())
260    }
261
262    fn max_reduce(&self) -> Scalar {
263        let v = ChunkAgg::max(self);
264        Scalar::new(T::get_static_dtype(), v.into())
265    }
266
267    fn min_reduce(&self) -> Scalar {
268        let v = ChunkAgg::min(self);
269        Scalar::new(T::get_static_dtype(), v.into())
270    }
271
272    fn prod_reduce(&self) -> Scalar {
273        let mut prod = T::Native::one();
274
275        for arr in self.downcast_iter() {
276            for v in arr.into_iter().flatten() {
277                prod = prod * *v
278            }
279        }
280        Scalar::new(T::get_static_dtype(), prod.into())
281    }
282}
283
284impl<T> VarAggSeries for ChunkedArray<T>
285where
286    T: PolarsIntegerType,
287    ChunkedArray<T>: ChunkVar,
288{
289    fn var_reduce(&self, ddof: u8) -> Scalar {
290        let v = self.var(ddof);
291        Scalar::new(DataType::Float64, v.into())
292    }
293
294    fn std_reduce(&self, ddof: u8) -> Scalar {
295        let v = self.std(ddof);
296        Scalar::new(DataType::Float64, v.into())
297    }
298}
299
300impl VarAggSeries for Float32Chunked {
301    fn var_reduce(&self, ddof: u8) -> Scalar {
302        let v = self.var(ddof).map(|v| v as f32);
303        Scalar::new(DataType::Float32, v.into())
304    }
305
306    fn std_reduce(&self, ddof: u8) -> Scalar {
307        let v = self.std(ddof).map(|v| v as f32);
308        Scalar::new(DataType::Float32, v.into())
309    }
310}
311
312impl VarAggSeries for Float64Chunked {
313    fn var_reduce(&self, ddof: u8) -> Scalar {
314        let v = self.var(ddof);
315        Scalar::new(DataType::Float64, v.into())
316    }
317
318    fn std_reduce(&self, ddof: u8) -> Scalar {
319        let v = self.std(ddof);
320        Scalar::new(DataType::Float64, v.into())
321    }
322}
323
324impl<T> QuantileAggSeries for ChunkedArray<T>
325where
326    T: PolarsIntegerType,
327    T::Native: Ord + WrappingSum,
328{
329    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
330        let v = self.quantile(quantile, method)?;
331        Ok(Scalar::new(DataType::Float64, v.into()))
332    }
333
334    fn median_reduce(&self) -> Scalar {
335        let v = self.median();
336        Scalar::new(DataType::Float64, v.into())
337    }
338}
339
340impl QuantileAggSeries for Float32Chunked {
341    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
342        let v = self.quantile(quantile, method)?;
343        Ok(Scalar::new(DataType::Float32, v.into()))
344    }
345
346    fn median_reduce(&self) -> Scalar {
347        let v = self.median();
348        Scalar::new(DataType::Float32, v.into())
349    }
350}
351
352impl QuantileAggSeries for Float64Chunked {
353    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
354        let v = self.quantile(quantile, method)?;
355        Ok(Scalar::new(DataType::Float64, v.into()))
356    }
357
358    fn median_reduce(&self) -> Scalar {
359        let v = self.median();
360        Scalar::new(DataType::Float64, v.into())
361    }
362}
363
364impl ChunkAggSeries for BooleanChunked {
365    fn sum_reduce(&self) -> Scalar {
366        let v = self.sum();
367        Scalar::new(IDX_DTYPE, v.into())
368    }
369    fn max_reduce(&self) -> Scalar {
370        let v = self.max();
371        Scalar::new(DataType::Boolean, v.into())
372    }
373    fn min_reduce(&self) -> Scalar {
374        let v = self.min();
375        Scalar::new(DataType::Boolean, v.into())
376    }
377}
378
379impl StringChunked {
380    pub(crate) fn max_str(&self) -> Option<&str> {
381        if self.is_empty() {
382            return None;
383        }
384        match self.is_sorted_flag() {
385            IsSorted::Ascending => {
386                self.last_non_null().and_then(|idx| {
387                    // SAFETY: last_non_null returns in bound index
388                    unsafe { self.get_unchecked(idx) }
389                })
390            },
391            IsSorted::Descending => {
392                self.first_non_null().and_then(|idx| {
393                    // SAFETY: first_non_null returns in bound index
394                    unsafe { self.get_unchecked(idx) }
395                })
396            },
397            IsSorted::Not => self
398                .downcast_iter()
399                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
400                .reduce(MinMax::max_ignore_nan),
401        }
402    }
403    pub(crate) fn min_str(&self) -> Option<&str> {
404        if self.is_empty() {
405            return None;
406        }
407        match self.is_sorted_flag() {
408            IsSorted::Ascending => {
409                self.first_non_null().and_then(|idx| {
410                    // SAFETY: first_non_null returns in bound index
411                    unsafe { self.get_unchecked(idx) }
412                })
413            },
414            IsSorted::Descending => {
415                self.last_non_null().and_then(|idx| {
416                    // SAFETY: last_non_null returns in bound index
417                    unsafe { self.get_unchecked(idx) }
418                })
419            },
420            IsSorted::Not => self
421                .downcast_iter()
422                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
423                .reduce(MinMax::min_ignore_nan),
424        }
425    }
426}
427
428impl ChunkAggSeries for StringChunked {
429    fn max_reduce(&self) -> Scalar {
430        let av: AnyValue = self.max_str().into();
431        Scalar::new(DataType::String, av.into_static())
432    }
433    fn min_reduce(&self) -> Scalar {
434        let av: AnyValue = self.min_str().into();
435        Scalar::new(DataType::String, av.into_static())
436    }
437}
438
439#[cfg(feature = "dtype-categorical")]
440impl<T: PolarsCategoricalType> CategoricalChunked<T>
441where
442    ChunkedArray<T::PolarsPhysical>: ChunkAgg<T::Native>,
443{
444    fn min_categorical(&self) -> Option<CatSize> {
445        if self.is_empty() || self.null_count() == self.len() {
446            return None;
447        }
448        if self.uses_lexical_ordering() {
449            let mapping = self.get_mapping();
450            let s = self
451                .physical()
452                .iter()
453                .flat_map(|opt_cat| {
454                    Some(unsafe { mapping.cat_to_str_unchecked(opt_cat?.as_cat()) })
455                })
456                .min();
457            mapping.get_cat(s.unwrap())
458        } else {
459            Some(self.physical().min()?.as_cat())
460        }
461    }
462
463    fn max_categorical(&self) -> Option<CatSize> {
464        if self.is_empty() || self.null_count() == self.len() {
465            return None;
466        }
467        if self.uses_lexical_ordering() {
468            let mapping = self.get_mapping();
469            let s = self
470                .physical()
471                .iter()
472                .flat_map(|opt_cat| {
473                    Some(unsafe { mapping.cat_to_str_unchecked(opt_cat?.as_cat()) })
474                })
475                .max();
476            mapping.get_cat(s.unwrap())
477        } else {
478            Some(self.physical().max()?.as_cat())
479        }
480    }
481}
482
483#[cfg(feature = "dtype-categorical")]
484impl<T: PolarsCategoricalType> ChunkAggSeries for CategoricalChunked<T>
485where
486    ChunkedArray<T::PolarsPhysical>: ChunkAgg<T::Native>,
487{
488    fn min_reduce(&self) -> Scalar {
489        let Some(min) = self.min_categorical() else {
490            return Scalar::new(self.dtype().clone(), AnyValue::Null);
491        };
492        let av = match self.dtype() {
493            DataType::Enum(_, mapping) => AnyValue::EnumOwned(min, mapping.clone()),
494            DataType::Categorical(_, mapping) => AnyValue::CategoricalOwned(min, mapping.clone()),
495            _ => unreachable!(),
496        };
497        Scalar::new(self.dtype().clone(), av)
498    }
499
500    fn max_reduce(&self) -> Scalar {
501        let Some(max) = self.max_categorical() else {
502            return Scalar::new(self.dtype().clone(), AnyValue::Null);
503        };
504        let av = match self.dtype() {
505            DataType::Enum(_, mapping) => AnyValue::EnumOwned(max, mapping.clone()),
506            DataType::Categorical(_, mapping) => AnyValue::CategoricalOwned(max, mapping.clone()),
507            _ => unreachable!(),
508        };
509        Scalar::new(self.dtype().clone(), av)
510    }
511}
512
513impl BinaryChunked {
514    pub fn max_binary(&self) -> Option<&[u8]> {
515        if self.is_empty() {
516            return None;
517        }
518        match self.is_sorted_flag() {
519            IsSorted::Ascending => {
520                self.last_non_null().and_then(|idx| {
521                    // SAFETY: last_non_null returns in bound index.
522                    unsafe { self.get_unchecked(idx) }
523                })
524            },
525            IsSorted::Descending => {
526                self.first_non_null().and_then(|idx| {
527                    // SAFETY: first_non_null returns in bound index.
528                    unsafe { self.get_unchecked(idx) }
529                })
530            },
531            IsSorted::Not => self
532                .downcast_iter()
533                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
534                .reduce(MinMax::max_ignore_nan),
535        }
536    }
537
538    pub fn min_binary(&self) -> Option<&[u8]> {
539        if self.is_empty() {
540            return None;
541        }
542        match self.is_sorted_flag() {
543            IsSorted::Ascending => {
544                self.first_non_null().and_then(|idx| {
545                    // SAFETY: first_non_null returns in bound index.
546                    unsafe { self.get_unchecked(idx) }
547                })
548            },
549            IsSorted::Descending => {
550                self.last_non_null().and_then(|idx| {
551                    // SAFETY: last_non_null returns in bound index.
552                    unsafe { self.get_unchecked(idx) }
553                })
554            },
555            IsSorted::Not => self
556                .downcast_iter()
557                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
558                .reduce(MinMax::min_ignore_nan),
559        }
560    }
561}
562
563impl ChunkAggSeries for BinaryChunked {
564    fn sum_reduce(&self) -> Scalar {
565        unimplemented!()
566    }
567    fn max_reduce(&self) -> Scalar {
568        let av: AnyValue = self.max_binary().into();
569        Scalar::new(self.dtype().clone(), av.into_static())
570    }
571    fn min_reduce(&self) -> Scalar {
572        let av: AnyValue = self.min_binary().into();
573        Scalar::new(self.dtype().clone(), av.into_static())
574    }
575}
576
577#[cfg(feature = "object")]
578impl<T: PolarsObject> ChunkAggSeries for ObjectChunked<T> {}
579
580#[cfg(test)]
581mod test {
582    use polars_compute::rolling::QuantileMethod;
583
584    use crate::prelude::*;
585
586    #[test]
587    #[cfg(not(miri))]
588    fn test_var() {
589        // Validated with numpy. Note that numpy uses ddof as an argument which
590        // influences results. The default ddof=0, we chose ddof=1, which is
591        // standard in statistics.
592        let ca1 = Int32Chunked::new(PlSmallStr::EMPTY, &[5, 8, 9, 5, 0]);
593        let ca2 = Int32Chunked::new(
594            PlSmallStr::EMPTY,
595            &[
596                Some(5),
597                None,
598                Some(8),
599                Some(9),
600                None,
601                Some(5),
602                Some(0),
603                None,
604            ],
605        );
606        for ca in &[ca1, ca2] {
607            let out = ca.var(1);
608            assert_eq!(out, Some(12.3));
609            let out = ca.std(1).unwrap();
610            assert!((3.5071355833500366 - out).abs() < 0.000000001);
611        }
612    }
613
614    #[test]
615    fn test_agg_float() {
616        let ca1 = Float32Chunked::new(PlSmallStr::from_static("a"), &[1.0, f32::NAN]);
617        let ca2 = Float32Chunked::new(PlSmallStr::from_static("b"), &[f32::NAN, 1.0]);
618        assert_eq!(ca1.min(), ca2.min());
619        let ca1 = Float64Chunked::new(PlSmallStr::from_static("a"), &[1.0, f64::NAN]);
620        let ca2 = Float64Chunked::from_slice(PlSmallStr::from_static("b"), &[f64::NAN, 1.0]);
621        assert_eq!(ca1.min(), ca2.min());
622        println!("{:?}", (ca1.min(), ca2.min()))
623    }
624
625    #[test]
626    fn test_median() {
627        let ca = UInt32Chunked::new(
628            PlSmallStr::from_static("a"),
629            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
630        );
631        assert_eq!(ca.median(), Some(3.0));
632        let ca = UInt32Chunked::new(
633            PlSmallStr::from_static("a"),
634            &[
635                None,
636                Some(7),
637                Some(6),
638                Some(2),
639                Some(1),
640                None,
641                Some(3),
642                Some(5),
643                None,
644                Some(4),
645            ],
646        );
647        assert_eq!(ca.median(), Some(4.0));
648
649        let ca = Float32Chunked::from_slice(
650            PlSmallStr::EMPTY,
651            &[
652                0.166189, 0.166559, 0.168517, 0.169393, 0.175272, 0.233167, 0.238787, 0.266562,
653                0.26903, 0.285792, 0.292801, 0.293429, 0.301706, 0.308534, 0.331489, 0.346095,
654                0.367644, 0.369939, 0.372074, 0.41014, 0.415789, 0.421781, 0.427725, 0.465363,
655                0.500208, 2.621727, 2.803311, 3.868526,
656            ],
657        );
658        assert!((ca.median().unwrap() - 0.3200115).abs() < 0.0001)
659    }
660
661    #[test]
662    fn test_mean() {
663        let ca = Float32Chunked::new(PlSmallStr::EMPTY, &[Some(1.0), Some(2.0), None]);
664        assert_eq!(ca.mean().unwrap(), 1.5);
665        assert_eq!(
666            ca.into_series()
667                .mean_reduce()
668                .value()
669                .extract::<f32>()
670                .unwrap(),
671            1.5
672        );
673        // all null values case
674        let ca = Float32Chunked::full_null(PlSmallStr::EMPTY, 3);
675        assert_eq!(ca.mean(), None);
676        assert_eq!(
677            ca.into_series().mean_reduce().value().extract::<f32>(),
678            None
679        );
680    }
681
682    #[test]
683    fn test_quantile_all_null() {
684        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
685        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
686        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
687        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
688
689        let methods = vec![
690            QuantileMethod::Nearest,
691            QuantileMethod::Lower,
692            QuantileMethod::Higher,
693            QuantileMethod::Midpoint,
694            QuantileMethod::Linear,
695            QuantileMethod::Equiprobable,
696        ];
697
698        for method in methods {
699            assert_eq!(test_f32.quantile(0.9, method).unwrap(), None);
700            assert_eq!(test_i32.quantile(0.9, method).unwrap(), None);
701            assert_eq!(test_f64.quantile(0.9, method).unwrap(), None);
702            assert_eq!(test_i64.quantile(0.9, method).unwrap(), None);
703        }
704    }
705
706    #[test]
707    fn test_quantile_single_value() {
708        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
709        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
710        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
711        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
712
713        let methods = vec![
714            QuantileMethod::Nearest,
715            QuantileMethod::Lower,
716            QuantileMethod::Higher,
717            QuantileMethod::Midpoint,
718            QuantileMethod::Linear,
719            QuantileMethod::Equiprobable,
720        ];
721
722        for method in methods {
723            assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0));
724            assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0));
725            assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0));
726            assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0));
727        }
728    }
729
730    #[test]
731    fn test_quantile_min_max() {
732        let test_f32 = Float32Chunked::from_slice_options(
733            PlSmallStr::EMPTY,
734            &[None, Some(1f32), Some(5f32), Some(1f32)],
735        );
736        let test_i32 = Int32Chunked::from_slice_options(
737            PlSmallStr::EMPTY,
738            &[None, Some(1i32), Some(5i32), Some(1i32)],
739        );
740        let test_f64 = Float64Chunked::from_slice_options(
741            PlSmallStr::EMPTY,
742            &[None, Some(1f64), Some(5f64), Some(1f64)],
743        );
744        let test_i64 = Int64Chunked::from_slice_options(
745            PlSmallStr::EMPTY,
746            &[None, Some(1i64), Some(5i64), Some(1i64)],
747        );
748
749        let methods = vec![
750            QuantileMethod::Nearest,
751            QuantileMethod::Lower,
752            QuantileMethod::Higher,
753            QuantileMethod::Midpoint,
754            QuantileMethod::Linear,
755            QuantileMethod::Equiprobable,
756        ];
757
758        for method in methods {
759            assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min());
760            assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max());
761
762            assert_eq!(
763                test_i32.quantile(0.0, method).unwrap().unwrap(),
764                test_i32.min().unwrap() as f64
765            );
766            assert_eq!(
767                test_i32.quantile(1.0, method).unwrap().unwrap(),
768                test_i32.max().unwrap() as f64
769            );
770
771            assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min());
772            assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max());
773            assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median());
774
775            assert_eq!(
776                test_i64.quantile(0.0, method).unwrap().unwrap(),
777                test_i64.min().unwrap() as f64
778            );
779            assert_eq!(
780                test_i64.quantile(1.0, method).unwrap().unwrap(),
781                test_i64.max().unwrap() as f64
782            );
783        }
784    }
785
786    #[test]
787    fn test_quantile() {
788        let ca = UInt32Chunked::new(
789            PlSmallStr::from_static("a"),
790            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
791        );
792
793        assert_eq!(
794            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
795            Some(1.0)
796        );
797        assert_eq!(
798            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
799            Some(5.0)
800        );
801        assert_eq!(
802            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
803            Some(3.0)
804        );
805
806        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
807        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0));
808        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0));
809
810        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
811        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0));
812        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0));
813
814        assert_eq!(
815            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
816            Some(1.5)
817        );
818        assert_eq!(
819            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
820            Some(4.5)
821        );
822        assert_eq!(
823            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
824            Some(3.5)
825        );
826
827        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4));
828        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6));
829        assert!(
830            (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001
831        );
832
833        assert_eq!(
834            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
835            Some(1.0)
836        );
837        assert_eq!(
838            ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(),
839            Some(2.0)
840        );
841        assert_eq!(
842            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
843            Some(3.0)
844        );
845
846        let ca = UInt32Chunked::new(
847            PlSmallStr::from_static("a"),
848            &[
849                None,
850                Some(7),
851                Some(6),
852                Some(2),
853                Some(1),
854                None,
855                Some(3),
856                Some(5),
857                None,
858                Some(4),
859            ],
860        );
861
862        assert_eq!(
863            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
864            Some(2.0)
865        );
866        assert_eq!(
867            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
868            Some(6.0)
869        );
870        assert_eq!(
871            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
872            Some(5.0)
873        );
874
875        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
876        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0));
877        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0));
878
879        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
880        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0));
881        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0));
882
883        assert_eq!(
884            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
885            Some(1.5)
886        );
887        assert_eq!(
888            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
889            Some(6.5)
890        );
891        assert_eq!(
892            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
893            Some(4.5)
894        );
895
896        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6));
897        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4));
898        assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6));
899
900        assert_eq!(
901            ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(),
902            Some(1.0)
903        );
904        assert_eq!(
905            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
906            Some(2.0)
907        );
908        assert_eq!(
909            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
910            Some(5.0)
911        );
912    }
913}