polars_core/chunked_array/ops/aggregate/
mod.rs

1//! Implementations of the ChunkAgg trait.
2mod quantile;
3mod var;
4
5use arrow::types::NativeType;
6use num_traits::{AsPrimitive, Float, One, ToPrimitive, Zero};
7use polars_compute::float_sum;
8use polars_compute::min_max::MinMaxKernel;
9use polars_compute::rolling::QuantileMethod;
10use polars_compute::sum::{WrappingSum, wrapping_sum_arr};
11use polars_utils::float16::pf16;
12use polars_utils::min_max::MinMax;
13pub use quantile::*;
14pub use var::*;
15
16use super::float_sorted_arg_max::{
17    float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
18};
19use crate::chunked_array::ChunkedArray;
20use crate::datatypes::{BooleanChunked, PolarsNumericType};
21use crate::prelude::*;
22use crate::series::IsSorted;
23
24/// Aggregations that return [`Series`] of unit length. Those can be used in broadcasting operations.
25pub trait ChunkAggSeries {
26    /// Get the sum of the [`ChunkedArray`] as a new [`Series`] of length 1.
27    fn sum_reduce(&self) -> Scalar {
28        unimplemented!()
29    }
30    /// Get the max of the [`ChunkedArray`] as a new [`Series`] of length 1.
31    fn max_reduce(&self) -> Scalar {
32        unimplemented!()
33    }
34    /// Get the min of the [`ChunkedArray`] as a new [`Series`] of length 1.
35    fn min_reduce(&self) -> Scalar {
36        unimplemented!()
37    }
38    /// Get the product of the [`ChunkedArray`] as a new [`Series`] of length 1.
39    fn prod_reduce(&self) -> Scalar {
40        unimplemented!()
41    }
42}
43
44fn sum<T>(array: &PrimitiveArray<T>) -> T
45where
46    T: NumericNative + NativeType + WrappingSum,
47{
48    if array.null_count() == array.len() {
49        return T::default();
50    }
51
52    if T::is_float() {
53        unsafe {
54            if T::is_f16() {
55                let f16_arr =
56                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<pf16>>(array);
57                // We do not trust the numerical accuracy of f16 summation
58                let sum: pf16 = float_sum::sum_arr_as_f32(f16_arr).as_();
59                std::mem::transmute_copy::<pf16, T>(&sum)
60            } else if T::is_f32() {
61                let f32_arr =
62                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f32>>(array);
63                let sum = float_sum::sum_arr_as_f32(f32_arr);
64                std::mem::transmute_copy::<f32, T>(&sum)
65            } else if T::is_f64() {
66                let f64_arr =
67                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f64>>(array);
68                let sum = float_sum::sum_arr_as_f64(f64_arr);
69                std::mem::transmute_copy::<f64, T>(&sum)
70            } else {
71                unreachable!("only supported float types are f16, f32 and f64");
72            }
73        }
74    } else {
75        wrapping_sum_arr(array)
76    }
77}
78
79impl<T> ChunkAgg<T::Native> for ChunkedArray<T>
80where
81    T: PolarsNumericType,
82    T::Native: WrappingSum,
83    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
84{
85    fn sum(&self) -> Option<T::Native> {
86        Some(
87            self.downcast_iter()
88                .map(sum)
89                .fold(T::Native::zero(), |acc, v| acc + v),
90        )
91    }
92
93    fn _sum_as_f64(&self) -> f64 {
94        self.downcast_iter().map(float_sum::sum_arr_as_f64).sum()
95    }
96
97    fn min(&self) -> Option<T::Native> {
98        if self.null_count() == self.len() {
99            return None;
100        }
101
102        // There is at least one non-null value.
103
104        match self.is_sorted_flag() {
105            IsSorted::Ascending => {
106                let idx = self.first_non_null().unwrap();
107                unsafe { self.get_unchecked(idx) }
108            },
109            IsSorted::Descending => {
110                let idx = self.last_non_null().unwrap();
111                unsafe { self.get_unchecked(idx) }
112            },
113            IsSorted::Not => self
114                .downcast_iter()
115                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
116                .reduce(MinMax::min_ignore_nan),
117        }
118    }
119
120    fn max(&self) -> Option<T::Native> {
121        if self.null_count() == self.len() {
122            return None;
123        }
124        // There is at least one non-null value.
125
126        match self.is_sorted_flag() {
127            IsSorted::Ascending => {
128                let idx = if T::get_static_dtype().is_float() {
129                    float_arg_max_sorted_ascending(self)
130                } else {
131                    self.last_non_null().unwrap()
132                };
133
134                unsafe { self.get_unchecked(idx) }
135            },
136            IsSorted::Descending => {
137                let idx = if T::get_static_dtype().is_float() {
138                    float_arg_max_sorted_descending(self)
139                } else {
140                    self.first_non_null().unwrap()
141                };
142
143                unsafe { self.get_unchecked(idx) }
144            },
145            IsSorted::Not => self
146                .downcast_iter()
147                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
148                .reduce(MinMax::max_ignore_nan),
149        }
150    }
151
152    fn min_max(&self) -> Option<(T::Native, T::Native)> {
153        if self.null_count() == self.len() {
154            return None;
155        }
156        // There is at least one non-null value.
157
158        match self.is_sorted_flag() {
159            IsSorted::Ascending => {
160                let min = unsafe { self.get_unchecked(self.first_non_null().unwrap()) };
161                let max = {
162                    let idx = if T::get_static_dtype().is_float() {
163                        float_arg_max_sorted_ascending(self)
164                    } else {
165                        self.last_non_null().unwrap()
166                    };
167
168                    unsafe { self.get_unchecked(idx) }
169                };
170                min.zip(max)
171            },
172            IsSorted::Descending => {
173                let min = unsafe { self.get_unchecked(self.last_non_null().unwrap()) };
174                let max = {
175                    let idx = if T::get_static_dtype().is_float() {
176                        float_arg_max_sorted_descending(self)
177                    } else {
178                        self.first_non_null().unwrap()
179                    };
180
181                    unsafe { self.get_unchecked(idx) }
182                };
183
184                min.zip(max)
185            },
186            IsSorted::Not => self
187                .downcast_iter()
188                .filter_map(MinMaxKernel::min_max_ignore_nan_kernel)
189                .reduce(|(min1, max1), (min2, max2)| {
190                    (
191                        MinMax::min_ignore_nan(min1, min2),
192                        MinMax::max_ignore_nan(max1, max2),
193                    )
194                }),
195        }
196    }
197
198    fn mean(&self) -> Option<f64> {
199        let count = self.len() - self.null_count();
200        if count == 0 {
201            return None;
202        }
203        Some(self._sum_as_f64() / count as f64)
204    }
205}
206
207/// Booleans are cast to 1 or 0.
208impl BooleanChunked {
209    pub fn sum(&self) -> Option<IdxSize> {
210        Some(if self.is_empty() {
211            0
212        } else {
213            self.downcast_iter()
214                .map(|arr| match arr.validity() {
215                    Some(validity) => {
216                        (arr.len() - (validity & arr.values()).unset_bits()) as IdxSize
217                    },
218                    None => (arr.len() - arr.values().unset_bits()) as IdxSize,
219                })
220                .sum()
221        })
222    }
223
224    pub fn min(&self) -> Option<bool> {
225        let nc = self.null_count();
226        let len = self.len();
227        if self.is_empty() || nc == len {
228            return None;
229        }
230        if nc == 0 {
231            if self.all() { Some(true) } else { Some(false) }
232        } else {
233            // we can unwrap as we already checked empty and all null above
234            if (self.sum().unwrap() + nc as IdxSize) == len as IdxSize {
235                Some(true)
236            } else {
237                Some(false)
238            }
239        }
240    }
241
242    pub fn max(&self) -> Option<bool> {
243        if self.is_empty() || self.null_count() == self.len() {
244            return None;
245        }
246        if self.any() { Some(true) } else { Some(false) }
247    }
248    pub fn mean(&self) -> Option<f64> {
249        if self.is_empty() || self.null_count() == self.len() {
250            return None;
251        }
252        self.sum()
253            .map(|sum| sum as f64 / (self.len() - self.null_count()) as f64)
254    }
255}
256
257// Needs the same trait bounds as the implementation of ChunkedArray<T> of dyn Series.
258impl<T> ChunkAggSeries for ChunkedArray<T>
259where
260    T: PolarsNumericType,
261    T::Native: WrappingSum,
262    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
263{
264    fn sum_reduce(&self) -> Scalar {
265        let v: Option<T::Native> = self.sum();
266        Scalar::new(T::get_static_dtype(), v.into())
267    }
268
269    fn max_reduce(&self) -> Scalar {
270        let v = ChunkAgg::max(self);
271        Scalar::new(T::get_static_dtype(), v.into())
272    }
273
274    fn min_reduce(&self) -> Scalar {
275        let v = ChunkAgg::min(self);
276        Scalar::new(T::get_static_dtype(), v.into())
277    }
278
279    fn prod_reduce(&self) -> Scalar {
280        let mut prod = T::Native::one();
281
282        for arr in self.downcast_iter() {
283            for v in arr.into_iter().flatten() {
284                prod = prod * *v
285            }
286        }
287        Scalar::new(T::get_static_dtype(), prod.into())
288    }
289}
290
291impl<T> VarAggSeries for ChunkedArray<T>
292where
293    T: PolarsIntegerType,
294    ChunkedArray<T>: ChunkVar,
295{
296    fn var_reduce(&self, ddof: u8) -> Scalar {
297        let v = self.var(ddof);
298        Scalar::new(DataType::Float64, v.into())
299    }
300
301    fn std_reduce(&self, ddof: u8) -> Scalar {
302        let v = self.std(ddof);
303        Scalar::new(DataType::Float64, v.into())
304    }
305}
306
307#[cfg(feature = "dtype-f16")]
308impl VarAggSeries for Float16Chunked {
309    fn var_reduce(&self, ddof: u8) -> Scalar {
310        let v = self.var(ddof).map(AsPrimitive::<pf16>::as_);
311        Scalar::new(DataType::Float16, v.into())
312    }
313
314    fn std_reduce(&self, ddof: u8) -> Scalar {
315        let v = self.std(ddof).map(AsPrimitive::<pf16>::as_);
316        Scalar::new(DataType::Float16, v.into())
317    }
318}
319
320impl VarAggSeries for Float32Chunked {
321    fn var_reduce(&self, ddof: u8) -> Scalar {
322        let v = self.var(ddof).map(|v| v as f32);
323        Scalar::new(DataType::Float32, v.into())
324    }
325
326    fn std_reduce(&self, ddof: u8) -> Scalar {
327        let v = self.std(ddof).map(|v| v as f32);
328        Scalar::new(DataType::Float32, v.into())
329    }
330}
331
332impl VarAggSeries for Float64Chunked {
333    fn var_reduce(&self, ddof: u8) -> Scalar {
334        let v = self.var(ddof);
335        Scalar::new(DataType::Float64, v.into())
336    }
337
338    fn std_reduce(&self, ddof: u8) -> Scalar {
339        let v = self.std(ddof);
340        Scalar::new(DataType::Float64, v.into())
341    }
342}
343
344impl<T> QuantileAggSeries for ChunkedArray<T>
345where
346    T: PolarsIntegerType,
347    T::Native: Ord + WrappingSum,
348{
349    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
350        let v = self.quantile(quantile, method)?;
351        Ok(Scalar::new(DataType::Float64, v.into()))
352    }
353
354    fn median_reduce(&self) -> Scalar {
355        let v = self.median();
356        Scalar::new(DataType::Float64, v.into())
357    }
358}
359
360#[cfg(feature = "dtype-f16")]
361impl QuantileAggSeries for Float16Chunked {
362    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
363        let v = self.quantile(quantile, method)?;
364        Ok(Scalar::new(DataType::Float16, v.into()))
365    }
366
367    fn median_reduce(&self) -> Scalar {
368        let v = self.median();
369        Scalar::new(DataType::Float16, v.into())
370    }
371}
372
373impl QuantileAggSeries for Float32Chunked {
374    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
375        let v = self.quantile(quantile, method)?;
376        Ok(Scalar::new(DataType::Float32, v.into()))
377    }
378
379    fn median_reduce(&self) -> Scalar {
380        let v = self.median();
381        Scalar::new(DataType::Float32, v.into())
382    }
383}
384
385impl QuantileAggSeries for Float64Chunked {
386    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
387        let v = self.quantile(quantile, method)?;
388        Ok(Scalar::new(DataType::Float64, v.into()))
389    }
390
391    fn median_reduce(&self) -> Scalar {
392        let v = self.median();
393        Scalar::new(DataType::Float64, v.into())
394    }
395}
396
397impl ChunkAggSeries for BooleanChunked {
398    fn sum_reduce(&self) -> Scalar {
399        let v = self.sum();
400        Scalar::new(IDX_DTYPE, v.into())
401    }
402    fn max_reduce(&self) -> Scalar {
403        let v = self.max();
404        Scalar::new(DataType::Boolean, v.into())
405    }
406    fn min_reduce(&self) -> Scalar {
407        let v = self.min();
408        Scalar::new(DataType::Boolean, v.into())
409    }
410}
411
412impl StringChunked {
413    pub(crate) fn max_str(&self) -> Option<&str> {
414        if self.is_empty() {
415            return None;
416        }
417        match self.is_sorted_flag() {
418            IsSorted::Ascending => {
419                self.last_non_null().and_then(|idx| {
420                    // SAFETY: last_non_null returns in bound index
421                    unsafe { self.get_unchecked(idx) }
422                })
423            },
424            IsSorted::Descending => {
425                self.first_non_null().and_then(|idx| {
426                    // SAFETY: first_non_null returns in bound index
427                    unsafe { self.get_unchecked(idx) }
428                })
429            },
430            IsSorted::Not => self
431                .downcast_iter()
432                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
433                .reduce(MinMax::max_ignore_nan),
434        }
435    }
436    pub(crate) fn min_str(&self) -> Option<&str> {
437        if self.is_empty() {
438            return None;
439        }
440        match self.is_sorted_flag() {
441            IsSorted::Ascending => {
442                self.first_non_null().and_then(|idx| {
443                    // SAFETY: first_non_null returns in bound index
444                    unsafe { self.get_unchecked(idx) }
445                })
446            },
447            IsSorted::Descending => {
448                self.last_non_null().and_then(|idx| {
449                    // SAFETY: last_non_null returns in bound index
450                    unsafe { self.get_unchecked(idx) }
451                })
452            },
453            IsSorted::Not => self
454                .downcast_iter()
455                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
456                .reduce(MinMax::min_ignore_nan),
457        }
458    }
459}
460
461impl ChunkAggSeries for StringChunked {
462    fn max_reduce(&self) -> Scalar {
463        let av: AnyValue = self.max_str().into();
464        Scalar::new(DataType::String, av.into_static())
465    }
466    fn min_reduce(&self) -> Scalar {
467        let av: AnyValue = self.min_str().into();
468        Scalar::new(DataType::String, av.into_static())
469    }
470}
471
472#[cfg(feature = "dtype-categorical")]
473impl<T: PolarsCategoricalType> CategoricalChunked<T>
474where
475    ChunkedArray<T::PolarsPhysical>: ChunkAgg<T::Native>,
476{
477    fn min_categorical(&self) -> Option<CatSize> {
478        if self.is_empty() || self.null_count() == self.len() {
479            return None;
480        }
481        if self.uses_lexical_ordering() {
482            let mapping = self.get_mapping();
483            let s = self
484                .physical()
485                .iter()
486                .flat_map(|opt_cat| {
487                    Some(unsafe { mapping.cat_to_str_unchecked(opt_cat?.as_cat()) })
488                })
489                .min();
490            mapping.get_cat(s.unwrap())
491        } else {
492            Some(self.physical().min()?.as_cat())
493        }
494    }
495
496    fn max_categorical(&self) -> Option<CatSize> {
497        if self.is_empty() || self.null_count() == self.len() {
498            return None;
499        }
500        if self.uses_lexical_ordering() {
501            let mapping = self.get_mapping();
502            let s = self
503                .physical()
504                .iter()
505                .flat_map(|opt_cat| {
506                    Some(unsafe { mapping.cat_to_str_unchecked(opt_cat?.as_cat()) })
507                })
508                .max();
509            mapping.get_cat(s.unwrap())
510        } else {
511            Some(self.physical().max()?.as_cat())
512        }
513    }
514}
515
516#[cfg(feature = "dtype-categorical")]
517impl<T: PolarsCategoricalType> ChunkAggSeries for CategoricalChunked<T>
518where
519    ChunkedArray<T::PolarsPhysical>: ChunkAgg<T::Native>,
520{
521    fn min_reduce(&self) -> Scalar {
522        let Some(min) = self.min_categorical() else {
523            return Scalar::new(self.dtype().clone(), AnyValue::Null);
524        };
525        let av = match self.dtype() {
526            DataType::Enum(_, mapping) => AnyValue::EnumOwned(min, mapping.clone()),
527            DataType::Categorical(_, mapping) => AnyValue::CategoricalOwned(min, mapping.clone()),
528            _ => unreachable!(),
529        };
530        Scalar::new(self.dtype().clone(), av)
531    }
532
533    fn max_reduce(&self) -> Scalar {
534        let Some(max) = self.max_categorical() else {
535            return Scalar::new(self.dtype().clone(), AnyValue::Null);
536        };
537        let av = match self.dtype() {
538            DataType::Enum(_, mapping) => AnyValue::EnumOwned(max, mapping.clone()),
539            DataType::Categorical(_, mapping) => AnyValue::CategoricalOwned(max, mapping.clone()),
540            _ => unreachable!(),
541        };
542        Scalar::new(self.dtype().clone(), av)
543    }
544}
545
546impl BinaryChunked {
547    pub fn max_binary(&self) -> Option<&[u8]> {
548        if self.is_empty() {
549            return None;
550        }
551        match self.is_sorted_flag() {
552            IsSorted::Ascending => {
553                self.last_non_null().and_then(|idx| {
554                    // SAFETY: last_non_null returns in bound index.
555                    unsafe { self.get_unchecked(idx) }
556                })
557            },
558            IsSorted::Descending => {
559                self.first_non_null().and_then(|idx| {
560                    // SAFETY: first_non_null returns in bound index.
561                    unsafe { self.get_unchecked(idx) }
562                })
563            },
564            IsSorted::Not => self
565                .downcast_iter()
566                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
567                .reduce(MinMax::max_ignore_nan),
568        }
569    }
570
571    pub fn min_binary(&self) -> Option<&[u8]> {
572        if self.is_empty() {
573            return None;
574        }
575        match self.is_sorted_flag() {
576            IsSorted::Ascending => {
577                self.first_non_null().and_then(|idx| {
578                    // SAFETY: first_non_null returns in bound index.
579                    unsafe { self.get_unchecked(idx) }
580                })
581            },
582            IsSorted::Descending => {
583                self.last_non_null().and_then(|idx| {
584                    // SAFETY: last_non_null returns in bound index.
585                    unsafe { self.get_unchecked(idx) }
586                })
587            },
588            IsSorted::Not => self
589                .downcast_iter()
590                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
591                .reduce(MinMax::min_ignore_nan),
592        }
593    }
594}
595
596impl ChunkAggSeries for BinaryChunked {
597    fn sum_reduce(&self) -> Scalar {
598        unimplemented!()
599    }
600    fn max_reduce(&self) -> Scalar {
601        let av: AnyValue = self.max_binary().into();
602        Scalar::new(self.dtype().clone(), av.into_static())
603    }
604    fn min_reduce(&self) -> Scalar {
605        let av: AnyValue = self.min_binary().into();
606        Scalar::new(self.dtype().clone(), av.into_static())
607    }
608}
609
610#[cfg(feature = "object")]
611impl<T: PolarsObject> ChunkAggSeries for ObjectChunked<T> {}
612
613#[cfg(test)]
614mod test {
615    use polars_compute::rolling::QuantileMethod;
616
617    use crate::prelude::*;
618
619    #[test]
620    #[cfg(not(miri))]
621    fn test_var() {
622        // Validated with numpy. Note that numpy uses ddof as an argument which
623        // influences results. The default ddof=0, we chose ddof=1, which is
624        // standard in statistics.
625        let ca1 = Int32Chunked::new(PlSmallStr::EMPTY, &[5, 8, 9, 5, 0]);
626        let ca2 = Int32Chunked::new(
627            PlSmallStr::EMPTY,
628            &[
629                Some(5),
630                None,
631                Some(8),
632                Some(9),
633                None,
634                Some(5),
635                Some(0),
636                None,
637            ],
638        );
639        for ca in &[ca1, ca2] {
640            let out = ca.var(1);
641            assert_eq!(out, Some(12.3));
642            let out = ca.std(1).unwrap();
643            assert!((3.5071355833500366 - out).abs() < 0.000000001);
644        }
645    }
646
647    #[test]
648    fn test_agg_float() {
649        let ca1 = Float32Chunked::new(PlSmallStr::from_static("a"), &[1.0, f32::NAN]);
650        let ca2 = Float32Chunked::new(PlSmallStr::from_static("b"), &[f32::NAN, 1.0]);
651        assert_eq!(ca1.min(), ca2.min());
652        let ca1 = Float64Chunked::new(PlSmallStr::from_static("a"), &[1.0, f64::NAN]);
653        let ca2 = Float64Chunked::from_slice(PlSmallStr::from_static("b"), &[f64::NAN, 1.0]);
654        assert_eq!(ca1.min(), ca2.min());
655        println!("{:?}", (ca1.min(), ca2.min()))
656    }
657
658    #[test]
659    fn test_median() {
660        let ca = UInt32Chunked::new(
661            PlSmallStr::from_static("a"),
662            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
663        );
664        assert_eq!(ca.median(), Some(3.0));
665        let ca = UInt32Chunked::new(
666            PlSmallStr::from_static("a"),
667            &[
668                None,
669                Some(7),
670                Some(6),
671                Some(2),
672                Some(1),
673                None,
674                Some(3),
675                Some(5),
676                None,
677                Some(4),
678            ],
679        );
680        assert_eq!(ca.median(), Some(4.0));
681
682        let ca = Float32Chunked::from_slice(
683            PlSmallStr::EMPTY,
684            &[
685                0.166189, 0.166559, 0.168517, 0.169393, 0.175272, 0.233167, 0.238787, 0.266562,
686                0.26903, 0.285792, 0.292801, 0.293429, 0.301706, 0.308534, 0.331489, 0.346095,
687                0.367644, 0.369939, 0.372074, 0.41014, 0.415789, 0.421781, 0.427725, 0.465363,
688                0.500208, 2.621727, 2.803311, 3.868526,
689            ],
690        );
691        assert!((ca.median().unwrap() - 0.3200115).abs() < 0.0001)
692    }
693
694    #[test]
695    fn test_mean() {
696        let ca = Float32Chunked::new(PlSmallStr::EMPTY, &[Some(1.0), Some(2.0), None]);
697        assert_eq!(ca.mean().unwrap(), 1.5);
698        assert_eq!(
699            ca.into_series()
700                .mean_reduce()
701                .unwrap()
702                .value()
703                .extract::<f32>()
704                .unwrap(),
705            1.5
706        );
707        // all null values case
708        let ca = Float32Chunked::full_null(PlSmallStr::EMPTY, 3);
709        assert_eq!(ca.mean(), None);
710        assert_eq!(
711            ca.into_series()
712                .mean_reduce()
713                .unwrap()
714                .value()
715                .extract::<f32>(),
716            None
717        );
718    }
719
720    #[test]
721    fn test_quantile_all_null() {
722        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
723        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
724        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
725        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
726
727        let methods = vec![
728            QuantileMethod::Nearest,
729            QuantileMethod::Lower,
730            QuantileMethod::Higher,
731            QuantileMethod::Midpoint,
732            QuantileMethod::Linear,
733            QuantileMethod::Equiprobable,
734        ];
735
736        for method in methods {
737            assert_eq!(test_f32.quantile(0.9, method).unwrap(), None);
738            assert_eq!(test_i32.quantile(0.9, method).unwrap(), None);
739            assert_eq!(test_f64.quantile(0.9, method).unwrap(), None);
740            assert_eq!(test_i64.quantile(0.9, method).unwrap(), None);
741        }
742    }
743
744    #[test]
745    fn test_quantile_single_value() {
746        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
747        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
748        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
749        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
750
751        let methods = vec![
752            QuantileMethod::Nearest,
753            QuantileMethod::Lower,
754            QuantileMethod::Higher,
755            QuantileMethod::Midpoint,
756            QuantileMethod::Linear,
757            QuantileMethod::Equiprobable,
758        ];
759
760        for method in methods {
761            assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0));
762            assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0));
763            assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0));
764            assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0));
765        }
766    }
767
768    #[test]
769    fn test_quantile_min_max() {
770        let test_f32 = Float32Chunked::from_slice_options(
771            PlSmallStr::EMPTY,
772            &[None, Some(1f32), Some(5f32), Some(1f32)],
773        );
774        let test_i32 = Int32Chunked::from_slice_options(
775            PlSmallStr::EMPTY,
776            &[None, Some(1i32), Some(5i32), Some(1i32)],
777        );
778        let test_f64 = Float64Chunked::from_slice_options(
779            PlSmallStr::EMPTY,
780            &[None, Some(1f64), Some(5f64), Some(1f64)],
781        );
782        let test_i64 = Int64Chunked::from_slice_options(
783            PlSmallStr::EMPTY,
784            &[None, Some(1i64), Some(5i64), Some(1i64)],
785        );
786
787        let methods = vec![
788            QuantileMethod::Nearest,
789            QuantileMethod::Lower,
790            QuantileMethod::Higher,
791            QuantileMethod::Midpoint,
792            QuantileMethod::Linear,
793            QuantileMethod::Equiprobable,
794        ];
795
796        for method in methods {
797            assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min());
798            assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max());
799
800            assert_eq!(
801                test_i32.quantile(0.0, method).unwrap().unwrap(),
802                test_i32.min().unwrap() as f64
803            );
804            assert_eq!(
805                test_i32.quantile(1.0, method).unwrap().unwrap(),
806                test_i32.max().unwrap() as f64
807            );
808
809            assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min());
810            assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max());
811            assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median());
812
813            assert_eq!(
814                test_i64.quantile(0.0, method).unwrap().unwrap(),
815                test_i64.min().unwrap() as f64
816            );
817            assert_eq!(
818                test_i64.quantile(1.0, method).unwrap().unwrap(),
819                test_i64.max().unwrap() as f64
820            );
821        }
822    }
823
824    #[test]
825    fn test_quantile() {
826        let ca = UInt32Chunked::new(
827            PlSmallStr::from_static("a"),
828            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
829        );
830
831        assert_eq!(
832            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
833            Some(1.0)
834        );
835        assert_eq!(
836            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
837            Some(5.0)
838        );
839        assert_eq!(
840            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
841            Some(3.0)
842        );
843
844        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
845        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0));
846        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0));
847
848        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
849        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0));
850        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0));
851
852        assert_eq!(
853            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
854            Some(1.5)
855        );
856        assert_eq!(
857            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
858            Some(4.5)
859        );
860        assert_eq!(
861            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
862            Some(3.5)
863        );
864
865        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4));
866        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6));
867        assert!(
868            (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001
869        );
870
871        assert_eq!(
872            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
873            Some(1.0)
874        );
875        assert_eq!(
876            ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(),
877            Some(2.0)
878        );
879        assert_eq!(
880            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
881            Some(3.0)
882        );
883
884        let ca = UInt32Chunked::new(
885            PlSmallStr::from_static("a"),
886            &[
887                None,
888                Some(7),
889                Some(6),
890                Some(2),
891                Some(1),
892                None,
893                Some(3),
894                Some(5),
895                None,
896                Some(4),
897            ],
898        );
899
900        assert_eq!(
901            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
902            Some(2.0)
903        );
904        assert_eq!(
905            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
906            Some(6.0)
907        );
908        assert_eq!(
909            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
910            Some(5.0)
911        );
912
913        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
914        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0));
915        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0));
916
917        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
918        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0));
919        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0));
920
921        assert_eq!(
922            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
923            Some(1.5)
924        );
925        assert_eq!(
926            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
927            Some(6.5)
928        );
929        assert_eq!(
930            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
931            Some(4.5)
932        );
933
934        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6));
935        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4));
936        assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6));
937
938        assert_eq!(
939            ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(),
940            Some(1.0)
941        );
942        assert_eq!(
943            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
944            Some(2.0)
945        );
946        assert_eq!(
947            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
948            Some(5.0)
949        );
950    }
951}