polars_core/chunked_array/ops/aggregate/
mod.rs

1//! Implementations of the ChunkAgg trait.
2mod quantile;
3mod var;
4
5use arrow::types::NativeType;
6use num_traits::{Float, One, ToPrimitive, Zero};
7use polars_compute::float_sum;
8use polars_compute::min_max::MinMaxKernel;
9use polars_compute::rolling::QuantileMethod;
10use polars_compute::sum::{WrappingSum, wrapping_sum_arr};
11use polars_utils::min_max::MinMax;
12pub use quantile::*;
13pub use var::*;
14
15use super::float_sorted_arg_max::{
16    float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
17};
18use crate::chunked_array::ChunkedArray;
19use crate::datatypes::{BooleanChunked, PolarsNumericType};
20use crate::prelude::*;
21use crate::series::IsSorted;
22
23/// Aggregations that return [`Series`] of unit length. Those can be used in broadcasting operations.
24pub trait ChunkAggSeries {
25    /// Get the sum of the [`ChunkedArray`] as a new [`Series`] of length 1.
26    fn sum_reduce(&self) -> Scalar {
27        unimplemented!()
28    }
29    /// Get the max of the [`ChunkedArray`] as a new [`Series`] of length 1.
30    fn max_reduce(&self) -> Scalar {
31        unimplemented!()
32    }
33    /// Get the min of the [`ChunkedArray`] as a new [`Series`] of length 1.
34    fn min_reduce(&self) -> Scalar {
35        unimplemented!()
36    }
37    /// Get the product of the [`ChunkedArray`] as a new [`Series`] of length 1.
38    fn prod_reduce(&self) -> Scalar {
39        unimplemented!()
40    }
41}
42
43fn sum<T>(array: &PrimitiveArray<T>) -> T
44where
45    T: NumericNative + NativeType + WrappingSum,
46{
47    if array.null_count() == array.len() {
48        return T::default();
49    }
50
51    if T::is_float() {
52        unsafe {
53            if T::is_f32() {
54                let f32_arr =
55                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f32>>(array);
56                let sum = float_sum::sum_arr_as_f32(f32_arr);
57                std::mem::transmute_copy::<f32, T>(&sum)
58            } else if T::is_f64() {
59                let f64_arr =
60                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f64>>(array);
61                let sum = float_sum::sum_arr_as_f64(f64_arr);
62                std::mem::transmute_copy::<f64, T>(&sum)
63            } else {
64                unreachable!("only supported float types are f32 and f64");
65            }
66        }
67    } else {
68        wrapping_sum_arr(array)
69    }
70}
71
72impl<T> ChunkAgg<T::Native> for ChunkedArray<T>
73where
74    T: PolarsNumericType,
75    T::Native: WrappingSum,
76    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
77{
78    fn sum(&self) -> Option<T::Native> {
79        Some(
80            self.downcast_iter()
81                .map(sum)
82                .fold(T::Native::zero(), |acc, v| acc + v),
83        )
84    }
85
86    fn _sum_as_f64(&self) -> f64 {
87        self.downcast_iter().map(float_sum::sum_arr_as_f64).sum()
88    }
89
90    fn min(&self) -> Option<T::Native> {
91        if self.null_count() == self.len() {
92            return None;
93        }
94
95        // There is at least one non-null value.
96
97        match self.is_sorted_flag() {
98            IsSorted::Ascending => {
99                let idx = self.first_non_null().unwrap();
100                unsafe { self.get_unchecked(idx) }
101            },
102            IsSorted::Descending => {
103                let idx = self.last_non_null().unwrap();
104                unsafe { self.get_unchecked(idx) }
105            },
106            IsSorted::Not => self
107                .downcast_iter()
108                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
109                .reduce(MinMax::min_ignore_nan),
110        }
111    }
112
113    fn max(&self) -> Option<T::Native> {
114        if self.null_count() == self.len() {
115            return None;
116        }
117        // There is at least one non-null value.
118
119        match self.is_sorted_flag() {
120            IsSorted::Ascending => {
121                let idx = if T::get_static_dtype().is_float() {
122                    float_arg_max_sorted_ascending(self)
123                } else {
124                    self.last_non_null().unwrap()
125                };
126
127                unsafe { self.get_unchecked(idx) }
128            },
129            IsSorted::Descending => {
130                let idx = if T::get_static_dtype().is_float() {
131                    float_arg_max_sorted_descending(self)
132                } else {
133                    self.first_non_null().unwrap()
134                };
135
136                unsafe { self.get_unchecked(idx) }
137            },
138            IsSorted::Not => self
139                .downcast_iter()
140                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
141                .reduce(MinMax::max_ignore_nan),
142        }
143    }
144
145    fn min_max(&self) -> Option<(T::Native, T::Native)> {
146        if self.null_count() == self.len() {
147            return None;
148        }
149        // There is at least one non-null value.
150
151        match self.is_sorted_flag() {
152            IsSorted::Ascending => {
153                let min = unsafe { self.get_unchecked(self.first_non_null().unwrap()) };
154                let max = {
155                    let idx = if T::get_static_dtype().is_float() {
156                        float_arg_max_sorted_ascending(self)
157                    } else {
158                        self.last_non_null().unwrap()
159                    };
160
161                    unsafe { self.get_unchecked(idx) }
162                };
163                min.zip(max)
164            },
165            IsSorted::Descending => {
166                let min = unsafe { self.get_unchecked(self.last_non_null().unwrap()) };
167                let max = {
168                    let idx = if T::get_static_dtype().is_float() {
169                        float_arg_max_sorted_descending(self)
170                    } else {
171                        self.first_non_null().unwrap()
172                    };
173
174                    unsafe { self.get_unchecked(idx) }
175                };
176
177                min.zip(max)
178            },
179            IsSorted::Not => self
180                .downcast_iter()
181                .filter_map(MinMaxKernel::min_max_ignore_nan_kernel)
182                .reduce(|(min1, max1), (min2, max2)| {
183                    (
184                        MinMax::min_ignore_nan(min1, min2),
185                        MinMax::max_ignore_nan(max1, max2),
186                    )
187                }),
188        }
189    }
190
191    fn mean(&self) -> Option<f64> {
192        let count = self.len() - self.null_count();
193        if count == 0 {
194            return None;
195        }
196        Some(self._sum_as_f64() / count as f64)
197    }
198}
199
200/// Booleans are cast to 1 or 0.
201impl BooleanChunked {
202    pub fn sum(&self) -> Option<IdxSize> {
203        Some(if self.is_empty() {
204            0
205        } else {
206            self.downcast_iter()
207                .map(|arr| match arr.validity() {
208                    Some(validity) => {
209                        (arr.len() - (validity & arr.values()).unset_bits()) as IdxSize
210                    },
211                    None => (arr.len() - arr.values().unset_bits()) as IdxSize,
212                })
213                .sum()
214        })
215    }
216
217    pub fn min(&self) -> Option<bool> {
218        let nc = self.null_count();
219        let len = self.len();
220        if self.is_empty() || nc == len {
221            return None;
222        }
223        if nc == 0 {
224            if self.all() { Some(true) } else { Some(false) }
225        } else {
226            // we can unwrap as we already checked empty and all null above
227            if (self.sum().unwrap() + nc as IdxSize) == len as IdxSize {
228                Some(true)
229            } else {
230                Some(false)
231            }
232        }
233    }
234
235    pub fn max(&self) -> Option<bool> {
236        if self.is_empty() || self.null_count() == self.len() {
237            return None;
238        }
239        if self.any() { Some(true) } else { Some(false) }
240    }
241    pub fn mean(&self) -> Option<f64> {
242        if self.is_empty() || self.null_count() == self.len() {
243            return None;
244        }
245        self.sum()
246            .map(|sum| sum as f64 / (self.len() - self.null_count()) as f64)
247    }
248}
249
250// Needs the same trait bounds as the implementation of ChunkedArray<T> of dyn Series.
251impl<T> ChunkAggSeries for ChunkedArray<T>
252where
253    T: PolarsNumericType,
254    T::Native: WrappingSum,
255    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
256{
257    fn sum_reduce(&self) -> Scalar {
258        let v: Option<T::Native> = self.sum();
259        Scalar::new(T::get_static_dtype(), v.into())
260    }
261
262    fn max_reduce(&self) -> Scalar {
263        let v = ChunkAgg::max(self);
264        Scalar::new(T::get_static_dtype(), v.into())
265    }
266
267    fn min_reduce(&self) -> Scalar {
268        let v = ChunkAgg::min(self);
269        Scalar::new(T::get_static_dtype(), v.into())
270    }
271
272    fn prod_reduce(&self) -> Scalar {
273        let mut prod = T::Native::one();
274
275        for arr in self.downcast_iter() {
276            for v in arr.into_iter().flatten() {
277                prod = prod * *v
278            }
279        }
280        Scalar::new(T::get_static_dtype(), prod.into())
281    }
282}
283
284impl<T> VarAggSeries for ChunkedArray<T>
285where
286    T: PolarsIntegerType,
287    ChunkedArray<T>: ChunkVar,
288{
289    fn var_reduce(&self, ddof: u8) -> Scalar {
290        let v = self.var(ddof);
291        Scalar::new(DataType::Float64, v.into())
292    }
293
294    fn std_reduce(&self, ddof: u8) -> Scalar {
295        let v = self.std(ddof);
296        Scalar::new(DataType::Float64, v.into())
297    }
298}
299
300impl VarAggSeries for Float32Chunked {
301    fn var_reduce(&self, ddof: u8) -> Scalar {
302        let v = self.var(ddof).map(|v| v as f32);
303        Scalar::new(DataType::Float32, v.into())
304    }
305
306    fn std_reduce(&self, ddof: u8) -> Scalar {
307        let v = self.std(ddof).map(|v| v as f32);
308        Scalar::new(DataType::Float32, v.into())
309    }
310}
311
312impl VarAggSeries for Float64Chunked {
313    fn var_reduce(&self, ddof: u8) -> Scalar {
314        let v = self.var(ddof);
315        Scalar::new(DataType::Float64, v.into())
316    }
317
318    fn std_reduce(&self, ddof: u8) -> Scalar {
319        let v = self.std(ddof);
320        Scalar::new(DataType::Float64, v.into())
321    }
322}
323
324impl<T> QuantileAggSeries for ChunkedArray<T>
325where
326    T: PolarsIntegerType,
327    T::Native: Ord + WrappingSum,
328{
329    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
330        let v = self.quantile(quantile, method)?;
331        Ok(Scalar::new(DataType::Float64, v.into()))
332    }
333
334    fn median_reduce(&self) -> Scalar {
335        let v = self.median();
336        Scalar::new(DataType::Float64, v.into())
337    }
338}
339
340impl QuantileAggSeries for Float32Chunked {
341    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
342        let v = self.quantile(quantile, method)?;
343        Ok(Scalar::new(DataType::Float32, v.into()))
344    }
345
346    fn median_reduce(&self) -> Scalar {
347        let v = self.median();
348        Scalar::new(DataType::Float32, v.into())
349    }
350}
351
352impl QuantileAggSeries for Float64Chunked {
353    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
354        let v = self.quantile(quantile, method)?;
355        Ok(Scalar::new(DataType::Float64, v.into()))
356    }
357
358    fn median_reduce(&self) -> Scalar {
359        let v = self.median();
360        Scalar::new(DataType::Float64, v.into())
361    }
362}
363
364impl ChunkAggSeries for BooleanChunked {
365    fn sum_reduce(&self) -> Scalar {
366        let v = self.sum();
367        Scalar::new(IDX_DTYPE, v.into())
368    }
369    fn max_reduce(&self) -> Scalar {
370        let v = self.max();
371        Scalar::new(DataType::Boolean, v.into())
372    }
373    fn min_reduce(&self) -> Scalar {
374        let v = self.min();
375        Scalar::new(DataType::Boolean, v.into())
376    }
377}
378
379impl StringChunked {
380    pub(crate) fn max_str(&self) -> Option<&str> {
381        if self.is_empty() {
382            return None;
383        }
384        match self.is_sorted_flag() {
385            IsSorted::Ascending => {
386                self.last_non_null().and_then(|idx| {
387                    // SAFETY: last_non_null returns in bound index
388                    unsafe { self.get_unchecked(idx) }
389                })
390            },
391            IsSorted::Descending => {
392                self.first_non_null().and_then(|idx| {
393                    // SAFETY: first_non_null returns in bound index
394                    unsafe { self.get_unchecked(idx) }
395                })
396            },
397            IsSorted::Not => self
398                .downcast_iter()
399                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
400                .reduce(MinMax::max_ignore_nan),
401        }
402    }
403    pub(crate) fn min_str(&self) -> Option<&str> {
404        if self.is_empty() {
405            return None;
406        }
407        match self.is_sorted_flag() {
408            IsSorted::Ascending => {
409                self.first_non_null().and_then(|idx| {
410                    // SAFETY: first_non_null returns in bound index
411                    unsafe { self.get_unchecked(idx) }
412                })
413            },
414            IsSorted::Descending => {
415                self.last_non_null().and_then(|idx| {
416                    // SAFETY: last_non_null returns in bound index
417                    unsafe { self.get_unchecked(idx) }
418                })
419            },
420            IsSorted::Not => self
421                .downcast_iter()
422                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
423                .reduce(MinMax::min_ignore_nan),
424        }
425    }
426}
427
428impl ChunkAggSeries for StringChunked {
429    fn max_reduce(&self) -> Scalar {
430        let av: AnyValue = self.max_str().into();
431        Scalar::new(DataType::String, av.into_static())
432    }
433    fn min_reduce(&self) -> Scalar {
434        let av: AnyValue = self.min_str().into();
435        Scalar::new(DataType::String, av.into_static())
436    }
437}
438
439#[cfg(feature = "dtype-categorical")]
440impl<T: PolarsCategoricalType> CategoricalChunked<T>
441where
442    ChunkedArray<T::PolarsPhysical>: ChunkAgg<T::Native>,
443{
444    fn min_categorical(&self) -> Option<CatSize> {
445        if self.is_empty() || self.null_count() == self.len() {
446            return None;
447        }
448        if self.uses_lexical_ordering() {
449            let mapping = self.get_mapping();
450            let s = self
451                .physical()
452                .iter()
453                .flat_map(|opt_cat| {
454                    Some(unsafe { mapping.cat_to_str_unchecked(opt_cat?.as_cat()) })
455                })
456                .min();
457            mapping.get_cat(s.unwrap())
458        } else {
459            Some(self.physical().min()?.as_cat())
460        }
461    }
462
463    fn max_categorical(&self) -> Option<CatSize> {
464        if self.is_empty() || self.null_count() == self.len() {
465            return None;
466        }
467        if self.uses_lexical_ordering() {
468            let mapping = self.get_mapping();
469            let s = self
470                .physical()
471                .iter()
472                .flat_map(|opt_cat| {
473                    Some(unsafe { mapping.cat_to_str_unchecked(opt_cat?.as_cat()) })
474                })
475                .max();
476            mapping.get_cat(s.unwrap())
477        } else {
478            Some(self.physical().max()?.as_cat())
479        }
480    }
481}
482
483#[cfg(feature = "dtype-categorical")]
484impl<T: PolarsCategoricalType> ChunkAggSeries for CategoricalChunked<T>
485where
486    ChunkedArray<T::PolarsPhysical>: ChunkAgg<T::Native>,
487{
488    fn min_reduce(&self) -> Scalar {
489        let Some(min) = self.min_categorical() else {
490            return Scalar::new(self.dtype().clone(), AnyValue::Null);
491        };
492        let av = match self.dtype() {
493            DataType::Enum(_, mapping) => AnyValue::EnumOwned(min, mapping.clone()),
494            DataType::Categorical(_, mapping) => AnyValue::CategoricalOwned(min, mapping.clone()),
495            _ => unreachable!(),
496        };
497        Scalar::new(self.dtype().clone(), av)
498    }
499
500    fn max_reduce(&self) -> Scalar {
501        let Some(max) = self.max_categorical() else {
502            return Scalar::new(self.dtype().clone(), AnyValue::Null);
503        };
504        let av = match self.dtype() {
505            DataType::Enum(_, mapping) => AnyValue::EnumOwned(max, mapping.clone()),
506            DataType::Categorical(_, mapping) => AnyValue::CategoricalOwned(max, mapping.clone()),
507            _ => unreachable!(),
508        };
509        Scalar::new(self.dtype().clone(), av)
510    }
511}
512
513impl BinaryChunked {
514    pub fn max_binary(&self) -> Option<&[u8]> {
515        if self.is_empty() {
516            return None;
517        }
518        match self.is_sorted_flag() {
519            IsSorted::Ascending => {
520                self.last_non_null().and_then(|idx| {
521                    // SAFETY: last_non_null returns in bound index.
522                    unsafe { self.get_unchecked(idx) }
523                })
524            },
525            IsSorted::Descending => {
526                self.first_non_null().and_then(|idx| {
527                    // SAFETY: first_non_null returns in bound index.
528                    unsafe { self.get_unchecked(idx) }
529                })
530            },
531            IsSorted::Not => self
532                .downcast_iter()
533                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
534                .reduce(MinMax::max_ignore_nan),
535        }
536    }
537
538    pub fn min_binary(&self) -> Option<&[u8]> {
539        if self.is_empty() {
540            return None;
541        }
542        match self.is_sorted_flag() {
543            IsSorted::Ascending => {
544                self.first_non_null().and_then(|idx| {
545                    // SAFETY: first_non_null returns in bound index.
546                    unsafe { self.get_unchecked(idx) }
547                })
548            },
549            IsSorted::Descending => {
550                self.last_non_null().and_then(|idx| {
551                    // SAFETY: last_non_null returns in bound index.
552                    unsafe { self.get_unchecked(idx) }
553                })
554            },
555            IsSorted::Not => self
556                .downcast_iter()
557                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
558                .reduce(MinMax::min_ignore_nan),
559        }
560    }
561}
562
563impl ChunkAggSeries for BinaryChunked {
564    fn sum_reduce(&self) -> Scalar {
565        unimplemented!()
566    }
567    fn max_reduce(&self) -> Scalar {
568        let av: AnyValue = self.max_binary().into();
569        Scalar::new(self.dtype().clone(), av.into_static())
570    }
571    fn min_reduce(&self) -> Scalar {
572        let av: AnyValue = self.min_binary().into();
573        Scalar::new(self.dtype().clone(), av.into_static())
574    }
575}
576
577#[cfg(feature = "object")]
578impl<T: PolarsObject> ChunkAggSeries for ObjectChunked<T> {}
579
580#[cfg(test)]
581mod test {
582    use polars_compute::rolling::QuantileMethod;
583
584    use crate::prelude::*;
585
586    #[test]
587    #[cfg(not(miri))]
588    fn test_var() {
589        // Validated with numpy. Note that numpy uses ddof as an argument which
590        // influences results. The default ddof=0, we chose ddof=1, which is
591        // standard in statistics.
592        let ca1 = Int32Chunked::new(PlSmallStr::EMPTY, &[5, 8, 9, 5, 0]);
593        let ca2 = Int32Chunked::new(
594            PlSmallStr::EMPTY,
595            &[
596                Some(5),
597                None,
598                Some(8),
599                Some(9),
600                None,
601                Some(5),
602                Some(0),
603                None,
604            ],
605        );
606        for ca in &[ca1, ca2] {
607            let out = ca.var(1);
608            assert_eq!(out, Some(12.3));
609            let out = ca.std(1).unwrap();
610            assert!((3.5071355833500366 - out).abs() < 0.000000001);
611        }
612    }
613
614    #[test]
615    fn test_agg_float() {
616        let ca1 = Float32Chunked::new(PlSmallStr::from_static("a"), &[1.0, f32::NAN]);
617        let ca2 = Float32Chunked::new(PlSmallStr::from_static("b"), &[f32::NAN, 1.0]);
618        assert_eq!(ca1.min(), ca2.min());
619        let ca1 = Float64Chunked::new(PlSmallStr::from_static("a"), &[1.0, f64::NAN]);
620        let ca2 = Float64Chunked::from_slice(PlSmallStr::from_static("b"), &[f64::NAN, 1.0]);
621        assert_eq!(ca1.min(), ca2.min());
622        println!("{:?}", (ca1.min(), ca2.min()))
623    }
624
625    #[test]
626    fn test_median() {
627        let ca = UInt32Chunked::new(
628            PlSmallStr::from_static("a"),
629            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
630        );
631        assert_eq!(ca.median(), Some(3.0));
632        let ca = UInt32Chunked::new(
633            PlSmallStr::from_static("a"),
634            &[
635                None,
636                Some(7),
637                Some(6),
638                Some(2),
639                Some(1),
640                None,
641                Some(3),
642                Some(5),
643                None,
644                Some(4),
645            ],
646        );
647        assert_eq!(ca.median(), Some(4.0));
648
649        let ca = Float32Chunked::from_slice(
650            PlSmallStr::EMPTY,
651            &[
652                0.166189, 0.166559, 0.168517, 0.169393, 0.175272, 0.233167, 0.238787, 0.266562,
653                0.26903, 0.285792, 0.292801, 0.293429, 0.301706, 0.308534, 0.331489, 0.346095,
654                0.367644, 0.369939, 0.372074, 0.41014, 0.415789, 0.421781, 0.427725, 0.465363,
655                0.500208, 2.621727, 2.803311, 3.868526,
656            ],
657        );
658        assert!((ca.median().unwrap() - 0.3200115).abs() < 0.0001)
659    }
660
661    #[test]
662    fn test_mean() {
663        let ca = Float32Chunked::new(PlSmallStr::EMPTY, &[Some(1.0), Some(2.0), None]);
664        assert_eq!(ca.mean().unwrap(), 1.5);
665        assert_eq!(
666            ca.into_series()
667                .mean_reduce()
668                .unwrap()
669                .value()
670                .extract::<f32>()
671                .unwrap(),
672            1.5
673        );
674        // all null values case
675        let ca = Float32Chunked::full_null(PlSmallStr::EMPTY, 3);
676        assert_eq!(ca.mean(), None);
677        assert_eq!(
678            ca.into_series()
679                .mean_reduce()
680                .unwrap()
681                .value()
682                .extract::<f32>(),
683            None
684        );
685    }
686
687    #[test]
688    fn test_quantile_all_null() {
689        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
690        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
691        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
692        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
693
694        let methods = vec![
695            QuantileMethod::Nearest,
696            QuantileMethod::Lower,
697            QuantileMethod::Higher,
698            QuantileMethod::Midpoint,
699            QuantileMethod::Linear,
700            QuantileMethod::Equiprobable,
701        ];
702
703        for method in methods {
704            assert_eq!(test_f32.quantile(0.9, method).unwrap(), None);
705            assert_eq!(test_i32.quantile(0.9, method).unwrap(), None);
706            assert_eq!(test_f64.quantile(0.9, method).unwrap(), None);
707            assert_eq!(test_i64.quantile(0.9, method).unwrap(), None);
708        }
709    }
710
711    #[test]
712    fn test_quantile_single_value() {
713        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
714        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
715        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
716        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
717
718        let methods = vec![
719            QuantileMethod::Nearest,
720            QuantileMethod::Lower,
721            QuantileMethod::Higher,
722            QuantileMethod::Midpoint,
723            QuantileMethod::Linear,
724            QuantileMethod::Equiprobable,
725        ];
726
727        for method in methods {
728            assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0));
729            assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0));
730            assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0));
731            assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0));
732        }
733    }
734
735    #[test]
736    fn test_quantile_min_max() {
737        let test_f32 = Float32Chunked::from_slice_options(
738            PlSmallStr::EMPTY,
739            &[None, Some(1f32), Some(5f32), Some(1f32)],
740        );
741        let test_i32 = Int32Chunked::from_slice_options(
742            PlSmallStr::EMPTY,
743            &[None, Some(1i32), Some(5i32), Some(1i32)],
744        );
745        let test_f64 = Float64Chunked::from_slice_options(
746            PlSmallStr::EMPTY,
747            &[None, Some(1f64), Some(5f64), Some(1f64)],
748        );
749        let test_i64 = Int64Chunked::from_slice_options(
750            PlSmallStr::EMPTY,
751            &[None, Some(1i64), Some(5i64), Some(1i64)],
752        );
753
754        let methods = vec![
755            QuantileMethod::Nearest,
756            QuantileMethod::Lower,
757            QuantileMethod::Higher,
758            QuantileMethod::Midpoint,
759            QuantileMethod::Linear,
760            QuantileMethod::Equiprobable,
761        ];
762
763        for method in methods {
764            assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min());
765            assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max());
766
767            assert_eq!(
768                test_i32.quantile(0.0, method).unwrap().unwrap(),
769                test_i32.min().unwrap() as f64
770            );
771            assert_eq!(
772                test_i32.quantile(1.0, method).unwrap().unwrap(),
773                test_i32.max().unwrap() as f64
774            );
775
776            assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min());
777            assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max());
778            assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median());
779
780            assert_eq!(
781                test_i64.quantile(0.0, method).unwrap().unwrap(),
782                test_i64.min().unwrap() as f64
783            );
784            assert_eq!(
785                test_i64.quantile(1.0, method).unwrap().unwrap(),
786                test_i64.max().unwrap() as f64
787            );
788        }
789    }
790
791    #[test]
792    fn test_quantile() {
793        let ca = UInt32Chunked::new(
794            PlSmallStr::from_static("a"),
795            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
796        );
797
798        assert_eq!(
799            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
800            Some(1.0)
801        );
802        assert_eq!(
803            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
804            Some(5.0)
805        );
806        assert_eq!(
807            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
808            Some(3.0)
809        );
810
811        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
812        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0));
813        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0));
814
815        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
816        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0));
817        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0));
818
819        assert_eq!(
820            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
821            Some(1.5)
822        );
823        assert_eq!(
824            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
825            Some(4.5)
826        );
827        assert_eq!(
828            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
829            Some(3.5)
830        );
831
832        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4));
833        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6));
834        assert!(
835            (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001
836        );
837
838        assert_eq!(
839            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
840            Some(1.0)
841        );
842        assert_eq!(
843            ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(),
844            Some(2.0)
845        );
846        assert_eq!(
847            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
848            Some(3.0)
849        );
850
851        let ca = UInt32Chunked::new(
852            PlSmallStr::from_static("a"),
853            &[
854                None,
855                Some(7),
856                Some(6),
857                Some(2),
858                Some(1),
859                None,
860                Some(3),
861                Some(5),
862                None,
863                Some(4),
864            ],
865        );
866
867        assert_eq!(
868            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
869            Some(2.0)
870        );
871        assert_eq!(
872            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
873            Some(6.0)
874        );
875        assert_eq!(
876            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
877            Some(5.0)
878        );
879
880        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
881        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0));
882        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0));
883
884        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
885        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0));
886        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0));
887
888        assert_eq!(
889            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
890            Some(1.5)
891        );
892        assert_eq!(
893            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
894            Some(6.5)
895        );
896        assert_eq!(
897            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
898            Some(4.5)
899        );
900
901        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6));
902        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4));
903        assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6));
904
905        assert_eq!(
906            ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(),
907            Some(1.0)
908        );
909        assert_eq!(
910            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
911            Some(2.0)
912        );
913        assert_eq!(
914            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
915            Some(5.0)
916        );
917    }
918}