polars_core/frame/group_by/aggregations/
mod.rs

1mod agg_list;
2mod boolean;
3mod dispatch;
4mod string;
5
6use std::borrow::Cow;
7
8pub use agg_list::*;
9use arrow::bitmap::{Bitmap, MutableBitmap};
10use arrow::legacy::kernels::take_agg::*;
11use arrow::legacy::trusted_len::TrustedLenPush;
12use arrow::types::NativeType;
13use num_traits::pow::Pow;
14use num_traits::{Bounded, Float, Num, NumCast, ToPrimitive, Zero};
15use polars_compute::rolling::no_nulls::{
16    MaxWindow, MinWindow, MomentWindow, QuantileWindow, RollingAggWindowNoNulls,
17};
18use polars_compute::rolling::nulls::{RollingAggWindowNulls, VarianceMoment};
19use polars_compute::rolling::quantile_filter::SealedRolling;
20use polars_compute::rolling::{
21    self, MeanWindow, QuantileMethod, RollingFnParams, RollingQuantileParams, RollingVarParams,
22    SumWindow, quantile_filter,
23};
24use polars_utils::float::IsFloat;
25#[cfg(feature = "dtype-f16")]
26use polars_utils::float16::pf16;
27use polars_utils::idx_vec::IdxVec;
28use polars_utils::kahan_sum::KahanSum;
29use polars_utils::min_max::MinMax;
30use rayon::prelude::*;
31
32use crate::POOL;
33use crate::chunked_array::cast::CastOptions;
34#[cfg(feature = "object")]
35use crate::chunked_array::object::extension::create_extension;
36use crate::frame::group_by::GroupsIdx;
37#[cfg(feature = "object")]
38use crate::frame::group_by::GroupsIndicator;
39use crate::prelude::*;
40use crate::series::IsSorted;
41use crate::series::implementations::SeriesWrap;
42use crate::utils::NoNull;
43
44fn idx2usize(idx: &[IdxSize]) -> impl ExactSizeIterator<Item = usize> + '_ {
45    idx.iter().map(|i| *i as usize)
46}
47
48// if the windows overlap, we can use the rolling_<agg> kernels
49// they maintain state, which saves a lot of compute by not naively traversing all elements every
50// window
51//
52// if the windows don't overlap, we should not use these kernels as they are single threaded, so
53// we miss out on easy parallelization.
54pub fn _use_rolling_kernels(
55    groups: &GroupsSlice,
56    overlapping: bool,
57    monotonic: bool,
58    chunks: &[ArrayRef],
59) -> bool {
60    match groups.len() {
61        0 | 1 => false,
62        _ => overlapping && monotonic && chunks.len() == 1,
63    }
64}
65
66// Use an aggregation window that maintains the state
67pub fn _rolling_apply_agg_window_nulls<Agg, T, O>(
68    values: &[T],
69    validity: &Bitmap,
70    offsets: O,
71    params: Option<RollingFnParams>,
72) -> PrimitiveArray<T>
73where
74    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
75    Agg: RollingAggWindowNulls<T>,
76    T: IsFloat + NativeType,
77{
78    if values.is_empty() {
79        let out: Vec<T> = vec![];
80        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
81    }
82
83    // This iterators length can be trusted
84    // these represent the number of groups in the group_by operation
85    let output_len = offsets.size_hint().0;
86    // start with a dummy index, will be overwritten on first iteration.
87    let mut agg_window = Agg::new(values, validity, 0, 0, params, None);
88
89    let mut validity = MutableBitmap::with_capacity(output_len);
90    validity.extend_constant(output_len, true);
91
92    let out = offsets
93        .enumerate()
94        .map(|(idx, (start, len))| {
95            let end = start + len;
96
97            // SAFETY:
98            // we are in bounds
99            let agg = unsafe { agg_window.update(start as usize, end as usize) };
100
101            match agg {
102                Some(val) => val,
103                None => {
104                    // SAFETY: we are in bounds
105                    unsafe { validity.set_unchecked(idx, false) };
106                    T::default()
107                },
108            }
109        })
110        .collect_trusted::<Vec<_>>();
111
112    PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), Some(validity.into()))
113}
114
115// Use an aggregation window that maintains the state.
116pub fn _rolling_apply_agg_window_no_nulls<Agg, T, O>(
117    values: &[T],
118    offsets: O,
119    params: Option<RollingFnParams>,
120) -> PrimitiveArray<T>
121where
122    // items (offset, len) -> so offsets are offset, offset + len
123    Agg: RollingAggWindowNoNulls<T>,
124    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
125    T: IsFloat + NativeType,
126{
127    if values.is_empty() {
128        let out: Vec<T> = vec![];
129        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
130    }
131    // start with a dummy index, will be overwritten on first iteration.
132    let mut agg_window = Agg::new(values, 0, 0, params, None);
133
134    offsets
135        .map(|(start, len)| {
136            let end = start + len;
137
138            // SAFETY: we are in bounds.
139            unsafe { agg_window.update(start as usize, end as usize) }
140        })
141        .collect::<PrimitiveArray<T>>()
142}
143
144pub fn _slice_from_offsets<T>(ca: &ChunkedArray<T>, first: IdxSize, len: IdxSize) -> ChunkedArray<T>
145where
146    T: PolarsDataType,
147{
148    ca.slice(first as i64, len as usize)
149}
150
151/// Helper that combines the groups into a parallel iterator over `(first, all): (u32, &Vec<u32>)`.
152pub fn _agg_helper_idx<T, F>(groups: &GroupsIdx, f: F) -> Series
153where
154    F: Fn((IdxSize, &IdxVec)) -> Option<T::Native> + Send + Sync,
155    T: PolarsNumericType,
156{
157    let ca: ChunkedArray<T> = POOL.install(|| groups.into_par_iter().map(f).collect());
158    ca.into_series()
159}
160
161/// Same helper as `_agg_helper_idx` but for aggregations that don't return an Option.
162pub fn _agg_helper_idx_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
163where
164    F: Fn((IdxSize, &IdxVec)) -> T::Native + Send + Sync,
165    T: PolarsNumericType,
166{
167    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.into_par_iter().map(f).collect());
168    ca.into_inner().into_series()
169}
170
171/// Helper that iterates on the `all: Vec<Vec<u32>` collection,
172/// this doesn't have traverse the `first: Vec<u32>` memory and is therefore faster.
173fn agg_helper_idx_on_all<T, F>(groups: &GroupsIdx, f: F) -> Series
174where
175    F: Fn(&IdxVec) -> Option<T::Native> + Send + Sync,
176    T: PolarsNumericType,
177{
178    let ca: ChunkedArray<T> = POOL.install(|| groups.all().into_par_iter().map(f).collect());
179    ca.into_series()
180}
181
182pub fn _agg_helper_slice<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
183where
184    F: Fn([IdxSize; 2]) -> Option<T::Native> + Send + Sync,
185    T: PolarsNumericType,
186{
187    let ca: ChunkedArray<T> = POOL.install(|| groups.par_iter().copied().map(f).collect());
188    ca.into_series()
189}
190
191pub fn _agg_helper_slice_no_null<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
192where
193    F: Fn([IdxSize; 2]) -> T::Native + Send + Sync,
194    T: PolarsNumericType,
195{
196    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.par_iter().copied().map(f).collect());
197    ca.into_inner().into_series()
198}
199
200/// Intermediate helper trait so we can have a single generic implementation
201/// This trait will ensure the specific dispatch works without complicating
202/// the trait bounds.
203trait QuantileDispatcher<K> {
204    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<K>>;
205
206    fn _median(self) -> Option<K>;
207}
208
209impl<T> QuantileDispatcher<f64> for ChunkedArray<T>
210where
211    T: PolarsIntegerType,
212    T::Native: Ord,
213{
214    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
215        self.quantile_faster(quantile, method)
216    }
217    fn _median(self) -> Option<f64> {
218        self.median_faster()
219    }
220}
221
222#[cfg(feature = "dtype-f16")]
223impl QuantileDispatcher<pf16> for Float16Chunked {
224    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<pf16>> {
225        self.quantile_faster(quantile, method)
226    }
227    fn _median(self) -> Option<pf16> {
228        self.median_faster()
229    }
230}
231
232impl QuantileDispatcher<f32> for Float32Chunked {
233    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f32>> {
234        self.quantile_faster(quantile, method)
235    }
236    fn _median(self) -> Option<f32> {
237        self.median_faster()
238    }
239}
240impl QuantileDispatcher<f64> for Float64Chunked {
241    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
242        self.quantile_faster(quantile, method)
243    }
244    fn _median(self) -> Option<f64> {
245        self.median_faster()
246    }
247}
248
249unsafe fn agg_quantile_generic<T, K>(
250    ca: &ChunkedArray<T>,
251    groups: &GroupsType,
252    quantile: f64,
253    method: QuantileMethod,
254) -> Series
255where
256    T: PolarsNumericType,
257    ChunkedArray<T>: QuantileDispatcher<K::Native>,
258    K: PolarsNumericType,
259    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + quantile_filter::SealedRolling,
260{
261    let invalid_quantile = !(0.0..=1.0).contains(&quantile);
262    if invalid_quantile {
263        return Series::full_null(ca.name().clone(), groups.len(), ca.dtype());
264    }
265    match groups {
266        GroupsType::Idx(groups) => {
267            let ca = ca.rechunk();
268            agg_helper_idx_on_all::<K, _>(groups, |idx| {
269                debug_assert!(idx.len() <= ca.len());
270                if idx.is_empty() {
271                    return None;
272                }
273                let take = { ca.take_unchecked(idx) };
274                // checked with invalid quantile check
275                take._quantile(quantile, method).unwrap_unchecked()
276            })
277        },
278        GroupsType::Slice {
279            groups,
280            overlapping,
281            monotonic,
282        } => {
283            if _use_rolling_kernels(groups, *overlapping, *monotonic, ca.chunks()) {
284                // this cast is a no-op for floats
285                let s = ca
286                    .cast_with_options(&K::get_static_dtype(), CastOptions::Overflowing)
287                    .unwrap();
288                let ca: &ChunkedArray<K> = s.as_ref().as_ref();
289                let arr = ca.downcast_iter().next().unwrap();
290                let values = arr.values().as_slice();
291                let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
292                let arr = match arr.validity() {
293                    None => _rolling_apply_agg_window_no_nulls::<QuantileWindow<_>, _, _>(
294                        values,
295                        offset_iter,
296                        Some(RollingFnParams::Quantile(RollingQuantileParams {
297                            prob: quantile,
298                            method,
299                        })),
300                    ),
301                    Some(validity) => {
302                        _rolling_apply_agg_window_nulls::<rolling::nulls::QuantileWindow<_>, _, _>(
303                            values,
304                            validity,
305                            offset_iter,
306                            Some(RollingFnParams::Quantile(RollingQuantileParams {
307                                prob: quantile,
308                                method,
309                            })),
310                        )
311                    },
312                };
313                // The rolling kernels works on the dtype, this is not yet the
314                // float output type we need.
315                ChunkedArray::<K>::with_chunk(PlSmallStr::EMPTY, arr).into_series()
316            } else {
317                _agg_helper_slice::<K, _>(groups, |[first, len]| {
318                    debug_assert!(first + len <= ca.len() as IdxSize);
319                    match len {
320                        0 => None,
321                        1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()),
322                        _ => {
323                            let arr_group = _slice_from_offsets(ca, first, len);
324                            // unwrap checked with invalid quantile check
325                            arr_group
326                                ._quantile(quantile, method)
327                                .unwrap_unchecked()
328                                .map(|flt| NumCast::from(flt).unwrap_unchecked())
329                        },
330                    }
331                })
332            }
333        },
334    }
335}
336
337unsafe fn agg_median_generic<T, K>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series
338where
339    T: PolarsNumericType,
340    ChunkedArray<T>: QuantileDispatcher<K::Native>,
341    K: PolarsNumericType,
342    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + SealedRolling,
343{
344    match groups {
345        GroupsType::Idx(groups) => {
346            let ca = ca.rechunk();
347            agg_helper_idx_on_all::<K, _>(groups, |idx| {
348                debug_assert!(idx.len() <= ca.len());
349                if idx.is_empty() {
350                    return None;
351                }
352                let take = { ca.take_unchecked(idx) };
353                take._median()
354            })
355        },
356        GroupsType::Slice { .. } => {
357            agg_quantile_generic::<T, K>(ca, groups, 0.5, QuantileMethod::Linear)
358        },
359    }
360}
361
362/// # Safety
363///
364/// No bounds checks on `groups`.
365#[cfg(feature = "bitwise")]
366unsafe fn bitwise_agg<T: PolarsNumericType>(
367    ca: &ChunkedArray<T>,
368    groups: &GroupsType,
369    f: fn(&ChunkedArray<T>) -> Option<T::Native>,
370) -> Series
371where
372    ChunkedArray<T>: ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native>,
373{
374    // Prevent a rechunk for every individual group.
375
376    let s = if groups.len() > 1 {
377        ca.rechunk()
378    } else {
379        Cow::Borrowed(ca)
380    };
381
382    match groups {
383        GroupsType::Idx(groups) => agg_helper_idx_on_all::<T, _>(groups, |idx| {
384            debug_assert!(idx.len() <= s.len());
385            if idx.is_empty() {
386                None
387            } else {
388                let take = unsafe { s.take_unchecked(idx) };
389                f(&take)
390            }
391        }),
392        GroupsType::Slice { groups, .. } => _agg_helper_slice::<T, _>(groups, |[first, len]| {
393            debug_assert!(len <= s.len() as IdxSize);
394            if len == 0 {
395                None
396            } else {
397                let take = _slice_from_offsets(&s, first, len);
398                f(&take)
399            }
400        }),
401    }
402}
403
404#[cfg(feature = "bitwise")]
405impl<T> ChunkedArray<T>
406where
407    T: PolarsNumericType,
408    ChunkedArray<T>: ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native>,
409{
410    /// # Safety
411    ///
412    /// No bounds checks on `groups`.
413    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series {
414        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce) }
415    }
416
417    /// # Safety
418    ///
419    /// No bounds checks on `groups`.
420    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series {
421        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce) }
422    }
423
424    /// # Safety
425    ///
426    /// No bounds checks on `groups`.
427    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series {
428        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce) }
429    }
430}
431
432impl<T> ChunkedArray<T>
433where
434    T: PolarsNumericType + Sync,
435    T::Native: NativeType + PartialOrd + Num + NumCast + Zero + Bounded + std::iter::Sum<T::Native>,
436    ChunkedArray<T>: ChunkAgg<T::Native>,
437{
438    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
439        // faster paths
440        match self.is_sorted_flag() {
441            IsSorted::Ascending => return self.clone().into_series().agg_first_non_null(groups),
442            IsSorted::Descending => return self.clone().into_series().agg_last_non_null(groups),
443            _ => {},
444        }
445        match groups {
446            GroupsType::Idx(groups) => {
447                let ca = self.rechunk();
448                let arr = ca.downcast_iter().next().unwrap();
449                let no_nulls = arr.null_count() == 0;
450                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
451                    debug_assert!(idx.len() <= arr.len());
452                    if idx.is_empty() {
453                        None
454                    } else if idx.len() == 1 {
455                        arr.get(first as usize)
456                    } else if no_nulls {
457                        take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
458                            .reduce(|a, b| a.min_ignore_nan(b))
459                    } else {
460                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
461                            .reduce(|a, b| a.min_ignore_nan(b))
462                    }
463                })
464            },
465            GroupsType::Slice {
466                groups: groups_slice,
467                overlapping,
468                monotonic,
469            } => {
470                if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, self.chunks()) {
471                    let arr = self.downcast_iter().next().unwrap();
472                    let values = arr.values().as_slice();
473                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
474                    let arr = match arr.validity() {
475                        None => _rolling_apply_agg_window_no_nulls::<MinWindow<_>, _, _>(
476                            values,
477                            offset_iter,
478                            None,
479                        ),
480                        Some(validity) => _rolling_apply_agg_window_nulls::<
481                            rolling::nulls::MinWindow<_>,
482                            _,
483                            _,
484                        >(
485                            values, validity, offset_iter, None
486                        ),
487                    };
488                    Self::from(arr).into_series()
489                } else {
490                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
491                        debug_assert!(len <= self.len() as IdxSize);
492                        match len {
493                            0 => None,
494                            1 => self.get(first as usize),
495                            _ => {
496                                let arr_group = _slice_from_offsets(self, first, len);
497                                ChunkAgg::min(&arr_group)
498                            },
499                        }
500                    })
501                }
502            },
503        }
504    }
505
506    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
507        // faster paths
508        match (self.is_sorted_flag(), self.null_count()) {
509            (IsSorted::Ascending, 0) => {
510                return self.clone().into_series().agg_last(groups);
511            },
512            (IsSorted::Descending, 0) => {
513                return self.clone().into_series().agg_first(groups);
514            },
515            _ => {},
516        }
517
518        match groups {
519            GroupsType::Idx(groups) => {
520                let ca = self.rechunk();
521                let arr = ca.downcast_iter().next().unwrap();
522                let no_nulls = arr.null_count() == 0;
523                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
524                    debug_assert!(idx.len() <= arr.len());
525                    if idx.is_empty() {
526                        None
527                    } else if idx.len() == 1 {
528                        arr.get(first as usize)
529                    } else if no_nulls {
530                        take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
531                            .reduce(|a, b| a.max_ignore_nan(b))
532                    } else {
533                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
534                            .reduce(|a, b| a.max_ignore_nan(b))
535                    }
536                })
537            },
538            GroupsType::Slice {
539                groups: groups_slice,
540                overlapping,
541                monotonic,
542            } => {
543                if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, self.chunks()) {
544                    let arr = self.downcast_iter().next().unwrap();
545                    let values = arr.values().as_slice();
546                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
547                    let arr = match arr.validity() {
548                        None => _rolling_apply_agg_window_no_nulls::<MaxWindow<_>, _, _>(
549                            values,
550                            offset_iter,
551                            None,
552                        ),
553                        Some(validity) => _rolling_apply_agg_window_nulls::<
554                            rolling::nulls::MaxWindow<_>,
555                            _,
556                            _,
557                        >(
558                            values, validity, offset_iter, None
559                        ),
560                    };
561                    Self::from(arr).into_series()
562                } else {
563                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
564                        debug_assert!(len <= self.len() as IdxSize);
565                        match len {
566                            0 => None,
567                            1 => self.get(first as usize),
568                            _ => {
569                                let arr_group = _slice_from_offsets(self, first, len);
570                                ChunkAgg::max(&arr_group)
571                            },
572                        }
573                    })
574                }
575            },
576        }
577    }
578
579    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
580        match groups {
581            GroupsType::Idx(groups) => {
582                let ca = self.rechunk();
583                let arr = ca.downcast_iter().next().unwrap();
584                let no_nulls = arr.null_count() == 0;
585                _agg_helper_idx_no_null::<T, _>(groups, |(first, idx)| {
586                    debug_assert!(idx.len() <= self.len());
587                    if idx.is_empty() {
588                        T::Native::zero()
589                    } else if idx.len() == 1 {
590                        arr.get(first as usize).unwrap_or(T::Native::zero())
591                    } else if no_nulls {
592                        if T::Native::is_float() {
593                            take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
594                                .fold(KahanSum::default(), |k, x| k + x)
595                                .sum()
596                        } else {
597                            take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
598                                .fold(T::Native::zero(), |a, b| a + b)
599                        }
600                    } else if T::Native::is_float() {
601                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
602                            .fold(KahanSum::default(), |k, x| k + x)
603                            .sum()
604                    } else {
605                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
606                            .fold(T::Native::zero(), |a, b| a + b)
607                    }
608                })
609            },
610            GroupsType::Slice {
611                groups,
612                overlapping,
613                monotonic,
614            } => {
615                if _use_rolling_kernels(groups, *overlapping, *monotonic, self.chunks()) {
616                    let arr = self.downcast_iter().next().unwrap();
617                    let values = arr.values().as_slice();
618                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
619                    let arr = match arr.validity() {
620                        None => _rolling_apply_agg_window_no_nulls::<
621                            SumWindow<T::Native, T::Native>,
622                            _,
623                            _,
624                        >(values, offset_iter, None),
625                        Some(validity) => _rolling_apply_agg_window_nulls::<
626                            SumWindow<T::Native, T::Native>,
627                            _,
628                            _,
629                        >(
630                            values, validity, offset_iter, None
631                        ),
632                    };
633                    Self::from(arr).into_series()
634                } else {
635                    _agg_helper_slice_no_null::<T, _>(groups, |[first, len]| {
636                        debug_assert!(len <= self.len() as IdxSize);
637                        match len {
638                            0 => T::Native::zero(),
639                            1 => self.get(first as usize).unwrap_or(T::Native::zero()),
640                            _ => {
641                                let arr_group = _slice_from_offsets(self, first, len);
642                                arr_group.sum().unwrap_or(T::Native::zero())
643                            },
644                        }
645                    })
646                }
647            },
648        }
649    }
650}
651
652impl<T> SeriesWrap<ChunkedArray<T>>
653where
654    T: PolarsFloatType,
655    ChunkedArray<T>: ChunkVar
656        + VarAggSeries
657        + ChunkQuantile<T::Native>
658        + QuantileAggSeries
659        + ChunkAgg<T::Native>,
660    T::Native: Pow<T::Native, Output = T::Native>,
661{
662    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
663        match groups {
664            GroupsType::Idx(groups) => {
665                let ca = self.rechunk();
666                let arr = ca.downcast_iter().next().unwrap();
667                let no_nulls = arr.null_count() == 0;
668                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
669                    // this can fail due to a bug in lazy code.
670                    // here users can create filters in aggregations
671                    // and thereby creating shorter columns than the original group tuples.
672                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
673                    // access
674                    debug_assert!(idx.len() <= self.len());
675                    let out = if idx.is_empty() {
676                        None
677                    } else if idx.len() == 1 {
678                        arr.get(first as usize).map(|sum| sum.to_f64().unwrap())
679                    } else if no_nulls {
680                        Some(
681                            take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
682                                .fold(KahanSum::default(), |a, b| {
683                                    a + b.to_f64().unwrap_unchecked()
684                                })
685                                .sum()
686                                / idx.len() as f64,
687                        )
688                    } else {
689                        take_agg_primitive_iter_unchecked_count_nulls(
690                            arr,
691                            idx2usize(idx),
692                            KahanSum::default(),
693                            |a, b| a + b.to_f64().unwrap_unchecked(),
694                            idx.len() as IdxSize,
695                        )
696                        .map(|(sum, null_count)| sum.sum() / (idx.len() as f64 - null_count as f64))
697                    };
698                    out.map(|flt| NumCast::from(flt).unwrap())
699                })
700            },
701            GroupsType::Slice {
702                groups,
703                overlapping,
704                monotonic,
705            } => {
706                if _use_rolling_kernels(groups, *overlapping, *monotonic, self.chunks()) {
707                    let arr = self.downcast_iter().next().unwrap();
708                    let values = arr.values().as_slice();
709                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
710                    let arr = match arr.validity() {
711                        None => _rolling_apply_agg_window_no_nulls::<MeanWindow<_>, _, _>(
712                            values,
713                            offset_iter,
714                            None,
715                        ),
716                        Some(validity) => _rolling_apply_agg_window_nulls::<MeanWindow<_>, _, _>(
717                            values,
718                            validity,
719                            offset_iter,
720                            None,
721                        ),
722                    };
723                    ChunkedArray::<T>::from(arr).into_series()
724                } else {
725                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
726                        debug_assert!(len <= self.len() as IdxSize);
727                        match len {
728                            0 => None,
729                            1 => self.get(first as usize),
730                            _ => {
731                                let arr_group = _slice_from_offsets(self, first, len);
732                                arr_group.mean().map(|flt| NumCast::from(flt).unwrap())
733                            },
734                        }
735                    })
736                }
737            },
738        }
739    }
740
741    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series
742    where
743        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
744    {
745        let ca = &self.0.rechunk();
746        match groups {
747            GroupsType::Idx(groups) => {
748                let ca = ca.rechunk();
749                let arr = ca.downcast_iter().next().unwrap();
750                let no_nulls = arr.null_count() == 0;
751                agg_helper_idx_on_all::<T, _>(groups, |idx| {
752                    debug_assert!(idx.len() <= ca.len());
753                    if idx.is_empty() {
754                        return None;
755                    }
756                    let out = if no_nulls {
757                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
758                    } else {
759                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
760                    };
761                    out.map(|flt| NumCast::from(flt).unwrap())
762                })
763            },
764            GroupsType::Slice {
765                groups,
766                overlapping,
767                monotonic,
768            } => {
769                if _use_rolling_kernels(groups, *overlapping, *monotonic, self.chunks()) {
770                    let arr = self.downcast_iter().next().unwrap();
771                    let values = arr.values().as_slice();
772                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
773                    let arr = match arr.validity() {
774                        None => _rolling_apply_agg_window_no_nulls::<
775                            MomentWindow<_, VarianceMoment>,
776                            _,
777                            _,
778                        >(
779                            values,
780                            offset_iter,
781                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
782                        ),
783                        Some(validity) => _rolling_apply_agg_window_nulls::<
784                            rolling::nulls::MomentWindow<_, VarianceMoment>,
785                            _,
786                            _,
787                        >(
788                            values,
789                            validity,
790                            offset_iter,
791                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
792                        ),
793                    };
794                    ChunkedArray::<T>::from(arr).into_series()
795                } else {
796                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
797                        debug_assert!(len <= self.len() as IdxSize);
798                        match len {
799                            0 => None,
800                            1 => {
801                                if ddof == 0 {
802                                    NumCast::from(0)
803                                } else {
804                                    None
805                                }
806                            },
807                            _ => {
808                                let arr_group = _slice_from_offsets(self, first, len);
809                                arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap())
810                            },
811                        }
812                    })
813                }
814            },
815        }
816    }
817    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series
818    where
819        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
820    {
821        let ca = &self.0.rechunk();
822        match groups {
823            GroupsType::Idx(groups) => {
824                let arr = ca.downcast_iter().next().unwrap();
825                let no_nulls = arr.null_count() == 0;
826                agg_helper_idx_on_all::<T, _>(groups, |idx| {
827                    debug_assert!(idx.len() <= ca.len());
828                    if idx.is_empty() {
829                        return None;
830                    }
831                    let out = if no_nulls {
832                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
833                    } else {
834                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
835                    };
836                    out.map(|flt| NumCast::from(flt.sqrt()).unwrap())
837                })
838            },
839            GroupsType::Slice {
840                groups,
841                overlapping,
842                monotonic,
843            } => {
844                if _use_rolling_kernels(groups, *overlapping, *monotonic, self.chunks()) {
845                    let arr = ca.downcast_iter().next().unwrap();
846                    let values = arr.values().as_slice();
847                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
848                    let arr = match arr.validity() {
849                        None => _rolling_apply_agg_window_no_nulls::<
850                            MomentWindow<_, VarianceMoment>,
851                            _,
852                            _,
853                        >(
854                            values,
855                            offset_iter,
856                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
857                        ),
858                        Some(validity) => _rolling_apply_agg_window_nulls::<
859                            rolling::nulls::MomentWindow<_, rolling::nulls::VarianceMoment>,
860                            _,
861                            _,
862                        >(
863                            values,
864                            validity,
865                            offset_iter,
866                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
867                        ),
868                    };
869
870                    let mut ca = ChunkedArray::<T>::from(arr);
871                    ca.apply_mut(|v| v.powf(NumCast::from(0.5).unwrap()));
872                    ca.into_series()
873                } else {
874                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
875                        debug_assert!(len <= self.len() as IdxSize);
876                        match len {
877                            0 => None,
878                            1 => {
879                                if ddof == 0 {
880                                    NumCast::from(0)
881                                } else {
882                                    None
883                                }
884                            },
885                            _ => {
886                                let arr_group = _slice_from_offsets(self, first, len);
887                                arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap())
888                            },
889                        }
890                    })
891                }
892            },
893        }
894    }
895}
896
897impl Float32Chunked {
898    pub(crate) unsafe fn agg_quantile(
899        &self,
900        groups: &GroupsType,
901        quantile: f64,
902        method: QuantileMethod,
903    ) -> Series {
904        agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method)
905    }
906    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
907        agg_median_generic::<_, Float32Type>(self, groups)
908    }
909}
910impl Float64Chunked {
911    pub(crate) unsafe fn agg_quantile(
912        &self,
913        groups: &GroupsType,
914        quantile: f64,
915        method: QuantileMethod,
916    ) -> Series {
917        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
918    }
919    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
920        agg_median_generic::<_, Float64Type>(self, groups)
921    }
922}
923
924impl<T> ChunkedArray<T>
925where
926    T: PolarsIntegerType,
927    ChunkedArray<T>: ChunkAgg<T::Native> + ChunkVar,
928    T::Native: NumericNative + Ord,
929{
930    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
931        match groups {
932            GroupsType::Idx(groups) => {
933                let ca = self.rechunk();
934                let arr = ca.downcast_get(0).unwrap();
935                _agg_helper_idx::<Float64Type, _>(groups, |(first, idx)| {
936                    // this can fail due to a bug in lazy code.
937                    // here users can create filters in aggregations
938                    // and thereby creating shorter columns than the original group tuples.
939                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
940                    // access
941                    debug_assert!(idx.len() <= self.len());
942                    if idx.is_empty() {
943                        None
944                    } else if idx.len() == 1 {
945                        self.get(first as usize).map(|sum| sum.to_f64().unwrap())
946                    } else {
947                        match (self.has_nulls(), self.chunks.len()) {
948                            (false, 1) => Some(
949                                take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
950                                    .fold(KahanSum::default(), |a, b| a + b.to_f64().unwrap())
951                                    .sum()
952                                    / idx.len() as f64,
953                            ),
954                            (_, 1) => {
955                                take_agg_primitive_iter_unchecked_count_nulls(
956                                    arr,
957                                    idx2usize(idx),
958                                    KahanSum::default(),
959                                    |a, b| a + b.to_f64().unwrap(),
960                                    idx.len() as IdxSize,
961                                )
962                            }
963                            .map(|(sum, null_count)| {
964                                sum.sum() / (idx.len() as f64 - null_count as f64)
965                            }),
966                            _ => {
967                                let take = { self.take_unchecked(idx) };
968                                take.mean()
969                            },
970                        }
971                    }
972                })
973            },
974            GroupsType::Slice {
975                groups: groups_slice,
976                overlapping,
977                monotonic,
978            } => {
979                if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, self.chunks()) {
980                    let ca = self
981                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
982                        .unwrap();
983                    ca.agg_mean(groups)
984                } else {
985                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
986                        debug_assert!(first + len <= self.len() as IdxSize);
987                        match len {
988                            0 => None,
989                            1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
990                            _ => {
991                                let arr_group = _slice_from_offsets(self, first, len);
992                                arr_group.mean()
993                            },
994                        }
995                    })
996                }
997            },
998        }
999    }
1000
1001    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series {
1002        match groups {
1003            GroupsType::Idx(groups) => {
1004                let ca_self = self.rechunk();
1005                let arr = ca_self.downcast_iter().next().unwrap();
1006                let no_nulls = arr.null_count() == 0;
1007                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1008                    debug_assert!(idx.len() <= arr.len());
1009                    if idx.is_empty() {
1010                        return None;
1011                    }
1012                    if no_nulls {
1013                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1014                    } else {
1015                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1016                    }
1017                })
1018            },
1019            GroupsType::Slice {
1020                groups: groups_slice,
1021                overlapping,
1022                monotonic,
1023            } => {
1024                if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, self.chunks()) {
1025                    let ca = self
1026                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1027                        .unwrap();
1028                    ca.agg_var(groups, ddof)
1029                } else {
1030                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1031                        debug_assert!(first + len <= self.len() as IdxSize);
1032                        match len {
1033                            0 => None,
1034                            1 => {
1035                                if ddof == 0 {
1036                                    NumCast::from(0)
1037                                } else {
1038                                    None
1039                                }
1040                            },
1041                            _ => {
1042                                let arr_group = _slice_from_offsets(self, first, len);
1043                                arr_group.var(ddof)
1044                            },
1045                        }
1046                    })
1047                }
1048            },
1049        }
1050    }
1051    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series {
1052        match groups {
1053            GroupsType::Idx(groups) => {
1054                let ca_self = self.rechunk();
1055                let arr = ca_self.downcast_iter().next().unwrap();
1056                let no_nulls = arr.null_count() == 0;
1057                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1058                    debug_assert!(idx.len() <= self.len());
1059                    if idx.is_empty() {
1060                        return None;
1061                    }
1062                    let out = if no_nulls {
1063                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1064                    } else {
1065                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1066                    };
1067                    out.map(|v| v.sqrt())
1068                })
1069            },
1070            GroupsType::Slice {
1071                groups: groups_slice,
1072                overlapping,
1073                monotonic,
1074            } => {
1075                if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, self.chunks()) {
1076                    let ca = self
1077                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1078                        .unwrap();
1079                    ca.agg_std(groups, ddof)
1080                } else {
1081                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1082                        debug_assert!(first + len <= self.len() as IdxSize);
1083                        match len {
1084                            0 => None,
1085                            1 => {
1086                                if ddof == 0 {
1087                                    NumCast::from(0)
1088                                } else {
1089                                    None
1090                                }
1091                            },
1092                            _ => {
1093                                let arr_group = _slice_from_offsets(self, first, len);
1094                                arr_group.std(ddof)
1095                            },
1096                        }
1097                    })
1098                }
1099            },
1100        }
1101    }
1102
1103    pub(crate) unsafe fn agg_quantile(
1104        &self,
1105        groups: &GroupsType,
1106        quantile: f64,
1107        method: QuantileMethod,
1108    ) -> Series {
1109        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
1110    }
1111    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
1112        agg_median_generic::<_, Float64Type>(self, groups)
1113    }
1114}