polars_core/frame/group_by/aggregations/
mod.rs

1mod agg_list;
2mod boolean;
3mod dispatch;
4mod string;
5
6use std::borrow::Cow;
7
8pub use agg_list::*;
9use arrow::bitmap::{Bitmap, MutableBitmap};
10use arrow::legacy::kernels::take_agg::*;
11use arrow::legacy::trusted_len::TrustedLenPush;
12use arrow::types::NativeType;
13use num_traits::pow::Pow;
14use num_traits::{Bounded, Float, Num, NumCast, ToPrimitive, Zero};
15use polars_compute::rolling::no_nulls::{
16    MaxWindow, MinWindow, MomentWindow, QuantileWindow, RollingAggWindowNoNulls,
17};
18use polars_compute::rolling::nulls::{RollingAggWindowNulls, VarianceMoment};
19use polars_compute::rolling::quantile_filter::SealedRolling;
20use polars_compute::rolling::{
21    self, MeanWindow, QuantileMethod, RollingFnParams, RollingQuantileParams, RollingVarParams,
22    SumWindow, quantile_filter,
23};
24use polars_utils::float::IsFloat;
25use polars_utils::idx_vec::IdxVec;
26use polars_utils::kahan_sum::KahanSum;
27use polars_utils::min_max::MinMax;
28use rayon::prelude::*;
29
30use crate::chunked_array::cast::CastOptions;
31#[cfg(feature = "object")]
32use crate::chunked_array::object::extension::create_extension;
33use crate::frame::group_by::GroupsIdx;
34#[cfg(feature = "object")]
35use crate::frame::group_by::GroupsIndicator;
36use crate::prelude::*;
37use crate::series::IsSorted;
38use crate::series::implementations::SeriesWrap;
39use crate::utils::NoNull;
40use crate::{POOL, apply_method_physical_integer};
41
42fn idx2usize(idx: &[IdxSize]) -> impl ExactSizeIterator<Item = usize> + '_ {
43    idx.iter().map(|i| *i as usize)
44}
45
46// if the windows overlap, we can use the rolling_<agg> kernels
47// they maintain state, which saves a lot of compute by not naively traversing all elements every
48// window
49//
50// if the windows don't overlap, we should not use these kernels as they are single threaded, so
51// we miss out on easy parallelization.
52pub fn _use_rolling_kernels(groups: &GroupsSlice, chunks: &[ArrayRef]) -> bool {
53    match groups.len() {
54        0 | 1 => false,
55        _ => {
56            let [first_offset, first_len] = groups[0];
57            let second_offset = groups[1][0];
58
59            second_offset >= first_offset // Prevent false positive from regular group-by that has out of order slices.
60                                          // Rolling group-by is expected to have monotonically increasing slices.
61                && second_offset < (first_offset + first_len)
62                && chunks.len() == 1
63        },
64    }
65}
66
67// Use an aggregation window that maintains the state
68pub fn _rolling_apply_agg_window_nulls<'a, Agg, T, O>(
69    values: &'a [T],
70    validity: &'a Bitmap,
71    offsets: O,
72    params: Option<RollingFnParams>,
73) -> PrimitiveArray<T>
74where
75    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
76    Agg: RollingAggWindowNulls<'a, T>,
77    T: IsFloat + NativeType,
78{
79    if values.is_empty() {
80        let out: Vec<T> = vec![];
81        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
82    }
83
84    // This iterators length can be trusted
85    // these represent the number of groups in the group_by operation
86    let output_len = offsets.size_hint().0;
87    // start with a dummy index, will be overwritten on first iteration.
88    // SAFETY:
89    // we are in bounds
90    let mut agg_window = unsafe { Agg::new(values, validity, 0, 0, params, None) };
91
92    let mut validity = MutableBitmap::with_capacity(output_len);
93    validity.extend_constant(output_len, true);
94
95    let out = offsets
96        .enumerate()
97        .map(|(idx, (start, len))| {
98            let end = start + len;
99
100            // SAFETY:
101            // we are in bounds
102
103            let agg = if start == end {
104                None
105            } else {
106                unsafe { agg_window.update(start as usize, end as usize) }
107            };
108
109            match agg {
110                Some(val) => val,
111                None => {
112                    // SAFETY: we are in bounds
113                    unsafe { validity.set_unchecked(idx, false) };
114                    T::default()
115                },
116            }
117        })
118        .collect_trusted::<Vec<_>>();
119
120    PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), Some(validity.into()))
121}
122
123// Use an aggregation window that maintains the state.
124pub fn _rolling_apply_agg_window_no_nulls<'a, Agg, T, O>(
125    values: &'a [T],
126    offsets: O,
127    params: Option<RollingFnParams>,
128) -> PrimitiveArray<T>
129where
130    // items (offset, len) -> so offsets are offset, offset + len
131    Agg: RollingAggWindowNoNulls<'a, T>,
132    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
133    T: IsFloat + NativeType,
134{
135    if values.is_empty() {
136        let out: Vec<T> = vec![];
137        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
138    }
139    // start with a dummy index, will be overwritten on first iteration.
140    let mut agg_window = Agg::new(values, 0, 0, params, None);
141
142    offsets
143        .map(|(start, len)| {
144            let end = start + len;
145
146            if start == end {
147                None
148            } else {
149                // SAFETY: we are in bounds.
150                unsafe { agg_window.update(start as usize, end as usize) }
151            }
152        })
153        .collect::<PrimitiveArray<T>>()
154}
155
156pub fn _slice_from_offsets<T>(ca: &ChunkedArray<T>, first: IdxSize, len: IdxSize) -> ChunkedArray<T>
157where
158    T: PolarsDataType,
159{
160    ca.slice(first as i64, len as usize)
161}
162
163/// Helper that combines the groups into a parallel iterator over `(first, all): (u32, &Vec<u32>)`.
164pub fn _agg_helper_idx<T, F>(groups: &GroupsIdx, f: F) -> Series
165where
166    F: Fn((IdxSize, &IdxVec)) -> Option<T::Native> + Send + Sync,
167    T: PolarsNumericType,
168{
169    let ca: ChunkedArray<T> = POOL.install(|| groups.into_par_iter().map(f).collect());
170    ca.into_series()
171}
172
173/// Same helper as `_agg_helper_idx` but for aggregations that don't return an Option.
174pub fn _agg_helper_idx_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
175where
176    F: Fn((IdxSize, &IdxVec)) -> T::Native + Send + Sync,
177    T: PolarsNumericType,
178{
179    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.into_par_iter().map(f).collect());
180    ca.into_inner().into_series()
181}
182
183/// Helper that iterates on the `all: Vec<Vec<u32>` collection,
184/// this doesn't have traverse the `first: Vec<u32>` memory and is therefore faster.
185fn agg_helper_idx_on_all<T, F>(groups: &GroupsIdx, f: F) -> Series
186where
187    F: Fn(&IdxVec) -> Option<T::Native> + Send + Sync,
188    T: PolarsNumericType,
189{
190    let ca: ChunkedArray<T> = POOL.install(|| groups.all().into_par_iter().map(f).collect());
191    ca.into_series()
192}
193
194/// Same as `agg_helper_idx_on_all` but for aggregations that don't return an Option.
195fn agg_helper_idx_on_all_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
196where
197    F: Fn(&IdxVec) -> T::Native + Send + Sync,
198    T: PolarsNumericType,
199{
200    let ca: NoNull<ChunkedArray<T>> =
201        POOL.install(|| groups.all().into_par_iter().map(f).collect());
202    ca.into_inner().into_series()
203}
204
205pub fn _agg_helper_slice<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
206where
207    F: Fn([IdxSize; 2]) -> Option<T::Native> + Send + Sync,
208    T: PolarsNumericType,
209{
210    let ca: ChunkedArray<T> = POOL.install(|| groups.par_iter().copied().map(f).collect());
211    ca.into_series()
212}
213
214pub fn _agg_helper_slice_no_null<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
215where
216    F: Fn([IdxSize; 2]) -> T::Native + Send + Sync,
217    T: PolarsNumericType,
218{
219    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.par_iter().copied().map(f).collect());
220    ca.into_inner().into_series()
221}
222
223/// Intermediate helper trait so we can have a single generic implementation
224/// This trait will ensure the specific dispatch works without complicating
225/// the trait bounds.
226trait QuantileDispatcher<K> {
227    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<K>>;
228
229    fn _median(self) -> Option<K>;
230}
231
232impl<T> QuantileDispatcher<f64> for ChunkedArray<T>
233where
234    T: PolarsIntegerType,
235    T::Native: Ord,
236{
237    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
238        self.quantile_faster(quantile, method)
239    }
240    fn _median(self) -> Option<f64> {
241        self.median_faster()
242    }
243}
244
245impl QuantileDispatcher<f32> for Float32Chunked {
246    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f32>> {
247        self.quantile_faster(quantile, method)
248    }
249    fn _median(self) -> Option<f32> {
250        self.median_faster()
251    }
252}
253impl QuantileDispatcher<f64> for Float64Chunked {
254    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
255        self.quantile_faster(quantile, method)
256    }
257    fn _median(self) -> Option<f64> {
258        self.median_faster()
259    }
260}
261
262unsafe fn agg_quantile_generic<T, K>(
263    ca: &ChunkedArray<T>,
264    groups: &GroupsType,
265    quantile: f64,
266    method: QuantileMethod,
267) -> Series
268where
269    T: PolarsNumericType,
270    ChunkedArray<T>: QuantileDispatcher<K::Native>,
271    K: PolarsNumericType,
272    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + quantile_filter::SealedRolling,
273{
274    let invalid_quantile = !(0.0..=1.0).contains(&quantile);
275    if invalid_quantile {
276        return Series::full_null(ca.name().clone(), groups.len(), ca.dtype());
277    }
278    match groups {
279        GroupsType::Idx(groups) => {
280            let ca = ca.rechunk();
281            agg_helper_idx_on_all::<K, _>(groups, |idx| {
282                debug_assert!(idx.len() <= ca.len());
283                if idx.is_empty() {
284                    return None;
285                }
286                let take = { ca.take_unchecked(idx) };
287                // checked with invalid quantile check
288                take._quantile(quantile, method).unwrap_unchecked()
289            })
290        },
291        GroupsType::Slice { groups, .. } => {
292            if _use_rolling_kernels(groups, ca.chunks()) {
293                // this cast is a no-op for floats
294                let s = ca
295                    .cast_with_options(&K::get_static_dtype(), CastOptions::Overflowing)
296                    .unwrap();
297                let ca: &ChunkedArray<K> = s.as_ref().as_ref();
298                let arr = ca.downcast_iter().next().unwrap();
299                let values = arr.values().as_slice();
300                let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
301                let arr = match arr.validity() {
302                    None => _rolling_apply_agg_window_no_nulls::<QuantileWindow<_>, _, _>(
303                        values,
304                        offset_iter,
305                        Some(RollingFnParams::Quantile(RollingQuantileParams {
306                            prob: quantile,
307                            method,
308                        })),
309                    ),
310                    Some(validity) => {
311                        _rolling_apply_agg_window_nulls::<rolling::nulls::QuantileWindow<_>, _, _>(
312                            values,
313                            validity,
314                            offset_iter,
315                            Some(RollingFnParams::Quantile(RollingQuantileParams {
316                                prob: quantile,
317                                method,
318                            })),
319                        )
320                    },
321                };
322                // The rolling kernels works on the dtype, this is not yet the
323                // float output type we need.
324                ChunkedArray::<K>::with_chunk(PlSmallStr::EMPTY, arr).into_series()
325            } else {
326                _agg_helper_slice::<K, _>(groups, |[first, len]| {
327                    debug_assert!(first + len <= ca.len() as IdxSize);
328                    match len {
329                        0 => None,
330                        1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()),
331                        _ => {
332                            let arr_group = _slice_from_offsets(ca, first, len);
333                            // unwrap checked with invalid quantile check
334                            arr_group
335                                ._quantile(quantile, method)
336                                .unwrap_unchecked()
337                                .map(|flt| NumCast::from(flt).unwrap_unchecked())
338                        },
339                    }
340                })
341            }
342        },
343    }
344}
345
346unsafe fn agg_median_generic<T, K>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series
347where
348    T: PolarsNumericType,
349    ChunkedArray<T>: QuantileDispatcher<K::Native>,
350    K: PolarsNumericType,
351    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + SealedRolling,
352{
353    match groups {
354        GroupsType::Idx(groups) => {
355            let ca = ca.rechunk();
356            agg_helper_idx_on_all::<K, _>(groups, |idx| {
357                debug_assert!(idx.len() <= ca.len());
358                if idx.is_empty() {
359                    return None;
360                }
361                let take = { ca.take_unchecked(idx) };
362                take._median()
363            })
364        },
365        GroupsType::Slice { .. } => {
366            agg_quantile_generic::<T, K>(ca, groups, 0.5, QuantileMethod::Linear)
367        },
368    }
369}
370
371/// # Safety
372///
373/// No bounds checks on `groups`.
374#[cfg(feature = "bitwise")]
375unsafe fn bitwise_agg<T: PolarsNumericType>(
376    ca: &ChunkedArray<T>,
377    groups: &GroupsType,
378    f: fn(&ChunkedArray<T>) -> Option<T::Native>,
379) -> Series
380where
381    ChunkedArray<T>: ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native>,
382{
383    // Prevent a rechunk for every individual group.
384
385    let s = if groups.len() > 1 {
386        ca.rechunk()
387    } else {
388        Cow::Borrowed(ca)
389    };
390
391    match groups {
392        GroupsType::Idx(groups) => agg_helper_idx_on_all::<T, _>(groups, |idx| {
393            debug_assert!(idx.len() <= s.len());
394            if idx.is_empty() {
395                None
396            } else {
397                let take = unsafe { s.take_unchecked(idx) };
398                f(&take)
399            }
400        }),
401        GroupsType::Slice { groups, .. } => _agg_helper_slice::<T, _>(groups, |[first, len]| {
402            debug_assert!(len <= s.len() as IdxSize);
403            if len == 0 {
404                None
405            } else {
406                let take = _slice_from_offsets(&s, first, len);
407                f(&take)
408            }
409        }),
410    }
411}
412
413#[cfg(feature = "bitwise")]
414impl<T> ChunkedArray<T>
415where
416    T: PolarsNumericType,
417    ChunkedArray<T>: ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native>,
418{
419    /// # Safety
420    ///
421    /// No bounds checks on `groups`.
422    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series {
423        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce) }
424    }
425
426    /// # Safety
427    ///
428    /// No bounds checks on `groups`.
429    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series {
430        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce) }
431    }
432
433    /// # Safety
434    ///
435    /// No bounds checks on `groups`.
436    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series {
437        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce) }
438    }
439}
440
441impl<T> ChunkedArray<T>
442where
443    T: PolarsNumericType + Sync,
444    T::Native: NativeType + PartialOrd + Num + NumCast + Zero + Bounded + std::iter::Sum<T::Native>,
445    ChunkedArray<T>: ChunkAgg<T::Native>,
446{
447    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
448        // faster paths
449        match (self.is_sorted_flag(), self.null_count()) {
450            (IsSorted::Ascending, 0) => {
451                return self.clone().into_series().agg_first(groups);
452            },
453            (IsSorted::Descending, 0) => {
454                return self.clone().into_series().agg_last(groups);
455            },
456            _ => {},
457        }
458        match groups {
459            GroupsType::Idx(groups) => {
460                let ca = self.rechunk();
461                let arr = ca.downcast_iter().next().unwrap();
462                let no_nulls = arr.null_count() == 0;
463                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
464                    debug_assert!(idx.len() <= arr.len());
465                    if idx.is_empty() {
466                        None
467                    } else if idx.len() == 1 {
468                        arr.get(first as usize)
469                    } else if no_nulls {
470                        take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
471                            .reduce(|a, b| a.min_ignore_nan(b))
472                    } else {
473                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
474                            .reduce(|a, b| a.min_ignore_nan(b))
475                    }
476                })
477            },
478            GroupsType::Slice {
479                groups: groups_slice,
480                ..
481            } => {
482                if _use_rolling_kernels(groups_slice, self.chunks()) {
483                    let arr = self.downcast_iter().next().unwrap();
484                    let values = arr.values().as_slice();
485                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
486                    let arr = match arr.validity() {
487                        None => _rolling_apply_agg_window_no_nulls::<MinWindow<_>, _, _>(
488                            values,
489                            offset_iter,
490                            None,
491                        ),
492                        Some(validity) => _rolling_apply_agg_window_nulls::<
493                            rolling::nulls::MinWindow<_>,
494                            _,
495                            _,
496                        >(
497                            values, validity, offset_iter, None
498                        ),
499                    };
500                    Self::from(arr).into_series()
501                } else {
502                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
503                        debug_assert!(len <= self.len() as IdxSize);
504                        match len {
505                            0 => None,
506                            1 => self.get(first as usize),
507                            _ => {
508                                let arr_group = _slice_from_offsets(self, first, len);
509                                ChunkAgg::min(&arr_group)
510                            },
511                        }
512                    })
513                }
514            },
515        }
516    }
517
518    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
519        // faster paths
520        match (self.is_sorted_flag(), self.null_count()) {
521            (IsSorted::Ascending, 0) => {
522                return self.clone().into_series().agg_last(groups);
523            },
524            (IsSorted::Descending, 0) => {
525                return self.clone().into_series().agg_first(groups);
526            },
527            _ => {},
528        }
529
530        match groups {
531            GroupsType::Idx(groups) => {
532                let ca = self.rechunk();
533                let arr = ca.downcast_iter().next().unwrap();
534                let no_nulls = arr.null_count() == 0;
535                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
536                    debug_assert!(idx.len() <= arr.len());
537                    if idx.is_empty() {
538                        None
539                    } else if idx.len() == 1 {
540                        arr.get(first as usize)
541                    } else if no_nulls {
542                        take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
543                            .reduce(|a, b| a.max_ignore_nan(b))
544                    } else {
545                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
546                            .reduce(|a, b| a.max_ignore_nan(b))
547                    }
548                })
549            },
550            GroupsType::Slice {
551                groups: groups_slice,
552                ..
553            } => {
554                if _use_rolling_kernels(groups_slice, self.chunks()) {
555                    let arr = self.downcast_iter().next().unwrap();
556                    let values = arr.values().as_slice();
557                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
558                    let arr = match arr.validity() {
559                        None => _rolling_apply_agg_window_no_nulls::<MaxWindow<_>, _, _>(
560                            values,
561                            offset_iter,
562                            None,
563                        ),
564                        Some(validity) => _rolling_apply_agg_window_nulls::<
565                            rolling::nulls::MaxWindow<_>,
566                            _,
567                            _,
568                        >(
569                            values, validity, offset_iter, None
570                        ),
571                    };
572                    Self::from(arr).into_series()
573                } else {
574                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
575                        debug_assert!(len <= self.len() as IdxSize);
576                        match len {
577                            0 => None,
578                            1 => self.get(first as usize),
579                            _ => {
580                                let arr_group = _slice_from_offsets(self, first, len);
581                                ChunkAgg::max(&arr_group)
582                            },
583                        }
584                    })
585                }
586            },
587        }
588    }
589
590    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
591        match groups {
592            GroupsType::Idx(groups) => {
593                let ca = self.rechunk();
594                let arr = ca.downcast_iter().next().unwrap();
595                let no_nulls = arr.null_count() == 0;
596                _agg_helper_idx_no_null::<T, _>(groups, |(first, idx)| {
597                    debug_assert!(idx.len() <= self.len());
598                    if idx.is_empty() {
599                        T::Native::zero()
600                    } else if idx.len() == 1 {
601                        arr.get(first as usize).unwrap_or(T::Native::zero())
602                    } else if no_nulls {
603                        if T::Native::is_float() {
604                            take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
605                                .fold(KahanSum::default(), |k, x| k + x)
606                                .sum()
607                        } else {
608                            take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
609                                .fold(T::Native::zero(), |a, b| a + b)
610                        }
611                    } else if T::Native::is_float() {
612                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
613                            .fold(KahanSum::default(), |k, x| k + x)
614                            .sum()
615                    } else {
616                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
617                            .fold(T::Native::zero(), |a, b| a + b)
618                    }
619                })
620            },
621            GroupsType::Slice { groups, .. } => {
622                if _use_rolling_kernels(groups, self.chunks()) {
623                    let arr = self.downcast_iter().next().unwrap();
624                    let values = arr.values().as_slice();
625                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
626                    let arr = match arr.validity() {
627                        None => _rolling_apply_agg_window_no_nulls::<
628                            SumWindow<T::Native, T::Native>,
629                            _,
630                            _,
631                        >(values, offset_iter, None),
632                        Some(validity) => _rolling_apply_agg_window_nulls::<
633                            SumWindow<T::Native, T::Native>,
634                            _,
635                            _,
636                        >(
637                            values, validity, offset_iter, None
638                        ),
639                    };
640                    Self::from(arr).into_series()
641                } else {
642                    _agg_helper_slice_no_null::<T, _>(groups, |[first, len]| {
643                        debug_assert!(len <= self.len() as IdxSize);
644                        match len {
645                            0 => T::Native::zero(),
646                            1 => self.get(first as usize).unwrap_or(T::Native::zero()),
647                            _ => {
648                                let arr_group = _slice_from_offsets(self, first, len);
649                                arr_group.sum().unwrap_or(T::Native::zero())
650                            },
651                        }
652                    })
653                }
654            },
655        }
656    }
657}
658
659impl<T> SeriesWrap<ChunkedArray<T>>
660where
661    T: PolarsFloatType,
662    ChunkedArray<T>: ChunkVar
663        + VarAggSeries
664        + ChunkQuantile<T::Native>
665        + QuantileAggSeries
666        + ChunkAgg<T::Native>,
667    T::Native: Pow<T::Native, Output = T::Native>,
668{
669    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
670        match groups {
671            GroupsType::Idx(groups) => {
672                let ca = self.rechunk();
673                let arr = ca.downcast_iter().next().unwrap();
674                let no_nulls = arr.null_count() == 0;
675                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
676                    // this can fail due to a bug in lazy code.
677                    // here users can create filters in aggregations
678                    // and thereby creating shorter columns than the original group tuples.
679                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
680                    // access
681                    debug_assert!(idx.len() <= self.len());
682                    let out = if idx.is_empty() {
683                        None
684                    } else if idx.len() == 1 {
685                        arr.get(first as usize).map(|sum| sum.to_f64().unwrap())
686                    } else if no_nulls {
687                        Some(
688                            take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
689                                .fold(KahanSum::default(), |a, b| {
690                                    a + b.to_f64().unwrap_unchecked()
691                                })
692                                .sum()
693                                / idx.len() as f64,
694                        )
695                    } else {
696                        take_agg_primitive_iter_unchecked_count_nulls(
697                            arr,
698                            idx2usize(idx),
699                            KahanSum::default(),
700                            |a, b| a + b.to_f64().unwrap_unchecked(),
701                            idx.len() as IdxSize,
702                        )
703                        .map(|(sum, null_count)| sum.sum() / (idx.len() as f64 - null_count as f64))
704                    };
705                    out.map(|flt| NumCast::from(flt).unwrap())
706                })
707            },
708            GroupsType::Slice { groups, .. } => {
709                if _use_rolling_kernels(groups, self.chunks()) {
710                    let arr = self.downcast_iter().next().unwrap();
711                    let values = arr.values().as_slice();
712                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
713                    let arr = match arr.validity() {
714                        None => _rolling_apply_agg_window_no_nulls::<MeanWindow<_>, _, _>(
715                            values,
716                            offset_iter,
717                            None,
718                        ),
719                        Some(validity) => _rolling_apply_agg_window_nulls::<MeanWindow<_>, _, _>(
720                            values,
721                            validity,
722                            offset_iter,
723                            None,
724                        ),
725                    };
726                    ChunkedArray::<T>::from(arr).into_series()
727                } else {
728                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
729                        debug_assert!(len <= self.len() as IdxSize);
730                        match len {
731                            0 => None,
732                            1 => self.get(first as usize),
733                            _ => {
734                                let arr_group = _slice_from_offsets(self, first, len);
735                                arr_group.mean().map(|flt| NumCast::from(flt).unwrap())
736                            },
737                        }
738                    })
739                }
740            },
741        }
742    }
743
744    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series
745    where
746        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
747    {
748        let ca = &self.0.rechunk();
749        match groups {
750            GroupsType::Idx(groups) => {
751                let ca = ca.rechunk();
752                let arr = ca.downcast_iter().next().unwrap();
753                let no_nulls = arr.null_count() == 0;
754                agg_helper_idx_on_all::<T, _>(groups, |idx| {
755                    debug_assert!(idx.len() <= ca.len());
756                    if idx.is_empty() {
757                        return None;
758                    }
759                    let out = if no_nulls {
760                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
761                    } else {
762                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
763                    };
764                    out.map(|flt| NumCast::from(flt).unwrap())
765                })
766            },
767            GroupsType::Slice { groups, .. } => {
768                if _use_rolling_kernels(groups, self.chunks()) {
769                    let arr = self.downcast_iter().next().unwrap();
770                    let values = arr.values().as_slice();
771                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
772                    let arr = match arr.validity() {
773                        None => _rolling_apply_agg_window_no_nulls::<
774                            MomentWindow<_, VarianceMoment>,
775                            _,
776                            _,
777                        >(
778                            values,
779                            offset_iter,
780                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
781                        ),
782                        Some(validity) => _rolling_apply_agg_window_nulls::<
783                            rolling::nulls::MomentWindow<_, VarianceMoment>,
784                            _,
785                            _,
786                        >(
787                            values,
788                            validity,
789                            offset_iter,
790                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
791                        ),
792                    };
793                    ChunkedArray::<T>::from(arr).into_series()
794                } else {
795                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
796                        debug_assert!(len <= self.len() as IdxSize);
797                        match len {
798                            0 => None,
799                            1 => {
800                                if ddof == 0 {
801                                    NumCast::from(0)
802                                } else {
803                                    None
804                                }
805                            },
806                            _ => {
807                                let arr_group = _slice_from_offsets(self, first, len);
808                                arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap())
809                            },
810                        }
811                    })
812                }
813            },
814        }
815    }
816    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series
817    where
818        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
819    {
820        let ca = &self.0.rechunk();
821        match groups {
822            GroupsType::Idx(groups) => {
823                let arr = ca.downcast_iter().next().unwrap();
824                let no_nulls = arr.null_count() == 0;
825                agg_helper_idx_on_all::<T, _>(groups, |idx| {
826                    debug_assert!(idx.len() <= ca.len());
827                    if idx.is_empty() {
828                        return None;
829                    }
830                    let out = if no_nulls {
831                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
832                    } else {
833                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
834                    };
835                    out.map(|flt| NumCast::from(flt.sqrt()).unwrap())
836                })
837            },
838            GroupsType::Slice { groups, .. } => {
839                if _use_rolling_kernels(groups, self.chunks()) {
840                    let arr = ca.downcast_iter().next().unwrap();
841                    let values = arr.values().as_slice();
842                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
843                    let arr = match arr.validity() {
844                        None => _rolling_apply_agg_window_no_nulls::<
845                            MomentWindow<_, VarianceMoment>,
846                            _,
847                            _,
848                        >(
849                            values,
850                            offset_iter,
851                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
852                        ),
853                        Some(validity) => _rolling_apply_agg_window_nulls::<
854                            rolling::nulls::MomentWindow<_, rolling::nulls::VarianceMoment>,
855                            _,
856                            _,
857                        >(
858                            values,
859                            validity,
860                            offset_iter,
861                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
862                        ),
863                    };
864
865                    let mut ca = ChunkedArray::<T>::from(arr);
866                    ca.apply_mut(|v| v.powf(NumCast::from(0.5).unwrap()));
867                    ca.into_series()
868                } else {
869                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
870                        debug_assert!(len <= self.len() as IdxSize);
871                        match len {
872                            0 => None,
873                            1 => {
874                                if ddof == 0 {
875                                    NumCast::from(0)
876                                } else {
877                                    None
878                                }
879                            },
880                            _ => {
881                                let arr_group = _slice_from_offsets(self, first, len);
882                                arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap())
883                            },
884                        }
885                    })
886                }
887            },
888        }
889    }
890}
891
892impl Float32Chunked {
893    pub(crate) unsafe fn agg_quantile(
894        &self,
895        groups: &GroupsType,
896        quantile: f64,
897        method: QuantileMethod,
898    ) -> Series {
899        agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method)
900    }
901    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
902        agg_median_generic::<_, Float32Type>(self, groups)
903    }
904}
905impl Float64Chunked {
906    pub(crate) unsafe fn agg_quantile(
907        &self,
908        groups: &GroupsType,
909        quantile: f64,
910        method: QuantileMethod,
911    ) -> Series {
912        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
913    }
914    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
915        agg_median_generic::<_, Float64Type>(self, groups)
916    }
917}
918
919impl<T> ChunkedArray<T>
920where
921    T: PolarsIntegerType,
922    ChunkedArray<T>: ChunkAgg<T::Native> + ChunkVar,
923    T::Native: NumericNative + Ord,
924{
925    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
926        match groups {
927            GroupsType::Idx(groups) => {
928                let ca = self.rechunk();
929                let arr = ca.downcast_get(0).unwrap();
930                _agg_helper_idx::<Float64Type, _>(groups, |(first, idx)| {
931                    // this can fail due to a bug in lazy code.
932                    // here users can create filters in aggregations
933                    // and thereby creating shorter columns than the original group tuples.
934                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
935                    // access
936                    debug_assert!(idx.len() <= self.len());
937                    if idx.is_empty() {
938                        None
939                    } else if idx.len() == 1 {
940                        self.get(first as usize).map(|sum| sum.to_f64().unwrap())
941                    } else {
942                        match (self.has_nulls(), self.chunks.len()) {
943                            (false, 1) => Some(
944                                take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
945                                    .fold(KahanSum::default(), |a, b| a + b.to_f64().unwrap())
946                                    .sum()
947                                    / idx.len() as f64,
948                            ),
949                            (_, 1) => {
950                                take_agg_primitive_iter_unchecked_count_nulls(
951                                    arr,
952                                    idx2usize(idx),
953                                    KahanSum::default(),
954                                    |a, b| a + b.to_f64().unwrap(),
955                                    idx.len() as IdxSize,
956                                )
957                            }
958                            .map(|(sum, null_count)| {
959                                sum.sum() / (idx.len() as f64 - null_count as f64)
960                            }),
961                            _ => {
962                                let take = { self.take_unchecked(idx) };
963                                take.mean()
964                            },
965                        }
966                    }
967                })
968            },
969            GroupsType::Slice {
970                groups: groups_slice,
971                ..
972            } => {
973                if _use_rolling_kernels(groups_slice, self.chunks()) {
974                    let ca = self
975                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
976                        .unwrap();
977                    ca.agg_mean(groups)
978                } else {
979                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
980                        debug_assert!(first + len <= self.len() as IdxSize);
981                        match len {
982                            0 => None,
983                            1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
984                            _ => {
985                                let arr_group = _slice_from_offsets(self, first, len);
986                                arr_group.mean()
987                            },
988                        }
989                    })
990                }
991            },
992        }
993    }
994
995    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series {
996        match groups {
997            GroupsType::Idx(groups) => {
998                let ca_self = self.rechunk();
999                let arr = ca_self.downcast_iter().next().unwrap();
1000                let no_nulls = arr.null_count() == 0;
1001                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1002                    debug_assert!(idx.len() <= arr.len());
1003                    if idx.is_empty() {
1004                        return None;
1005                    }
1006                    if no_nulls {
1007                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1008                    } else {
1009                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1010                    }
1011                })
1012            },
1013            GroupsType::Slice {
1014                groups: groups_slice,
1015                ..
1016            } => {
1017                if _use_rolling_kernels(groups_slice, self.chunks()) {
1018                    let ca = self
1019                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1020                        .unwrap();
1021                    ca.agg_var(groups, ddof)
1022                } else {
1023                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1024                        debug_assert!(first + len <= self.len() as IdxSize);
1025                        match len {
1026                            0 => None,
1027                            1 => {
1028                                if ddof == 0 {
1029                                    NumCast::from(0)
1030                                } else {
1031                                    None
1032                                }
1033                            },
1034                            _ => {
1035                                let arr_group = _slice_from_offsets(self, first, len);
1036                                arr_group.var(ddof)
1037                            },
1038                        }
1039                    })
1040                }
1041            },
1042        }
1043    }
1044    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series {
1045        match groups {
1046            GroupsType::Idx(groups) => {
1047                let ca_self = self.rechunk();
1048                let arr = ca_self.downcast_iter().next().unwrap();
1049                let no_nulls = arr.null_count() == 0;
1050                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1051                    debug_assert!(idx.len() <= self.len());
1052                    if idx.is_empty() {
1053                        return None;
1054                    }
1055                    let out = if no_nulls {
1056                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1057                    } else {
1058                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1059                    };
1060                    out.map(|v| v.sqrt())
1061                })
1062            },
1063            GroupsType::Slice {
1064                groups: groups_slice,
1065                ..
1066            } => {
1067                if _use_rolling_kernels(groups_slice, self.chunks()) {
1068                    let ca = self
1069                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1070                        .unwrap();
1071                    ca.agg_std(groups, ddof)
1072                } else {
1073                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1074                        debug_assert!(first + len <= self.len() as IdxSize);
1075                        match len {
1076                            0 => None,
1077                            1 => {
1078                                if ddof == 0 {
1079                                    NumCast::from(0)
1080                                } else {
1081                                    None
1082                                }
1083                            },
1084                            _ => {
1085                                let arr_group = _slice_from_offsets(self, first, len);
1086                                arr_group.std(ddof)
1087                            },
1088                        }
1089                    })
1090                }
1091            },
1092        }
1093    }
1094
1095    pub(crate) unsafe fn agg_quantile(
1096        &self,
1097        groups: &GroupsType,
1098        quantile: f64,
1099        method: QuantileMethod,
1100    ) -> Series {
1101        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
1102    }
1103    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
1104        agg_median_generic::<_, Float64Type>(self, groups)
1105    }
1106}