polars_core/chunked_array/ops/aggregate/
mod.rs

1//! Implementations of the ChunkAgg trait.
2mod quantile;
3mod var;
4
5use arrow::types::NativeType;
6use num_traits::{Float, One, ToPrimitive, Zero};
7use polars_compute::float_sum;
8use polars_compute::min_max::MinMaxKernel;
9use polars_compute::rolling::QuantileMethod;
10use polars_compute::sum::{WrappingSum, wrapping_sum_arr};
11use polars_utils::min_max::MinMax;
12use polars_utils::sync::SyncPtr;
13pub use quantile::*;
14pub use var::*;
15
16use super::float_sorted_arg_max::{
17    float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
18};
19use crate::chunked_array::ChunkedArray;
20use crate::datatypes::{BooleanChunked, PolarsNumericType};
21use crate::prelude::*;
22use crate::series::IsSorted;
23
24/// Aggregations that return [`Series`] of unit length. Those can be used in broadcasting operations.
25pub trait ChunkAggSeries {
26    /// Get the sum of the [`ChunkedArray`] as a new [`Series`] of length 1.
27    fn sum_reduce(&self) -> Scalar {
28        unimplemented!()
29    }
30    /// Get the max of the [`ChunkedArray`] as a new [`Series`] of length 1.
31    fn max_reduce(&self) -> Scalar {
32        unimplemented!()
33    }
34    /// Get the min of the [`ChunkedArray`] as a new [`Series`] of length 1.
35    fn min_reduce(&self) -> Scalar {
36        unimplemented!()
37    }
38    /// Get the product of the [`ChunkedArray`] as a new [`Series`] of length 1.
39    fn prod_reduce(&self) -> Scalar {
40        unimplemented!()
41    }
42}
43
44fn sum<T>(array: &PrimitiveArray<T>) -> T
45where
46    T: NumericNative + NativeType + WrappingSum,
47{
48    if array.null_count() == array.len() {
49        return T::default();
50    }
51
52    if T::is_float() {
53        unsafe {
54            if T::is_f32() {
55                let f32_arr =
56                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f32>>(array);
57                let sum = float_sum::sum_arr_as_f32(f32_arr);
58                std::mem::transmute_copy::<f32, T>(&sum)
59            } else if T::is_f64() {
60                let f64_arr =
61                    std::mem::transmute::<&PrimitiveArray<T>, &PrimitiveArray<f64>>(array);
62                let sum = float_sum::sum_arr_as_f64(f64_arr);
63                std::mem::transmute_copy::<f64, T>(&sum)
64            } else {
65                unreachable!("only supported float types are f32 and f64");
66            }
67        }
68    } else {
69        wrapping_sum_arr(array)
70    }
71}
72
73impl<T> ChunkAgg<T::Native> for ChunkedArray<T>
74where
75    T: PolarsNumericType,
76    T::Native: WrappingSum,
77    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
78{
79    fn sum(&self) -> Option<T::Native> {
80        Some(
81            self.downcast_iter()
82                .map(sum)
83                .fold(T::Native::zero(), |acc, v| acc + v),
84        )
85    }
86
87    fn _sum_as_f64(&self) -> f64 {
88        self.downcast_iter().map(float_sum::sum_arr_as_f64).sum()
89    }
90
91    fn min(&self) -> Option<T::Native> {
92        if self.null_count() == self.len() {
93            return None;
94        }
95
96        // There is at least one non-null value.
97
98        match self.is_sorted_flag() {
99            IsSorted::Ascending => {
100                let idx = self.first_non_null().unwrap();
101                unsafe { self.get_unchecked(idx) }
102            },
103            IsSorted::Descending => {
104                let idx = self.last_non_null().unwrap();
105                unsafe { self.get_unchecked(idx) }
106            },
107            IsSorted::Not => self
108                .downcast_iter()
109                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
110                .reduce(MinMax::min_ignore_nan),
111        }
112    }
113
114    fn max(&self) -> Option<T::Native> {
115        if self.null_count() == self.len() {
116            return None;
117        }
118        // There is at least one non-null value.
119
120        match self.is_sorted_flag() {
121            IsSorted::Ascending => {
122                let idx = if T::get_dtype().is_float() {
123                    float_arg_max_sorted_ascending(self)
124                } else {
125                    self.last_non_null().unwrap()
126                };
127
128                unsafe { self.get_unchecked(idx) }
129            },
130            IsSorted::Descending => {
131                let idx = if T::get_dtype().is_float() {
132                    float_arg_max_sorted_descending(self)
133                } else {
134                    self.first_non_null().unwrap()
135                };
136
137                unsafe { self.get_unchecked(idx) }
138            },
139            IsSorted::Not => self
140                .downcast_iter()
141                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
142                .reduce(MinMax::max_ignore_nan),
143        }
144    }
145
146    fn min_max(&self) -> Option<(T::Native, T::Native)> {
147        if self.null_count() == self.len() {
148            return None;
149        }
150        // There is at least one non-null value.
151
152        match self.is_sorted_flag() {
153            IsSorted::Ascending => {
154                let min = unsafe { self.get_unchecked(self.first_non_null().unwrap()) };
155                let max = {
156                    let idx = if T::get_dtype().is_float() {
157                        float_arg_max_sorted_ascending(self)
158                    } else {
159                        self.last_non_null().unwrap()
160                    };
161
162                    unsafe { self.get_unchecked(idx) }
163                };
164                min.zip(max)
165            },
166            IsSorted::Descending => {
167                let min = unsafe { self.get_unchecked(self.last_non_null().unwrap()) };
168                let max = {
169                    let idx = if T::get_dtype().is_float() {
170                        float_arg_max_sorted_descending(self)
171                    } else {
172                        self.first_non_null().unwrap()
173                    };
174
175                    unsafe { self.get_unchecked(idx) }
176                };
177
178                min.zip(max)
179            },
180            IsSorted::Not => self
181                .downcast_iter()
182                .filter_map(MinMaxKernel::min_max_ignore_nan_kernel)
183                .reduce(|(min1, max1), (min2, max2)| {
184                    (
185                        MinMax::min_ignore_nan(min1, min2),
186                        MinMax::max_ignore_nan(max1, max2),
187                    )
188                }),
189        }
190    }
191
192    fn mean(&self) -> Option<f64> {
193        let count = self.len() - self.null_count();
194        if count == 0 {
195            return None;
196        }
197        Some(self._sum_as_f64() / count as f64)
198    }
199}
200
201/// Booleans are cast to 1 or 0.
202impl BooleanChunked {
203    pub fn sum(&self) -> Option<IdxSize> {
204        Some(if self.is_empty() {
205            0
206        } else {
207            self.downcast_iter()
208                .map(|arr| match arr.validity() {
209                    Some(validity) => {
210                        (arr.len() - (validity & arr.values()).unset_bits()) as IdxSize
211                    },
212                    None => (arr.len() - arr.values().unset_bits()) as IdxSize,
213                })
214                .sum()
215        })
216    }
217
218    pub fn min(&self) -> Option<bool> {
219        let nc = self.null_count();
220        let len = self.len();
221        if self.is_empty() || nc == len {
222            return None;
223        }
224        if nc == 0 {
225            if self.all() { Some(true) } else { Some(false) }
226        } else {
227            // we can unwrap as we already checked empty and all null above
228            if (self.sum().unwrap() + nc as IdxSize) == len as IdxSize {
229                Some(true)
230            } else {
231                Some(false)
232            }
233        }
234    }
235
236    pub fn max(&self) -> Option<bool> {
237        if self.is_empty() || self.null_count() == self.len() {
238            return None;
239        }
240        if self.any() { Some(true) } else { Some(false) }
241    }
242    pub fn mean(&self) -> Option<f64> {
243        if self.is_empty() || self.null_count() == self.len() {
244            return None;
245        }
246        self.sum()
247            .map(|sum| sum as f64 / (self.len() - self.null_count()) as f64)
248    }
249}
250
251// Needs the same trait bounds as the implementation of ChunkedArray<T> of dyn Series.
252impl<T> ChunkAggSeries for ChunkedArray<T>
253where
254    T: PolarsNumericType,
255    T::Native: WrappingSum,
256    PrimitiveArray<T::Native>: for<'a> MinMaxKernel<Scalar<'a> = T::Native>,
257    ChunkedArray<T>: IntoSeries,
258{
259    fn sum_reduce(&self) -> Scalar {
260        let v: Option<T::Native> = self.sum();
261        Scalar::new(T::get_dtype(), v.into())
262    }
263
264    fn max_reduce(&self) -> Scalar {
265        let v = ChunkAgg::max(self);
266        Scalar::new(T::get_dtype(), v.into())
267    }
268
269    fn min_reduce(&self) -> Scalar {
270        let v = ChunkAgg::min(self);
271        Scalar::new(T::get_dtype(), v.into())
272    }
273
274    fn prod_reduce(&self) -> Scalar {
275        let mut prod = T::Native::one();
276
277        for arr in self.downcast_iter() {
278            for v in arr.into_iter().flatten() {
279                prod = prod * *v
280            }
281        }
282        Scalar::new(T::get_dtype(), prod.into())
283    }
284}
285
286impl<T> VarAggSeries for ChunkedArray<T>
287where
288    T: PolarsIntegerType,
289    ChunkedArray<T>: ChunkVar,
290{
291    fn var_reduce(&self, ddof: u8) -> Scalar {
292        let v = self.var(ddof);
293        Scalar::new(DataType::Float64, v.into())
294    }
295
296    fn std_reduce(&self, ddof: u8) -> Scalar {
297        let v = self.std(ddof);
298        Scalar::new(DataType::Float64, v.into())
299    }
300}
301
302impl VarAggSeries for Float32Chunked {
303    fn var_reduce(&self, ddof: u8) -> Scalar {
304        let v = self.var(ddof).map(|v| v as f32);
305        Scalar::new(DataType::Float32, v.into())
306    }
307
308    fn std_reduce(&self, ddof: u8) -> Scalar {
309        let v = self.std(ddof).map(|v| v as f32);
310        Scalar::new(DataType::Float32, v.into())
311    }
312}
313
314impl VarAggSeries for Float64Chunked {
315    fn var_reduce(&self, ddof: u8) -> Scalar {
316        let v = self.var(ddof);
317        Scalar::new(DataType::Float64, v.into())
318    }
319
320    fn std_reduce(&self, ddof: u8) -> Scalar {
321        let v = self.std(ddof);
322        Scalar::new(DataType::Float64, v.into())
323    }
324}
325
326impl<T> QuantileAggSeries for ChunkedArray<T>
327where
328    T: PolarsIntegerType,
329    T::Native: Ord + WrappingSum,
330{
331    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
332        let v = self.quantile(quantile, method)?;
333        Ok(Scalar::new(DataType::Float64, v.into()))
334    }
335
336    fn median_reduce(&self) -> Scalar {
337        let v = self.median();
338        Scalar::new(DataType::Float64, v.into())
339    }
340}
341
342impl QuantileAggSeries for Float32Chunked {
343    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
344        let v = self.quantile(quantile, method)?;
345        Ok(Scalar::new(DataType::Float32, v.into()))
346    }
347
348    fn median_reduce(&self) -> Scalar {
349        let v = self.median();
350        Scalar::new(DataType::Float32, v.into())
351    }
352}
353
354impl QuantileAggSeries for Float64Chunked {
355    fn quantile_reduce(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<Scalar> {
356        let v = self.quantile(quantile, method)?;
357        Ok(Scalar::new(DataType::Float64, v.into()))
358    }
359
360    fn median_reduce(&self) -> Scalar {
361        let v = self.median();
362        Scalar::new(DataType::Float64, v.into())
363    }
364}
365
366impl ChunkAggSeries for BooleanChunked {
367    fn sum_reduce(&self) -> Scalar {
368        let v = self.sum();
369        Scalar::new(IDX_DTYPE, v.into())
370    }
371    fn max_reduce(&self) -> Scalar {
372        let v = self.max();
373        Scalar::new(DataType::Boolean, v.into())
374    }
375    fn min_reduce(&self) -> Scalar {
376        let v = self.min();
377        Scalar::new(DataType::Boolean, v.into())
378    }
379}
380
381impl StringChunked {
382    pub(crate) fn max_str(&self) -> Option<&str> {
383        if self.is_empty() {
384            return None;
385        }
386        match self.is_sorted_flag() {
387            IsSorted::Ascending => {
388                self.last_non_null().and_then(|idx| {
389                    // SAFETY: last_non_null returns in bound index
390                    unsafe { self.get_unchecked(idx) }
391                })
392            },
393            IsSorted::Descending => {
394                self.first_non_null().and_then(|idx| {
395                    // SAFETY: first_non_null returns in bound index
396                    unsafe { self.get_unchecked(idx) }
397                })
398            },
399            IsSorted::Not => self
400                .downcast_iter()
401                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
402                .reduce(MinMax::max_ignore_nan),
403        }
404    }
405    pub(crate) fn min_str(&self) -> Option<&str> {
406        if self.is_empty() {
407            return None;
408        }
409        match self.is_sorted_flag() {
410            IsSorted::Ascending => {
411                self.first_non_null().and_then(|idx| {
412                    // SAFETY: first_non_null returns in bound index
413                    unsafe { self.get_unchecked(idx) }
414                })
415            },
416            IsSorted::Descending => {
417                self.last_non_null().and_then(|idx| {
418                    // SAFETY: last_non_null returns in bound index
419                    unsafe { self.get_unchecked(idx) }
420                })
421            },
422            IsSorted::Not => self
423                .downcast_iter()
424                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
425                .reduce(MinMax::min_ignore_nan),
426        }
427    }
428}
429
430impl ChunkAggSeries for StringChunked {
431    fn max_reduce(&self) -> Scalar {
432        let av: AnyValue = self.max_str().into();
433        Scalar::new(DataType::String, av.into_static())
434    }
435    fn min_reduce(&self) -> Scalar {
436        let av: AnyValue = self.min_str().into();
437        Scalar::new(DataType::String, av.into_static())
438    }
439}
440
441#[cfg(feature = "dtype-categorical")]
442impl CategoricalChunked {
443    fn min_categorical(&self) -> Option<u32> {
444        if self.is_empty() || self.null_count() == self.len() {
445            return None;
446        }
447        if self.uses_lexical_ordering() {
448            let rev_map = self.get_rev_map();
449            // Fast path where all categories are used
450            let c = if self._can_fast_unique() {
451                rev_map.get_categories().min_ignore_nan_kernel()
452            } else {
453                // SAFETY:
454                // Indices are in bounds
455                self.physical()
456                    .iter()
457                    .flat_map(|opt_el: Option<u32>| {
458                        opt_el.map(|el| unsafe { rev_map.get_unchecked(el) })
459                    })
460                    .min()
461            };
462            rev_map.find(c.unwrap())
463        } else {
464            self.physical().min()
465        }
466    }
467
468    fn max_categorical(&self) -> Option<u32> {
469        if self.is_empty() || self.null_count() == self.len() {
470            return None;
471        }
472        if self.uses_lexical_ordering() {
473            let rev_map = self.get_rev_map();
474            // Fast path where all categories are used
475            let c = if self._can_fast_unique() {
476                rev_map.get_categories().max_ignore_nan_kernel()
477            } else {
478                // SAFETY:
479                // Indices are in bounds
480                self.physical()
481                    .iter()
482                    .flat_map(|opt_el: Option<u32>| {
483                        opt_el.map(|el| unsafe { rev_map.get_unchecked(el) })
484                    })
485                    .max()
486            };
487            rev_map.find(c.unwrap())
488        } else {
489            self.physical().max()
490        }
491    }
492}
493
494#[cfg(feature = "dtype-categorical")]
495impl ChunkAggSeries for CategoricalChunked {
496    fn min_reduce(&self) -> Scalar {
497        match self.dtype() {
498            DataType::Enum(r, _) => match self.physical().min() {
499                None => Scalar::new(self.dtype().clone(), AnyValue::Null),
500                Some(v) => {
501                    let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else {
502                        unreachable!()
503                    };
504                    Scalar::new(
505                        self.dtype().clone(),
506                        AnyValue::EnumOwned(
507                            v,
508                            r.as_ref().unwrap().clone(),
509                            SyncPtr::from_const(arr as *const _),
510                        ),
511                    )
512                },
513            },
514            DataType::Categorical(r, _) => match self.min_categorical() {
515                None => Scalar::new(self.dtype().clone(), AnyValue::Null),
516                Some(v) => {
517                    let r = r.as_ref().unwrap();
518                    let arr = match &**r {
519                        RevMapping::Local(arr, _) => arr,
520                        RevMapping::Global(_, arr, _) => arr,
521                    };
522                    Scalar::new(
523                        self.dtype().clone(),
524                        AnyValue::CategoricalOwned(
525                            v,
526                            r.clone(),
527                            SyncPtr::from_const(arr as *const _),
528                        ),
529                    )
530                },
531            },
532            _ => unreachable!(),
533        }
534    }
535    fn max_reduce(&self) -> Scalar {
536        match self.dtype() {
537            DataType::Enum(r, _) => match self.physical().max() {
538                None => Scalar::new(self.dtype().clone(), AnyValue::Null),
539                Some(v) => {
540                    let RevMapping::Local(arr, _) = &**r.as_ref().unwrap() else {
541                        unreachable!()
542                    };
543                    Scalar::new(
544                        self.dtype().clone(),
545                        AnyValue::EnumOwned(
546                            v,
547                            r.as_ref().unwrap().clone(),
548                            SyncPtr::from_const(arr as *const _),
549                        ),
550                    )
551                },
552            },
553            DataType::Categorical(r, _) => match self.max_categorical() {
554                None => Scalar::new(self.dtype().clone(), AnyValue::Null),
555                Some(v) => {
556                    let r = r.as_ref().unwrap();
557                    let arr = match &**r {
558                        RevMapping::Local(arr, _) => arr,
559                        RevMapping::Global(_, arr, _) => arr,
560                    };
561                    Scalar::new(
562                        self.dtype().clone(),
563                        AnyValue::CategoricalOwned(
564                            v,
565                            r.clone(),
566                            SyncPtr::from_const(arr as *const _),
567                        ),
568                    )
569                },
570            },
571            _ => unreachable!(),
572        }
573    }
574}
575
576impl BinaryChunked {
577    pub fn max_binary(&self) -> Option<&[u8]> {
578        if self.is_empty() {
579            return None;
580        }
581        match self.is_sorted_flag() {
582            IsSorted::Ascending => {
583                self.last_non_null().and_then(|idx| {
584                    // SAFETY: last_non_null returns in bound index.
585                    unsafe { self.get_unchecked(idx) }
586                })
587            },
588            IsSorted::Descending => {
589                self.first_non_null().and_then(|idx| {
590                    // SAFETY: first_non_null returns in bound index.
591                    unsafe { self.get_unchecked(idx) }
592                })
593            },
594            IsSorted::Not => self
595                .downcast_iter()
596                .filter_map(MinMaxKernel::max_ignore_nan_kernel)
597                .reduce(MinMax::max_ignore_nan),
598        }
599    }
600
601    pub fn min_binary(&self) -> Option<&[u8]> {
602        if self.is_empty() {
603            return None;
604        }
605        match self.is_sorted_flag() {
606            IsSorted::Ascending => {
607                self.first_non_null().and_then(|idx| {
608                    // SAFETY: first_non_null returns in bound index.
609                    unsafe { self.get_unchecked(idx) }
610                })
611            },
612            IsSorted::Descending => {
613                self.last_non_null().and_then(|idx| {
614                    // SAFETY: last_non_null returns in bound index.
615                    unsafe { self.get_unchecked(idx) }
616                })
617            },
618            IsSorted::Not => self
619                .downcast_iter()
620                .filter_map(MinMaxKernel::min_ignore_nan_kernel)
621                .reduce(MinMax::min_ignore_nan),
622        }
623    }
624}
625
626impl ChunkAggSeries for BinaryChunked {
627    fn sum_reduce(&self) -> Scalar {
628        unimplemented!()
629    }
630    fn max_reduce(&self) -> Scalar {
631        let av: AnyValue = self.max_binary().into();
632        Scalar::new(self.dtype().clone(), av.into_static())
633    }
634    fn min_reduce(&self) -> Scalar {
635        let av: AnyValue = self.min_binary().into();
636        Scalar::new(self.dtype().clone(), av.into_static())
637    }
638}
639
640#[cfg(feature = "object")]
641impl<T: PolarsObject> ChunkAggSeries for ObjectChunked<T> {}
642
643#[cfg(test)]
644mod test {
645    use polars_compute::rolling::QuantileMethod;
646
647    use crate::prelude::*;
648
649    #[test]
650    #[cfg(not(miri))]
651    fn test_var() {
652        // Validated with numpy. Note that numpy uses ddof as an argument which
653        // influences results. The default ddof=0, we chose ddof=1, which is
654        // standard in statistics.
655        let ca1 = Int32Chunked::new(PlSmallStr::EMPTY, &[5, 8, 9, 5, 0]);
656        let ca2 = Int32Chunked::new(
657            PlSmallStr::EMPTY,
658            &[
659                Some(5),
660                None,
661                Some(8),
662                Some(9),
663                None,
664                Some(5),
665                Some(0),
666                None,
667            ],
668        );
669        for ca in &[ca1, ca2] {
670            let out = ca.var(1);
671            assert_eq!(out, Some(12.3));
672            let out = ca.std(1).unwrap();
673            assert!((3.5071355833500366 - out).abs() < 0.000000001);
674        }
675    }
676
677    #[test]
678    fn test_agg_float() {
679        let ca1 = Float32Chunked::new(PlSmallStr::from_static("a"), &[1.0, f32::NAN]);
680        let ca2 = Float32Chunked::new(PlSmallStr::from_static("b"), &[f32::NAN, 1.0]);
681        assert_eq!(ca1.min(), ca2.min());
682        let ca1 = Float64Chunked::new(PlSmallStr::from_static("a"), &[1.0, f64::NAN]);
683        let ca2 = Float64Chunked::from_slice(PlSmallStr::from_static("b"), &[f64::NAN, 1.0]);
684        assert_eq!(ca1.min(), ca2.min());
685        println!("{:?}", (ca1.min(), ca2.min()))
686    }
687
688    #[test]
689    fn test_median() {
690        let ca = UInt32Chunked::new(
691            PlSmallStr::from_static("a"),
692            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
693        );
694        assert_eq!(ca.median(), Some(3.0));
695        let ca = UInt32Chunked::new(
696            PlSmallStr::from_static("a"),
697            &[
698                None,
699                Some(7),
700                Some(6),
701                Some(2),
702                Some(1),
703                None,
704                Some(3),
705                Some(5),
706                None,
707                Some(4),
708            ],
709        );
710        assert_eq!(ca.median(), Some(4.0));
711
712        let ca = Float32Chunked::from_slice(
713            PlSmallStr::EMPTY,
714            &[
715                0.166189, 0.166559, 0.168517, 0.169393, 0.175272, 0.233167, 0.238787, 0.266562,
716                0.26903, 0.285792, 0.292801, 0.293429, 0.301706, 0.308534, 0.331489, 0.346095,
717                0.367644, 0.369939, 0.372074, 0.41014, 0.415789, 0.421781, 0.427725, 0.465363,
718                0.500208, 2.621727, 2.803311, 3.868526,
719            ],
720        );
721        assert!((ca.median().unwrap() - 0.3200115).abs() < 0.0001)
722    }
723
724    #[test]
725    fn test_mean() {
726        let ca = Float32Chunked::new(PlSmallStr::EMPTY, &[Some(1.0), Some(2.0), None]);
727        assert_eq!(ca.mean().unwrap(), 1.5);
728        assert_eq!(
729            ca.into_series()
730                .mean_reduce()
731                .value()
732                .extract::<f32>()
733                .unwrap(),
734            1.5
735        );
736        // all null values case
737        let ca = Float32Chunked::full_null(PlSmallStr::EMPTY, 3);
738        assert_eq!(ca.mean(), None);
739        assert_eq!(
740            ca.into_series().mean_reduce().value().extract::<f32>(),
741            None
742        );
743    }
744
745    #[test]
746    fn test_quantile_all_null() {
747        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
748        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
749        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
750        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]);
751
752        let methods = vec![
753            QuantileMethod::Nearest,
754            QuantileMethod::Lower,
755            QuantileMethod::Higher,
756            QuantileMethod::Midpoint,
757            QuantileMethod::Linear,
758            QuantileMethod::Equiprobable,
759        ];
760
761        for method in methods {
762            assert_eq!(test_f32.quantile(0.9, method).unwrap(), None);
763            assert_eq!(test_i32.quantile(0.9, method).unwrap(), None);
764            assert_eq!(test_f64.quantile(0.9, method).unwrap(), None);
765            assert_eq!(test_i64.quantile(0.9, method).unwrap(), None);
766        }
767    }
768
769    #[test]
770    fn test_quantile_single_value() {
771        let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
772        let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
773        let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]);
774        let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]);
775
776        let methods = vec![
777            QuantileMethod::Nearest,
778            QuantileMethod::Lower,
779            QuantileMethod::Higher,
780            QuantileMethod::Midpoint,
781            QuantileMethod::Linear,
782            QuantileMethod::Equiprobable,
783        ];
784
785        for method in methods {
786            assert_eq!(test_f32.quantile(0.5, method).unwrap(), Some(1.0));
787            assert_eq!(test_i32.quantile(0.5, method).unwrap(), Some(1.0));
788            assert_eq!(test_f64.quantile(0.5, method).unwrap(), Some(1.0));
789            assert_eq!(test_i64.quantile(0.5, method).unwrap(), Some(1.0));
790        }
791    }
792
793    #[test]
794    fn test_quantile_min_max() {
795        let test_f32 = Float32Chunked::from_slice_options(
796            PlSmallStr::EMPTY,
797            &[None, Some(1f32), Some(5f32), Some(1f32)],
798        );
799        let test_i32 = Int32Chunked::from_slice_options(
800            PlSmallStr::EMPTY,
801            &[None, Some(1i32), Some(5i32), Some(1i32)],
802        );
803        let test_f64 = Float64Chunked::from_slice_options(
804            PlSmallStr::EMPTY,
805            &[None, Some(1f64), Some(5f64), Some(1f64)],
806        );
807        let test_i64 = Int64Chunked::from_slice_options(
808            PlSmallStr::EMPTY,
809            &[None, Some(1i64), Some(5i64), Some(1i64)],
810        );
811
812        let methods = vec![
813            QuantileMethod::Nearest,
814            QuantileMethod::Lower,
815            QuantileMethod::Higher,
816            QuantileMethod::Midpoint,
817            QuantileMethod::Linear,
818            QuantileMethod::Equiprobable,
819        ];
820
821        for method in methods {
822            assert_eq!(test_f32.quantile(0.0, method).unwrap(), test_f32.min());
823            assert_eq!(test_f32.quantile(1.0, method).unwrap(), test_f32.max());
824
825            assert_eq!(
826                test_i32.quantile(0.0, method).unwrap().unwrap(),
827                test_i32.min().unwrap() as f64
828            );
829            assert_eq!(
830                test_i32.quantile(1.0, method).unwrap().unwrap(),
831                test_i32.max().unwrap() as f64
832            );
833
834            assert_eq!(test_f64.quantile(0.0, method).unwrap(), test_f64.min());
835            assert_eq!(test_f64.quantile(1.0, method).unwrap(), test_f64.max());
836            assert_eq!(test_f64.quantile(0.5, method).unwrap(), test_f64.median());
837
838            assert_eq!(
839                test_i64.quantile(0.0, method).unwrap().unwrap(),
840                test_i64.min().unwrap() as f64
841            );
842            assert_eq!(
843                test_i64.quantile(1.0, method).unwrap().unwrap(),
844                test_i64.max().unwrap() as f64
845            );
846        }
847    }
848
849    #[test]
850    fn test_quantile() {
851        let ca = UInt32Chunked::new(
852            PlSmallStr::from_static("a"),
853            &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)],
854        );
855
856        assert_eq!(
857            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
858            Some(1.0)
859        );
860        assert_eq!(
861            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
862            Some(5.0)
863        );
864        assert_eq!(
865            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
866            Some(3.0)
867        );
868
869        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
870        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(4.0));
871        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(3.0));
872
873        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
874        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(5.0));
875        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(4.0));
876
877        assert_eq!(
878            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
879            Some(1.5)
880        );
881        assert_eq!(
882            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
883            Some(4.5)
884        );
885        assert_eq!(
886            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
887            Some(3.5)
888        );
889
890        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.4));
891        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(4.6));
892        assert!(
893            (ca.quantile(0.6, QuantileMethod::Linear).unwrap().unwrap() - 3.4).abs() < 0.0000001
894        );
895
896        assert_eq!(
897            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
898            Some(1.0)
899        );
900        assert_eq!(
901            ca.quantile(0.25, QuantileMethod::Equiprobable).unwrap(),
902            Some(2.0)
903        );
904        assert_eq!(
905            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
906            Some(3.0)
907        );
908
909        let ca = UInt32Chunked::new(
910            PlSmallStr::from_static("a"),
911            &[
912                None,
913                Some(7),
914                Some(6),
915                Some(2),
916                Some(1),
917                None,
918                Some(3),
919                Some(5),
920                None,
921                Some(4),
922            ],
923        );
924
925        assert_eq!(
926            ca.quantile(0.1, QuantileMethod::Nearest).unwrap(),
927            Some(2.0)
928        );
929        assert_eq!(
930            ca.quantile(0.9, QuantileMethod::Nearest).unwrap(),
931            Some(6.0)
932        );
933        assert_eq!(
934            ca.quantile(0.6, QuantileMethod::Nearest).unwrap(),
935            Some(5.0)
936        );
937
938        assert_eq!(ca.quantile(0.1, QuantileMethod::Lower).unwrap(), Some(1.0));
939        assert_eq!(ca.quantile(0.9, QuantileMethod::Lower).unwrap(), Some(6.0));
940        assert_eq!(ca.quantile(0.6, QuantileMethod::Lower).unwrap(), Some(4.0));
941
942        assert_eq!(ca.quantile(0.1, QuantileMethod::Higher).unwrap(), Some(2.0));
943        assert_eq!(ca.quantile(0.9, QuantileMethod::Higher).unwrap(), Some(7.0));
944        assert_eq!(ca.quantile(0.6, QuantileMethod::Higher).unwrap(), Some(5.0));
945
946        assert_eq!(
947            ca.quantile(0.1, QuantileMethod::Midpoint).unwrap(),
948            Some(1.5)
949        );
950        assert_eq!(
951            ca.quantile(0.9, QuantileMethod::Midpoint).unwrap(),
952            Some(6.5)
953        );
954        assert_eq!(
955            ca.quantile(0.6, QuantileMethod::Midpoint).unwrap(),
956            Some(4.5)
957        );
958
959        assert_eq!(ca.quantile(0.1, QuantileMethod::Linear).unwrap(), Some(1.6));
960        assert_eq!(ca.quantile(0.9, QuantileMethod::Linear).unwrap(), Some(6.4));
961        assert_eq!(ca.quantile(0.6, QuantileMethod::Linear).unwrap(), Some(4.6));
962
963        assert_eq!(
964            ca.quantile(0.14, QuantileMethod::Equiprobable).unwrap(),
965            Some(1.0)
966        );
967        assert_eq!(
968            ca.quantile(0.15, QuantileMethod::Equiprobable).unwrap(),
969            Some(2.0)
970        );
971        assert_eq!(
972            ca.quantile(0.6, QuantileMethod::Equiprobable).unwrap(),
973            Some(5.0)
974        );
975    }
976}