polars_core/frame/group_by/aggregations/
mod.rs

1mod agg_list;
2mod boolean;
3mod dispatch;
4mod string;
5
6use std::borrow::Cow;
7use std::cmp::Ordering;
8
9pub use agg_list::*;
10use arrow::bitmap::{Bitmap, MutableBitmap};
11use arrow::legacy::kernels::take_agg::*;
12use arrow::legacy::trusted_len::TrustedLenPush;
13use arrow::types::NativeType;
14use num_traits::pow::Pow;
15use num_traits::{Bounded, Float, Num, NumCast, ToPrimitive, Zero};
16use polars_compute::rolling::no_nulls::{
17    MaxWindow, MeanWindow, MinWindow, QuantileWindow, RollingAggWindowNoNulls, SumWindow, VarWindow,
18};
19use polars_compute::rolling::nulls::RollingAggWindowNulls;
20use polars_compute::rolling::quantile_filter::SealedRolling;
21use polars_compute::rolling::{
22    self, QuantileMethod, RollingFnParams, RollingQuantileParams, RollingVarParams, quantile_filter,
23};
24use polars_utils::float::IsFloat;
25use polars_utils::idx_vec::IdxVec;
26use polars_utils::ord::{compare_fn_nan_max, compare_fn_nan_min};
27use rayon::prelude::*;
28
29use crate::chunked_array::cast::CastOptions;
30#[cfg(feature = "object")]
31use crate::chunked_array::object::extension::create_extension;
32use crate::frame::group_by::GroupsIdx;
33#[cfg(feature = "object")]
34use crate::frame::group_by::GroupsIndicator;
35use crate::prelude::*;
36use crate::series::IsSorted;
37use crate::series::implementations::SeriesWrap;
38use crate::utils::NoNull;
39use crate::{POOL, apply_method_physical_integer};
40
41fn idx2usize(idx: &[IdxSize]) -> impl ExactSizeIterator<Item = usize> + '_ {
42    idx.iter().map(|i| *i as usize)
43}
44
45// if the windows overlap, we can use the rolling_<agg> kernels
46// they maintain state, which saves a lot of compute by not naively traversing all elements every
47// window
48//
49// if the windows don't overlap, we should not use these kernels as they are single threaded, so
50// we miss out on easy parallelization.
51pub fn _use_rolling_kernels(groups: &GroupsSlice, chunks: &[ArrayRef]) -> bool {
52    match groups.len() {
53        0 | 1 => false,
54        _ => {
55            let [first_offset, first_len] = groups[0];
56            let second_offset = groups[1][0];
57
58            second_offset >= first_offset // Prevent false positive from regular group-by that has out of order slices.
59                                          // Rolling group-by is expected to have monotonically increasing slices.
60                && second_offset < (first_offset + first_len)
61                && chunks.len() == 1
62        },
63    }
64}
65
66// Use an aggregation window that maintains the state
67pub fn _rolling_apply_agg_window_nulls<'a, Agg, T, O>(
68    values: &'a [T],
69    validity: &'a Bitmap,
70    offsets: O,
71    params: Option<RollingFnParams>,
72) -> PrimitiveArray<T>
73where
74    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
75    Agg: RollingAggWindowNulls<'a, T>,
76    T: IsFloat + NativeType,
77{
78    if values.is_empty() {
79        let out: Vec<T> = vec![];
80        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
81    }
82
83    // This iterators length can be trusted
84    // these represent the number of groups in the group_by operation
85    let output_len = offsets.size_hint().0;
86    // start with a dummy index, will be overwritten on first iteration.
87    // SAFETY:
88    // we are in bounds
89    let mut agg_window = unsafe { Agg::new(values, validity, 0, 0, params) };
90
91    let mut validity = MutableBitmap::with_capacity(output_len);
92    validity.extend_constant(output_len, true);
93
94    let out = offsets
95        .enumerate()
96        .map(|(idx, (start, len))| {
97            let end = start + len;
98
99            // SAFETY:
100            // we are in bounds
101
102            let agg = if start == end {
103                None
104            } else {
105                unsafe { agg_window.update(start as usize, end as usize) }
106            };
107
108            match agg {
109                Some(val) => val,
110                None => {
111                    // SAFETY: we are in bounds
112                    unsafe { validity.set_unchecked(idx, false) };
113                    T::default()
114                },
115            }
116        })
117        .collect_trusted::<Vec<_>>();
118
119    PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), Some(validity.into()))
120}
121
122// Use an aggregation window that maintains the state.
123pub fn _rolling_apply_agg_window_no_nulls<'a, Agg, T, O>(
124    values: &'a [T],
125    offsets: O,
126    params: Option<RollingFnParams>,
127) -> PrimitiveArray<T>
128where
129    // items (offset, len) -> so offsets are offset, offset + len
130    Agg: RollingAggWindowNoNulls<'a, T>,
131    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
132    T: IsFloat + NativeType,
133{
134    if values.is_empty() {
135        let out: Vec<T> = vec![];
136        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
137    }
138    // start with a dummy index, will be overwritten on first iteration.
139    let mut agg_window = Agg::new(values, 0, 0, params);
140
141    offsets
142        .map(|(start, len)| {
143            let end = start + len;
144
145            if start == end {
146                None
147            } else {
148                // SAFETY: we are in bounds.
149                unsafe { agg_window.update(start as usize, end as usize) }
150            }
151        })
152        .collect::<PrimitiveArray<T>>()
153}
154
155pub fn _slice_from_offsets<T>(ca: &ChunkedArray<T>, first: IdxSize, len: IdxSize) -> ChunkedArray<T>
156where
157    T: PolarsDataType,
158{
159    ca.slice(first as i64, len as usize)
160}
161
162/// Helper that combines the groups into a parallel iterator over `(first, all): (u32, &Vec<u32>)`.
163pub fn _agg_helper_idx<T, F>(groups: &GroupsIdx, f: F) -> Series
164where
165    F: Fn((IdxSize, &IdxVec)) -> Option<T::Native> + Send + Sync,
166    T: PolarsNumericType,
167    ChunkedArray<T>: IntoSeries,
168{
169    let ca: ChunkedArray<T> = POOL.install(|| groups.into_par_iter().map(f).collect());
170    ca.into_series()
171}
172
173/// Same helper as `_agg_helper_idx` but for aggregations that don't return an Option.
174pub fn _agg_helper_idx_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
175where
176    F: Fn((IdxSize, &IdxVec)) -> T::Native + Send + Sync,
177    T: PolarsNumericType,
178    ChunkedArray<T>: IntoSeries,
179{
180    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.into_par_iter().map(f).collect());
181    ca.into_inner().into_series()
182}
183
184/// Helper that iterates on the `all: Vec<Vec<u32>` collection,
185/// this doesn't have traverse the `first: Vec<u32>` memory and is therefore faster.
186fn agg_helper_idx_on_all<T, F>(groups: &GroupsIdx, f: F) -> Series
187where
188    F: Fn(&IdxVec) -> Option<T::Native> + Send + Sync,
189    T: PolarsNumericType,
190    ChunkedArray<T>: IntoSeries,
191{
192    let ca: ChunkedArray<T> = POOL.install(|| groups.all().into_par_iter().map(f).collect());
193    ca.into_series()
194}
195
196/// Same as `agg_helper_idx_on_all` but for aggregations that don't return an Option.
197fn agg_helper_idx_on_all_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
198where
199    F: Fn(&IdxVec) -> T::Native + Send + Sync,
200    T: PolarsNumericType,
201    ChunkedArray<T>: IntoSeries,
202{
203    let ca: NoNull<ChunkedArray<T>> =
204        POOL.install(|| groups.all().into_par_iter().map(f).collect());
205    ca.into_inner().into_series()
206}
207
208pub fn _agg_helper_slice<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
209where
210    F: Fn([IdxSize; 2]) -> Option<T::Native> + Send + Sync,
211    T: PolarsNumericType,
212    ChunkedArray<T>: IntoSeries,
213{
214    let ca: ChunkedArray<T> = POOL.install(|| groups.par_iter().copied().map(f).collect());
215    ca.into_series()
216}
217
218pub fn _agg_helper_slice_no_null<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
219where
220    F: Fn([IdxSize; 2]) -> T::Native + Send + Sync,
221    T: PolarsNumericType,
222    ChunkedArray<T>: IntoSeries,
223{
224    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.par_iter().copied().map(f).collect());
225    ca.into_inner().into_series()
226}
227
228pub trait TakeExtremum {
229    fn take_min(self, other: Self) -> Self;
230
231    fn take_max(self, other: Self) -> Self;
232}
233
234macro_rules! impl_take_extremum {
235    ($tp:ty) => {
236        impl TakeExtremum for $tp {
237            #[inline(always)]
238            fn take_min(self, other: Self) -> Self {
239                if self < other { self } else { other }
240            }
241
242            #[inline(always)]
243            fn take_max(self, other: Self) -> Self {
244                if self > other { self } else { other }
245            }
246        }
247    };
248
249    (float: $tp:ty) => {
250        impl TakeExtremum for $tp {
251            #[inline(always)]
252            fn take_min(self, other: Self) -> Self {
253                if matches!(compare_fn_nan_max(&self, &other), Ordering::Less) {
254                    self
255                } else {
256                    other
257                }
258            }
259
260            #[inline(always)]
261            fn take_max(self, other: Self) -> Self {
262                if matches!(compare_fn_nan_min(&self, &other), Ordering::Greater) {
263                    self
264                } else {
265                    other
266                }
267            }
268        }
269    };
270}
271
272#[cfg(feature = "dtype-u8")]
273impl_take_extremum!(u8);
274#[cfg(feature = "dtype-u16")]
275impl_take_extremum!(u16);
276impl_take_extremum!(u32);
277impl_take_extremum!(u64);
278#[cfg(feature = "dtype-i8")]
279impl_take_extremum!(i8);
280#[cfg(feature = "dtype-i16")]
281impl_take_extremum!(i16);
282impl_take_extremum!(i32);
283impl_take_extremum!(i64);
284#[cfg(any(feature = "dtype-decimal", feature = "dtype-i128"))]
285impl_take_extremum!(i128);
286impl_take_extremum!(float: f32);
287impl_take_extremum!(float: f64);
288
289/// Intermediate helper trait so we can have a single generic implementation
290/// This trait will ensure the specific dispatch works without complicating
291/// the trait bounds.
292trait QuantileDispatcher<K> {
293    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<K>>;
294
295    fn _median(self) -> Option<K>;
296}
297
298impl<T> QuantileDispatcher<f64> for ChunkedArray<T>
299where
300    T: PolarsIntegerType,
301    T::Native: Ord,
302    ChunkedArray<T>: IntoSeries,
303{
304    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
305        self.quantile_faster(quantile, method)
306    }
307    fn _median(self) -> Option<f64> {
308        self.median_faster()
309    }
310}
311
312impl QuantileDispatcher<f32> for Float32Chunked {
313    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f32>> {
314        self.quantile_faster(quantile, method)
315    }
316    fn _median(self) -> Option<f32> {
317        self.median_faster()
318    }
319}
320impl QuantileDispatcher<f64> for Float64Chunked {
321    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
322        self.quantile_faster(quantile, method)
323    }
324    fn _median(self) -> Option<f64> {
325        self.median_faster()
326    }
327}
328
329unsafe fn agg_quantile_generic<T, K>(
330    ca: &ChunkedArray<T>,
331    groups: &GroupsType,
332    quantile: f64,
333    method: QuantileMethod,
334) -> Series
335where
336    T: PolarsNumericType,
337    ChunkedArray<T>: QuantileDispatcher<K::Native>,
338    ChunkedArray<K>: IntoSeries,
339    K: PolarsNumericType,
340    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + quantile_filter::SealedRolling,
341{
342    let invalid_quantile = !(0.0..=1.0).contains(&quantile);
343    if invalid_quantile {
344        return Series::full_null(ca.name().clone(), groups.len(), ca.dtype());
345    }
346    match groups {
347        GroupsType::Idx(groups) => {
348            let ca = ca.rechunk();
349            agg_helper_idx_on_all::<K, _>(groups, |idx| {
350                debug_assert!(idx.len() <= ca.len());
351                if idx.is_empty() {
352                    return None;
353                }
354                let take = { ca.take_unchecked(idx) };
355                // checked with invalid quantile check
356                take._quantile(quantile, method).unwrap_unchecked()
357            })
358        },
359        GroupsType::Slice { groups, .. } => {
360            if _use_rolling_kernels(groups, ca.chunks()) {
361                // this cast is a no-op for floats
362                let s = ca
363                    .cast_with_options(&K::get_dtype(), CastOptions::Overflowing)
364                    .unwrap();
365                let ca: &ChunkedArray<K> = s.as_ref().as_ref();
366                let arr = ca.downcast_iter().next().unwrap();
367                let values = arr.values().as_slice();
368                let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
369                let arr = match arr.validity() {
370                    None => _rolling_apply_agg_window_no_nulls::<QuantileWindow<_>, _, _>(
371                        values,
372                        offset_iter,
373                        Some(RollingFnParams::Quantile(RollingQuantileParams {
374                            prob: quantile,
375                            method,
376                        })),
377                    ),
378                    Some(validity) => {
379                        _rolling_apply_agg_window_nulls::<rolling::nulls::QuantileWindow<_>, _, _>(
380                            values,
381                            validity,
382                            offset_iter,
383                            Some(RollingFnParams::Quantile(RollingQuantileParams {
384                                prob: quantile,
385                                method,
386                            })),
387                        )
388                    },
389                };
390                // The rolling kernels works on the dtype, this is not yet the
391                // float output type we need.
392                ChunkedArray::from(arr).into_series()
393            } else {
394                _agg_helper_slice::<K, _>(groups, |[first, len]| {
395                    debug_assert!(first + len <= ca.len() as IdxSize);
396                    match len {
397                        0 => None,
398                        1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()),
399                        _ => {
400                            let arr_group = _slice_from_offsets(ca, first, len);
401                            // unwrap checked with invalid quantile check
402                            arr_group
403                                ._quantile(quantile, method)
404                                .unwrap_unchecked()
405                                .map(|flt| NumCast::from(flt).unwrap_unchecked())
406                        },
407                    }
408                })
409            }
410        },
411    }
412}
413
414unsafe fn agg_median_generic<T, K>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series
415where
416    T: PolarsNumericType,
417    ChunkedArray<T>: QuantileDispatcher<K::Native>,
418    ChunkedArray<K>: IntoSeries,
419    K: PolarsNumericType,
420    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + SealedRolling,
421{
422    match groups {
423        GroupsType::Idx(groups) => {
424            let ca = ca.rechunk();
425            agg_helper_idx_on_all::<K, _>(groups, |idx| {
426                debug_assert!(idx.len() <= ca.len());
427                if idx.is_empty() {
428                    return None;
429                }
430                let take = { ca.take_unchecked(idx) };
431                take._median()
432            })
433        },
434        GroupsType::Slice { .. } => {
435            agg_quantile_generic::<T, K>(ca, groups, 0.5, QuantileMethod::Linear)
436        },
437    }
438}
439
440/// # Safety
441///
442/// No bounds checks on `groups`.
443#[cfg(feature = "bitwise")]
444unsafe fn bitwise_agg<T: PolarsNumericType>(
445    ca: &ChunkedArray<T>,
446    groups: &GroupsType,
447    f: fn(&ChunkedArray<T>) -> Option<T::Native>,
448) -> Series
449where
450    ChunkedArray<T>:
451        ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native> + IntoSeries,
452{
453    // Prevent a rechunk for every individual group.
454
455    let s = if groups.len() > 1 {
456        ca.rechunk()
457    } else {
458        Cow::Borrowed(ca)
459    };
460
461    match groups {
462        GroupsType::Idx(groups) => agg_helper_idx_on_all::<T, _>(groups, |idx| {
463            debug_assert!(idx.len() <= s.len());
464            if idx.is_empty() {
465                None
466            } else {
467                let take = unsafe { s.take_unchecked(idx) };
468                f(&take)
469            }
470        }),
471        GroupsType::Slice { groups, .. } => _agg_helper_slice::<T, _>(groups, |[first, len]| {
472            debug_assert!(len <= s.len() as IdxSize);
473            if len == 0 {
474                None
475            } else {
476                let take = _slice_from_offsets(&s, first, len);
477                f(&take)
478            }
479        }),
480    }
481}
482
483#[cfg(feature = "bitwise")]
484impl<T> ChunkedArray<T>
485where
486    T: PolarsNumericType,
487    ChunkedArray<T>:
488        ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native> + IntoSeries,
489{
490    /// # Safety
491    ///
492    /// No bounds checks on `groups`.
493    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series {
494        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce) }
495    }
496
497    /// # Safety
498    ///
499    /// No bounds checks on `groups`.
500    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series {
501        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce) }
502    }
503
504    /// # Safety
505    ///
506    /// No bounds checks on `groups`.
507    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series {
508        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce) }
509    }
510}
511
512impl<T> ChunkedArray<T>
513where
514    T: PolarsNumericType + Sync,
515    T::Native: NativeType
516        + PartialOrd
517        + Num
518        + NumCast
519        + Zero
520        + Bounded
521        + std::iter::Sum<T::Native>
522        + TakeExtremum,
523    ChunkedArray<T>: IntoSeries + ChunkAgg<T::Native>,
524{
525    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
526        // faster paths
527        match (self.is_sorted_flag(), self.null_count()) {
528            (IsSorted::Ascending, 0) => {
529                return self.clone().into_series().agg_first(groups);
530            },
531            (IsSorted::Descending, 0) => {
532                return self.clone().into_series().agg_last(groups);
533            },
534            _ => {},
535        }
536        match groups {
537            GroupsType::Idx(groups) => {
538                let ca = self.rechunk();
539                let arr = ca.downcast_iter().next().unwrap();
540                let no_nulls = arr.null_count() == 0;
541                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
542                    debug_assert!(idx.len() <= arr.len());
543                    if idx.is_empty() {
544                        None
545                    } else if idx.len() == 1 {
546                        arr.get(first as usize)
547                    } else if no_nulls {
548                        take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
549                            arr,
550                            idx2usize(idx),
551                            |a, b| a.take_min(b),
552                        )
553                    } else {
554                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a.take_min(b))
555                    }
556                })
557            },
558            GroupsType::Slice {
559                groups: groups_slice,
560                ..
561            } => {
562                if _use_rolling_kernels(groups_slice, self.chunks()) {
563                    let arr = self.downcast_iter().next().unwrap();
564                    let values = arr.values().as_slice();
565                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
566                    let arr = match arr.validity() {
567                        None => _rolling_apply_agg_window_no_nulls::<MinWindow<_>, _, _>(
568                            values,
569                            offset_iter,
570                            None,
571                        ),
572                        Some(validity) => _rolling_apply_agg_window_nulls::<
573                            rolling::nulls::MinWindow<_>,
574                            _,
575                            _,
576                        >(
577                            values, validity, offset_iter, None
578                        ),
579                    };
580                    Self::from(arr).into_series()
581                } else {
582                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
583                        debug_assert!(len <= self.len() as IdxSize);
584                        match len {
585                            0 => None,
586                            1 => self.get(first as usize),
587                            _ => {
588                                let arr_group = _slice_from_offsets(self, first, len);
589                                ChunkAgg::min(&arr_group)
590                            },
591                        }
592                    })
593                }
594            },
595        }
596    }
597
598    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
599        // faster paths
600        match (self.is_sorted_flag(), self.null_count()) {
601            (IsSorted::Ascending, 0) => {
602                return self.clone().into_series().agg_last(groups);
603            },
604            (IsSorted::Descending, 0) => {
605                return self.clone().into_series().agg_first(groups);
606            },
607            _ => {},
608        }
609
610        match groups {
611            GroupsType::Idx(groups) => {
612                let ca = self.rechunk();
613                let arr = ca.downcast_iter().next().unwrap();
614                let no_nulls = arr.null_count() == 0;
615                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
616                    debug_assert!(idx.len() <= arr.len());
617                    if idx.is_empty() {
618                        None
619                    } else if idx.len() == 1 {
620                        arr.get(first as usize)
621                    } else if no_nulls {
622                        take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
623                            arr,
624                            idx2usize(idx),
625                            |a, b| a.take_max(b),
626                        )
627                    } else {
628                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a.take_max(b))
629                    }
630                })
631            },
632            GroupsType::Slice {
633                groups: groups_slice,
634                ..
635            } => {
636                if _use_rolling_kernels(groups_slice, self.chunks()) {
637                    let arr = self.downcast_iter().next().unwrap();
638                    let values = arr.values().as_slice();
639                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
640                    let arr = match arr.validity() {
641                        None => _rolling_apply_agg_window_no_nulls::<MaxWindow<_>, _, _>(
642                            values,
643                            offset_iter,
644                            None,
645                        ),
646                        Some(validity) => _rolling_apply_agg_window_nulls::<
647                            rolling::nulls::MaxWindow<_>,
648                            _,
649                            _,
650                        >(
651                            values, validity, offset_iter, None
652                        ),
653                    };
654                    Self::from(arr).into_series()
655                } else {
656                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
657                        debug_assert!(len <= self.len() as IdxSize);
658                        match len {
659                            0 => None,
660                            1 => self.get(first as usize),
661                            _ => {
662                                let arr_group = _slice_from_offsets(self, first, len);
663                                ChunkAgg::max(&arr_group)
664                            },
665                        }
666                    })
667                }
668            },
669        }
670    }
671
672    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
673        match groups {
674            GroupsType::Idx(groups) => {
675                let ca = self.rechunk();
676                let arr = ca.downcast_iter().next().unwrap();
677                let no_nulls = arr.null_count() == 0;
678                _agg_helper_idx_no_null::<T, _>(groups, |(first, idx)| {
679                    debug_assert!(idx.len() <= self.len());
680                    if idx.is_empty() {
681                        T::Native::zero()
682                    } else if idx.len() == 1 {
683                        arr.get(first as usize).unwrap_or(T::Native::zero())
684                    } else if no_nulls {
685                        take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a + b)
686                            .unwrap_or(T::Native::zero())
687                    } else {
688                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a + b)
689                            .unwrap_or(T::Native::zero())
690                    }
691                })
692            },
693            GroupsType::Slice { groups, .. } => {
694                if _use_rolling_kernels(groups, self.chunks()) {
695                    let arr = self.downcast_iter().next().unwrap();
696                    let values = arr.values().as_slice();
697                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
698                    let arr = match arr.validity() {
699                        None => _rolling_apply_agg_window_no_nulls::<SumWindow<_>, _, _>(
700                            values,
701                            offset_iter,
702                            None,
703                        ),
704                        Some(validity) => _rolling_apply_agg_window_nulls::<
705                            rolling::nulls::SumWindow<_>,
706                            _,
707                            _,
708                        >(
709                            values, validity, offset_iter, None
710                        ),
711                    };
712                    Self::from(arr).into_series()
713                } else {
714                    _agg_helper_slice_no_null::<T, _>(groups, |[first, len]| {
715                        debug_assert!(len <= self.len() as IdxSize);
716                        match len {
717                            0 => T::Native::zero(),
718                            1 => self.get(first as usize).unwrap_or(T::Native::zero()),
719                            _ => {
720                                let arr_group = _slice_from_offsets(self, first, len);
721                                arr_group.sum().unwrap_or(T::Native::zero())
722                            },
723                        }
724                    })
725                }
726            },
727        }
728    }
729}
730
731impl<T> SeriesWrap<ChunkedArray<T>>
732where
733    T: PolarsFloatType,
734    ChunkedArray<T>: IntoSeries
735        + ChunkVar
736        + VarAggSeries
737        + ChunkQuantile<T::Native>
738        + QuantileAggSeries
739        + ChunkAgg<T::Native>,
740    T::Native: Pow<T::Native, Output = T::Native>,
741{
742    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
743        match groups {
744            GroupsType::Idx(groups) => {
745                let ca = self.rechunk();
746                let arr = ca.downcast_iter().next().unwrap();
747                let no_nulls = arr.null_count() == 0;
748                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
749                    // this can fail due to a bug in lazy code.
750                    // here users can create filters in aggregations
751                    // and thereby creating shorter columns than the original group tuples.
752                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
753                    // access
754                    debug_assert!(idx.len() <= self.len());
755                    let out = if idx.is_empty() {
756                        None
757                    } else if idx.len() == 1 {
758                        arr.get(first as usize).map(|sum| sum.to_f64().unwrap())
759                    } else if no_nulls {
760                        take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
761                            arr,
762                            idx2usize(idx),
763                            |a, b| a + b,
764                        )
765                        .unwrap()
766                        .to_f64()
767                        .map(|sum| sum / idx.len() as f64)
768                    } else {
769                        take_agg_primitive_iter_unchecked_count_nulls::<T::Native, _, _, _>(
770                            arr,
771                            idx2usize(idx),
772                            |a, b| a + b,
773                            T::Native::zero(),
774                            idx.len() as IdxSize,
775                        )
776                        .map(|(sum, null_count)| {
777                            sum.to_f64()
778                                .map(|sum| sum / (idx.len() as f64 - null_count as f64))
779                                .unwrap()
780                        })
781                    };
782                    out.map(|flt| NumCast::from(flt).unwrap())
783                })
784            },
785            GroupsType::Slice { groups, .. } => {
786                if _use_rolling_kernels(groups, self.chunks()) {
787                    let arr = self.downcast_iter().next().unwrap();
788                    let values = arr.values().as_slice();
789                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
790                    let arr = match arr.validity() {
791                        None => _rolling_apply_agg_window_no_nulls::<MeanWindow<_>, _, _>(
792                            values,
793                            offset_iter,
794                            None,
795                        ),
796                        Some(validity) => _rolling_apply_agg_window_nulls::<
797                            rolling::nulls::MeanWindow<_>,
798                            _,
799                            _,
800                        >(
801                            values, validity, offset_iter, None
802                        ),
803                    };
804                    ChunkedArray::from(arr).into_series()
805                } else {
806                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
807                        debug_assert!(len <= self.len() as IdxSize);
808                        match len {
809                            0 => None,
810                            1 => self.get(first as usize),
811                            _ => {
812                                let arr_group = _slice_from_offsets(self, first, len);
813                                arr_group.mean().map(|flt| NumCast::from(flt).unwrap())
814                            },
815                        }
816                    })
817                }
818            },
819        }
820    }
821
822    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series
823    where
824        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
825    {
826        let ca = &self.0.rechunk();
827        match groups {
828            GroupsType::Idx(groups) => {
829                let ca = ca.rechunk();
830                let arr = ca.downcast_iter().next().unwrap();
831                let no_nulls = arr.null_count() == 0;
832                agg_helper_idx_on_all::<T, _>(groups, |idx| {
833                    debug_assert!(idx.len() <= ca.len());
834                    if idx.is_empty() {
835                        return None;
836                    }
837                    let out = if no_nulls {
838                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
839                    } else {
840                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
841                    };
842                    out.map(|flt| NumCast::from(flt).unwrap())
843                })
844            },
845            GroupsType::Slice { groups, .. } => {
846                if _use_rolling_kernels(groups, self.chunks()) {
847                    let arr = self.downcast_iter().next().unwrap();
848                    let values = arr.values().as_slice();
849                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
850                    let arr = match arr.validity() {
851                        None => _rolling_apply_agg_window_no_nulls::<VarWindow<_>, _, _>(
852                            values,
853                            offset_iter,
854                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
855                        ),
856                        Some(validity) => {
857                            _rolling_apply_agg_window_nulls::<rolling::nulls::VarWindow<_>, _, _>(
858                                values,
859                                validity,
860                                offset_iter,
861                                Some(RollingFnParams::Var(RollingVarParams { ddof })),
862                            )
863                        },
864                    };
865                    ChunkedArray::from(arr).into_series()
866                } else {
867                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
868                        debug_assert!(len <= self.len() as IdxSize);
869                        match len {
870                            0 => None,
871                            1 => {
872                                if ddof == 0 {
873                                    NumCast::from(0)
874                                } else {
875                                    None
876                                }
877                            },
878                            _ => {
879                                let arr_group = _slice_from_offsets(self, first, len);
880                                arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap())
881                            },
882                        }
883                    })
884                }
885            },
886        }
887    }
888    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series
889    where
890        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
891    {
892        let ca = &self.0.rechunk();
893        match groups {
894            GroupsType::Idx(groups) => {
895                let arr = ca.downcast_iter().next().unwrap();
896                let no_nulls = arr.null_count() == 0;
897                agg_helper_idx_on_all::<T, _>(groups, |idx| {
898                    debug_assert!(idx.len() <= ca.len());
899                    if idx.is_empty() {
900                        return None;
901                    }
902                    let out = if no_nulls {
903                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
904                    } else {
905                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
906                    };
907                    out.map(|flt| NumCast::from(flt.sqrt()).unwrap())
908                })
909            },
910            GroupsType::Slice { groups, .. } => {
911                if _use_rolling_kernels(groups, self.chunks()) {
912                    let arr = ca.downcast_iter().next().unwrap();
913                    let values = arr.values().as_slice();
914                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
915                    let arr = match arr.validity() {
916                        None => _rolling_apply_agg_window_no_nulls::<VarWindow<_>, _, _>(
917                            values,
918                            offset_iter,
919                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
920                        ),
921                        Some(validity) => {
922                            _rolling_apply_agg_window_nulls::<rolling::nulls::VarWindow<_>, _, _>(
923                                values,
924                                validity,
925                                offset_iter,
926                                Some(RollingFnParams::Var(RollingVarParams { ddof })),
927                            )
928                        },
929                    };
930
931                    let mut ca = ChunkedArray::<T>::from(arr);
932                    ca.apply_mut(|v| v.powf(NumCast::from(0.5).unwrap()));
933                    ca.into_series()
934                } else {
935                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
936                        debug_assert!(len <= self.len() as IdxSize);
937                        match len {
938                            0 => None,
939                            1 => {
940                                if ddof == 0 {
941                                    NumCast::from(0)
942                                } else {
943                                    None
944                                }
945                            },
946                            _ => {
947                                let arr_group = _slice_from_offsets(self, first, len);
948                                arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap())
949                            },
950                        }
951                    })
952                }
953            },
954        }
955    }
956}
957
958impl Float32Chunked {
959    pub(crate) unsafe fn agg_quantile(
960        &self,
961        groups: &GroupsType,
962        quantile: f64,
963        method: QuantileMethod,
964    ) -> Series {
965        agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method)
966    }
967    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
968        agg_median_generic::<_, Float32Type>(self, groups)
969    }
970}
971impl Float64Chunked {
972    pub(crate) unsafe fn agg_quantile(
973        &self,
974        groups: &GroupsType,
975        quantile: f64,
976        method: QuantileMethod,
977    ) -> Series {
978        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
979    }
980    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
981        agg_median_generic::<_, Float64Type>(self, groups)
982    }
983}
984
985impl<T> ChunkedArray<T>
986where
987    T: PolarsIntegerType,
988    ChunkedArray<T>: IntoSeries + ChunkAgg<T::Native> + ChunkVar,
989    T::Native: NumericNative + Ord,
990{
991    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
992        match groups {
993            GroupsType::Idx(groups) => {
994                let ca = self.rechunk();
995                let arr = ca.downcast_get(0).unwrap();
996                _agg_helper_idx::<Float64Type, _>(groups, |(first, idx)| {
997                    // this can fail due to a bug in lazy code.
998                    // here users can create filters in aggregations
999                    // and thereby creating shorter columns than the original group tuples.
1000                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
1001                    // access
1002                    debug_assert!(idx.len() <= self.len());
1003                    if idx.is_empty() {
1004                        None
1005                    } else if idx.len() == 1 {
1006                        self.get(first as usize).map(|sum| sum.to_f64().unwrap())
1007                    } else {
1008                        match (self.has_nulls(), self.chunks.len()) {
1009                            (false, 1) => {
1010                                take_agg_no_null_primitive_iter_unchecked::<_, f64, _, _>(
1011                                    arr,
1012                                    idx2usize(idx),
1013                                    |a, b| a + b,
1014                                )
1015                                .map(|sum| sum / idx.len() as f64)
1016                            },
1017                            (_, 1) => {
1018                                {
1019                                    take_agg_primitive_iter_unchecked_count_nulls::<
1020                                        T::Native,
1021                                        f64,
1022                                        _,
1023                                        _,
1024                                    >(
1025                                        arr, idx2usize(idx), |a, b| a + b, 0.0, idx.len() as IdxSize
1026                                    )
1027                                }
1028                                .map(|(sum, null_count)| {
1029                                    sum / (idx.len() as f64 - null_count as f64)
1030                                })
1031                            },
1032                            _ => {
1033                                let take = { self.take_unchecked(idx) };
1034                                take.mean()
1035                            },
1036                        }
1037                    }
1038                })
1039            },
1040            GroupsType::Slice {
1041                groups: groups_slice,
1042                ..
1043            } => {
1044                if _use_rolling_kernels(groups_slice, self.chunks()) {
1045                    let ca = self
1046                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1047                        .unwrap();
1048                    ca.agg_mean(groups)
1049                } else {
1050                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1051                        debug_assert!(first + len <= self.len() as IdxSize);
1052                        match len {
1053                            0 => None,
1054                            1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
1055                            _ => {
1056                                let arr_group = _slice_from_offsets(self, first, len);
1057                                arr_group.mean()
1058                            },
1059                        }
1060                    })
1061                }
1062            },
1063        }
1064    }
1065
1066    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series {
1067        match groups {
1068            GroupsType::Idx(groups) => {
1069                let ca_self = self.rechunk();
1070                let arr = ca_self.downcast_iter().next().unwrap();
1071                let no_nulls = arr.null_count() == 0;
1072                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1073                    debug_assert!(idx.len() <= arr.len());
1074                    if idx.is_empty() {
1075                        return None;
1076                    }
1077                    if no_nulls {
1078                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1079                    } else {
1080                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1081                    }
1082                })
1083            },
1084            GroupsType::Slice {
1085                groups: groups_slice,
1086                ..
1087            } => {
1088                if _use_rolling_kernels(groups_slice, self.chunks()) {
1089                    let ca = self
1090                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1091                        .unwrap();
1092                    ca.agg_var(groups, ddof)
1093                } else {
1094                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1095                        debug_assert!(first + len <= self.len() as IdxSize);
1096                        match len {
1097                            0 => None,
1098                            1 => {
1099                                if ddof == 0 {
1100                                    NumCast::from(0)
1101                                } else {
1102                                    None
1103                                }
1104                            },
1105                            _ => {
1106                                let arr_group = _slice_from_offsets(self, first, len);
1107                                arr_group.var(ddof)
1108                            },
1109                        }
1110                    })
1111                }
1112            },
1113        }
1114    }
1115    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series {
1116        match groups {
1117            GroupsType::Idx(groups) => {
1118                let ca_self = self.rechunk();
1119                let arr = ca_self.downcast_iter().next().unwrap();
1120                let no_nulls = arr.null_count() == 0;
1121                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1122                    debug_assert!(idx.len() <= self.len());
1123                    if idx.is_empty() {
1124                        return None;
1125                    }
1126                    let out = if no_nulls {
1127                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1128                    } else {
1129                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1130                    };
1131                    out.map(|v| v.sqrt())
1132                })
1133            },
1134            GroupsType::Slice {
1135                groups: groups_slice,
1136                ..
1137            } => {
1138                if _use_rolling_kernels(groups_slice, self.chunks()) {
1139                    let ca = self
1140                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1141                        .unwrap();
1142                    ca.agg_std(groups, ddof)
1143                } else {
1144                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1145                        debug_assert!(first + len <= self.len() as IdxSize);
1146                        match len {
1147                            0 => None,
1148                            1 => {
1149                                if ddof == 0 {
1150                                    NumCast::from(0)
1151                                } else {
1152                                    None
1153                                }
1154                            },
1155                            _ => {
1156                                let arr_group = _slice_from_offsets(self, first, len);
1157                                arr_group.std(ddof)
1158                            },
1159                        }
1160                    })
1161                }
1162            },
1163        }
1164    }
1165
1166    pub(crate) unsafe fn agg_quantile(
1167        &self,
1168        groups: &GroupsType,
1169        quantile: f64,
1170        method: QuantileMethod,
1171    ) -> Series {
1172        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
1173    }
1174    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
1175        agg_median_generic::<_, Float64Type>(self, groups)
1176    }
1177}