polars_core/frame/group_by/aggregations/
mod.rs

1mod agg_list;
2mod boolean;
3mod dispatch;
4mod string;
5
6use std::borrow::Cow;
7
8pub use agg_list::*;
9use arrow::bitmap::{Bitmap, MutableBitmap};
10use arrow::legacy::kernels::take_agg::*;
11use arrow::legacy::trusted_len::TrustedLenPush;
12use arrow::types::NativeType;
13use num_traits::pow::Pow;
14use num_traits::{Bounded, Float, Num, NumCast, ToPrimitive, Zero};
15use polars_compute::rolling::no_nulls::{
16    MaxWindow, MinWindow, MomentWindow, QuantileWindow, RollingAggWindowNoNulls,
17};
18use polars_compute::rolling::nulls::{RollingAggWindowNulls, VarianceMoment};
19use polars_compute::rolling::quantile_filter::SealedRolling;
20use polars_compute::rolling::{
21    self, MeanWindow, QuantileMethod, RollingFnParams, RollingQuantileParams, RollingVarParams,
22    SumWindow, quantile_filter,
23};
24use polars_utils::float::IsFloat;
25use polars_utils::idx_vec::IdxVec;
26use polars_utils::kahan_sum::KahanSum;
27use polars_utils::min_max::MinMax;
28use rayon::prelude::*;
29
30use crate::POOL;
31use crate::chunked_array::cast::CastOptions;
32#[cfg(feature = "object")]
33use crate::chunked_array::object::extension::create_extension;
34use crate::frame::group_by::GroupsIdx;
35#[cfg(feature = "object")]
36use crate::frame::group_by::GroupsIndicator;
37use crate::prelude::*;
38use crate::series::IsSorted;
39use crate::series::implementations::SeriesWrap;
40use crate::utils::NoNull;
41
42fn idx2usize(idx: &[IdxSize]) -> impl ExactSizeIterator<Item = usize> + '_ {
43    idx.iter().map(|i| *i as usize)
44}
45
46// if the windows overlap, we can use the rolling_<agg> kernels
47// they maintain state, which saves a lot of compute by not naively traversing all elements every
48// window
49//
50// if the windows don't overlap, we should not use these kernels as they are single threaded, so
51// we miss out on easy parallelization.
52pub fn _use_rolling_kernels(groups: &GroupsSlice, chunks: &[ArrayRef]) -> bool {
53    match groups.len() {
54        0 | 1 => false,
55        _ => {
56            let [first_offset, first_len] = groups[0];
57            let second_offset = groups[1][0];
58
59            second_offset >= first_offset // Prevent false positive from regular group-by that has out of order slices.
60                                          // Rolling group-by is expected to have monotonically increasing slices.
61                && second_offset < (first_offset + first_len)
62                && chunks.len() == 1
63        },
64    }
65}
66
67// Use an aggregation window that maintains the state
68pub fn _rolling_apply_agg_window_nulls<'a, Agg, T, O>(
69    values: &'a [T],
70    validity: &'a Bitmap,
71    offsets: O,
72    params: Option<RollingFnParams>,
73) -> PrimitiveArray<T>
74where
75    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
76    Agg: RollingAggWindowNulls<'a, T>,
77    T: IsFloat + NativeType,
78{
79    if values.is_empty() {
80        let out: Vec<T> = vec![];
81        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
82    }
83
84    // This iterators length can be trusted
85    // these represent the number of groups in the group_by operation
86    let output_len = offsets.size_hint().0;
87    // start with a dummy index, will be overwritten on first iteration.
88    // SAFETY:
89    // we are in bounds
90    let mut agg_window = unsafe { Agg::new(values, validity, 0, 0, params, None) };
91
92    let mut validity = MutableBitmap::with_capacity(output_len);
93    validity.extend_constant(output_len, true);
94
95    let out = offsets
96        .enumerate()
97        .map(|(idx, (start, len))| {
98            let end = start + len;
99
100            // SAFETY:
101            // we are in bounds
102
103            let agg = if start == end {
104                None
105            } else {
106                unsafe { agg_window.update(start as usize, end as usize) }
107            };
108
109            match agg {
110                Some(val) => val,
111                None => {
112                    // SAFETY: we are in bounds
113                    unsafe { validity.set_unchecked(idx, false) };
114                    T::default()
115                },
116            }
117        })
118        .collect_trusted::<Vec<_>>();
119
120    PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), Some(validity.into()))
121}
122
123// Use an aggregation window that maintains the state.
124pub fn _rolling_apply_agg_window_no_nulls<'a, Agg, T, O>(
125    values: &'a [T],
126    offsets: O,
127    params: Option<RollingFnParams>,
128) -> PrimitiveArray<T>
129where
130    // items (offset, len) -> so offsets are offset, offset + len
131    Agg: RollingAggWindowNoNulls<'a, T>,
132    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
133    T: IsFloat + NativeType,
134{
135    if values.is_empty() {
136        let out: Vec<T> = vec![];
137        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
138    }
139    // start with a dummy index, will be overwritten on first iteration.
140    let mut agg_window = Agg::new(values, 0, 0, params, None);
141
142    offsets
143        .map(|(start, len)| {
144            let end = start + len;
145
146            if start == end {
147                None
148            } else {
149                // SAFETY: we are in bounds.
150                unsafe { agg_window.update(start as usize, end as usize) }
151            }
152        })
153        .collect::<PrimitiveArray<T>>()
154}
155
156pub fn _slice_from_offsets<T>(ca: &ChunkedArray<T>, first: IdxSize, len: IdxSize) -> ChunkedArray<T>
157where
158    T: PolarsDataType,
159{
160    ca.slice(first as i64, len as usize)
161}
162
163/// Helper that combines the groups into a parallel iterator over `(first, all): (u32, &Vec<u32>)`.
164pub fn _agg_helper_idx<T, F>(groups: &GroupsIdx, f: F) -> Series
165where
166    F: Fn((IdxSize, &IdxVec)) -> Option<T::Native> + Send + Sync,
167    T: PolarsNumericType,
168{
169    let ca: ChunkedArray<T> = POOL.install(|| groups.into_par_iter().map(f).collect());
170    ca.into_series()
171}
172
173/// Same helper as `_agg_helper_idx` but for aggregations that don't return an Option.
174pub fn _agg_helper_idx_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
175where
176    F: Fn((IdxSize, &IdxVec)) -> T::Native + Send + Sync,
177    T: PolarsNumericType,
178{
179    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.into_par_iter().map(f).collect());
180    ca.into_inner().into_series()
181}
182
183/// Helper that iterates on the `all: Vec<Vec<u32>` collection,
184/// this doesn't have traverse the `first: Vec<u32>` memory and is therefore faster.
185fn agg_helper_idx_on_all<T, F>(groups: &GroupsIdx, f: F) -> Series
186where
187    F: Fn(&IdxVec) -> Option<T::Native> + Send + Sync,
188    T: PolarsNumericType,
189{
190    let ca: ChunkedArray<T> = POOL.install(|| groups.all().into_par_iter().map(f).collect());
191    ca.into_series()
192}
193
194pub fn _agg_helper_slice<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
195where
196    F: Fn([IdxSize; 2]) -> Option<T::Native> + Send + Sync,
197    T: PolarsNumericType,
198{
199    let ca: ChunkedArray<T> = POOL.install(|| groups.par_iter().copied().map(f).collect());
200    ca.into_series()
201}
202
203pub fn _agg_helper_slice_no_null<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
204where
205    F: Fn([IdxSize; 2]) -> T::Native + Send + Sync,
206    T: PolarsNumericType,
207{
208    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.par_iter().copied().map(f).collect());
209    ca.into_inner().into_series()
210}
211
212/// Intermediate helper trait so we can have a single generic implementation
213/// This trait will ensure the specific dispatch works without complicating
214/// the trait bounds.
215trait QuantileDispatcher<K> {
216    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<K>>;
217
218    fn _median(self) -> Option<K>;
219}
220
221impl<T> QuantileDispatcher<f64> for ChunkedArray<T>
222where
223    T: PolarsIntegerType,
224    T::Native: Ord,
225{
226    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
227        self.quantile_faster(quantile, method)
228    }
229    fn _median(self) -> Option<f64> {
230        self.median_faster()
231    }
232}
233
234impl QuantileDispatcher<f32> for Float32Chunked {
235    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f32>> {
236        self.quantile_faster(quantile, method)
237    }
238    fn _median(self) -> Option<f32> {
239        self.median_faster()
240    }
241}
242impl QuantileDispatcher<f64> for Float64Chunked {
243    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
244        self.quantile_faster(quantile, method)
245    }
246    fn _median(self) -> Option<f64> {
247        self.median_faster()
248    }
249}
250
251unsafe fn agg_quantile_generic<T, K>(
252    ca: &ChunkedArray<T>,
253    groups: &GroupsType,
254    quantile: f64,
255    method: QuantileMethod,
256) -> Series
257where
258    T: PolarsNumericType,
259    ChunkedArray<T>: QuantileDispatcher<K::Native>,
260    K: PolarsNumericType,
261    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + quantile_filter::SealedRolling,
262{
263    let invalid_quantile = !(0.0..=1.0).contains(&quantile);
264    if invalid_quantile {
265        return Series::full_null(ca.name().clone(), groups.len(), ca.dtype());
266    }
267    match groups {
268        GroupsType::Idx(groups) => {
269            let ca = ca.rechunk();
270            agg_helper_idx_on_all::<K, _>(groups, |idx| {
271                debug_assert!(idx.len() <= ca.len());
272                if idx.is_empty() {
273                    return None;
274                }
275                let take = { ca.take_unchecked(idx) };
276                // checked with invalid quantile check
277                take._quantile(quantile, method).unwrap_unchecked()
278            })
279        },
280        GroupsType::Slice { groups, .. } => {
281            if _use_rolling_kernels(groups, ca.chunks()) {
282                // this cast is a no-op for floats
283                let s = ca
284                    .cast_with_options(&K::get_static_dtype(), CastOptions::Overflowing)
285                    .unwrap();
286                let ca: &ChunkedArray<K> = s.as_ref().as_ref();
287                let arr = ca.downcast_iter().next().unwrap();
288                let values = arr.values().as_slice();
289                let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
290                let arr = match arr.validity() {
291                    None => _rolling_apply_agg_window_no_nulls::<QuantileWindow<_>, _, _>(
292                        values,
293                        offset_iter,
294                        Some(RollingFnParams::Quantile(RollingQuantileParams {
295                            prob: quantile,
296                            method,
297                        })),
298                    ),
299                    Some(validity) => {
300                        _rolling_apply_agg_window_nulls::<rolling::nulls::QuantileWindow<_>, _, _>(
301                            values,
302                            validity,
303                            offset_iter,
304                            Some(RollingFnParams::Quantile(RollingQuantileParams {
305                                prob: quantile,
306                                method,
307                            })),
308                        )
309                    },
310                };
311                // The rolling kernels works on the dtype, this is not yet the
312                // float output type we need.
313                ChunkedArray::<K>::with_chunk(PlSmallStr::EMPTY, arr).into_series()
314            } else {
315                _agg_helper_slice::<K, _>(groups, |[first, len]| {
316                    debug_assert!(first + len <= ca.len() as IdxSize);
317                    match len {
318                        0 => None,
319                        1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()),
320                        _ => {
321                            let arr_group = _slice_from_offsets(ca, first, len);
322                            // unwrap checked with invalid quantile check
323                            arr_group
324                                ._quantile(quantile, method)
325                                .unwrap_unchecked()
326                                .map(|flt| NumCast::from(flt).unwrap_unchecked())
327                        },
328                    }
329                })
330            }
331        },
332    }
333}
334
335unsafe fn agg_median_generic<T, K>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series
336where
337    T: PolarsNumericType,
338    ChunkedArray<T>: QuantileDispatcher<K::Native>,
339    K: PolarsNumericType,
340    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + SealedRolling,
341{
342    match groups {
343        GroupsType::Idx(groups) => {
344            let ca = ca.rechunk();
345            agg_helper_idx_on_all::<K, _>(groups, |idx| {
346                debug_assert!(idx.len() <= ca.len());
347                if idx.is_empty() {
348                    return None;
349                }
350                let take = { ca.take_unchecked(idx) };
351                take._median()
352            })
353        },
354        GroupsType::Slice { .. } => {
355            agg_quantile_generic::<T, K>(ca, groups, 0.5, QuantileMethod::Linear)
356        },
357    }
358}
359
360/// # Safety
361///
362/// No bounds checks on `groups`.
363#[cfg(feature = "bitwise")]
364unsafe fn bitwise_agg<T: PolarsNumericType>(
365    ca: &ChunkedArray<T>,
366    groups: &GroupsType,
367    f: fn(&ChunkedArray<T>) -> Option<T::Native>,
368) -> Series
369where
370    ChunkedArray<T>: ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native>,
371{
372    // Prevent a rechunk for every individual group.
373
374    let s = if groups.len() > 1 {
375        ca.rechunk()
376    } else {
377        Cow::Borrowed(ca)
378    };
379
380    match groups {
381        GroupsType::Idx(groups) => agg_helper_idx_on_all::<T, _>(groups, |idx| {
382            debug_assert!(idx.len() <= s.len());
383            if idx.is_empty() {
384                None
385            } else {
386                let take = unsafe { s.take_unchecked(idx) };
387                f(&take)
388            }
389        }),
390        GroupsType::Slice { groups, .. } => _agg_helper_slice::<T, _>(groups, |[first, len]| {
391            debug_assert!(len <= s.len() as IdxSize);
392            if len == 0 {
393                None
394            } else {
395                let take = _slice_from_offsets(&s, first, len);
396                f(&take)
397            }
398        }),
399    }
400}
401
402#[cfg(feature = "bitwise")]
403impl<T> ChunkedArray<T>
404where
405    T: PolarsNumericType,
406    ChunkedArray<T>: ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native>,
407{
408    /// # Safety
409    ///
410    /// No bounds checks on `groups`.
411    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series {
412        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce) }
413    }
414
415    /// # Safety
416    ///
417    /// No bounds checks on `groups`.
418    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series {
419        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce) }
420    }
421
422    /// # Safety
423    ///
424    /// No bounds checks on `groups`.
425    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series {
426        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce) }
427    }
428}
429
430impl<T> ChunkedArray<T>
431where
432    T: PolarsNumericType + Sync,
433    T::Native: NativeType + PartialOrd + Num + NumCast + Zero + Bounded + std::iter::Sum<T::Native>,
434    ChunkedArray<T>: ChunkAgg<T::Native>,
435{
436    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
437        // faster paths
438        match (self.is_sorted_flag(), self.null_count()) {
439            (IsSorted::Ascending, 0) => {
440                return self.clone().into_series().agg_first(groups);
441            },
442            (IsSorted::Descending, 0) => {
443                return self.clone().into_series().agg_last(groups);
444            },
445            _ => {},
446        }
447        match groups {
448            GroupsType::Idx(groups) => {
449                let ca = self.rechunk();
450                let arr = ca.downcast_iter().next().unwrap();
451                let no_nulls = arr.null_count() == 0;
452                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
453                    debug_assert!(idx.len() <= arr.len());
454                    if idx.is_empty() {
455                        None
456                    } else if idx.len() == 1 {
457                        arr.get(first as usize)
458                    } else if no_nulls {
459                        take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
460                            .reduce(|a, b| a.min_ignore_nan(b))
461                    } else {
462                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
463                            .reduce(|a, b| a.min_ignore_nan(b))
464                    }
465                })
466            },
467            GroupsType::Slice {
468                groups: groups_slice,
469                ..
470            } => {
471                if _use_rolling_kernels(groups_slice, self.chunks()) {
472                    let arr = self.downcast_iter().next().unwrap();
473                    let values = arr.values().as_slice();
474                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
475                    let arr = match arr.validity() {
476                        None => _rolling_apply_agg_window_no_nulls::<MinWindow<_>, _, _>(
477                            values,
478                            offset_iter,
479                            None,
480                        ),
481                        Some(validity) => _rolling_apply_agg_window_nulls::<
482                            rolling::nulls::MinWindow<_>,
483                            _,
484                            _,
485                        >(
486                            values, validity, offset_iter, None
487                        ),
488                    };
489                    Self::from(arr).into_series()
490                } else {
491                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
492                        debug_assert!(len <= self.len() as IdxSize);
493                        match len {
494                            0 => None,
495                            1 => self.get(first as usize),
496                            _ => {
497                                let arr_group = _slice_from_offsets(self, first, len);
498                                ChunkAgg::min(&arr_group)
499                            },
500                        }
501                    })
502                }
503            },
504        }
505    }
506
507    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
508        // faster paths
509        match (self.is_sorted_flag(), self.null_count()) {
510            (IsSorted::Ascending, 0) => {
511                return self.clone().into_series().agg_last(groups);
512            },
513            (IsSorted::Descending, 0) => {
514                return self.clone().into_series().agg_first(groups);
515            },
516            _ => {},
517        }
518
519        match groups {
520            GroupsType::Idx(groups) => {
521                let ca = self.rechunk();
522                let arr = ca.downcast_iter().next().unwrap();
523                let no_nulls = arr.null_count() == 0;
524                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
525                    debug_assert!(idx.len() <= arr.len());
526                    if idx.is_empty() {
527                        None
528                    } else if idx.len() == 1 {
529                        arr.get(first as usize)
530                    } else if no_nulls {
531                        take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
532                            .reduce(|a, b| a.max_ignore_nan(b))
533                    } else {
534                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
535                            .reduce(|a, b| a.max_ignore_nan(b))
536                    }
537                })
538            },
539            GroupsType::Slice {
540                groups: groups_slice,
541                ..
542            } => {
543                if _use_rolling_kernels(groups_slice, self.chunks()) {
544                    let arr = self.downcast_iter().next().unwrap();
545                    let values = arr.values().as_slice();
546                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
547                    let arr = match arr.validity() {
548                        None => _rolling_apply_agg_window_no_nulls::<MaxWindow<_>, _, _>(
549                            values,
550                            offset_iter,
551                            None,
552                        ),
553                        Some(validity) => _rolling_apply_agg_window_nulls::<
554                            rolling::nulls::MaxWindow<_>,
555                            _,
556                            _,
557                        >(
558                            values, validity, offset_iter, None
559                        ),
560                    };
561                    Self::from(arr).into_series()
562                } else {
563                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
564                        debug_assert!(len <= self.len() as IdxSize);
565                        match len {
566                            0 => None,
567                            1 => self.get(first as usize),
568                            _ => {
569                                let arr_group = _slice_from_offsets(self, first, len);
570                                ChunkAgg::max(&arr_group)
571                            },
572                        }
573                    })
574                }
575            },
576        }
577    }
578
579    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
580        match groups {
581            GroupsType::Idx(groups) => {
582                let ca = self.rechunk();
583                let arr = ca.downcast_iter().next().unwrap();
584                let no_nulls = arr.null_count() == 0;
585                _agg_helper_idx_no_null::<T, _>(groups, |(first, idx)| {
586                    debug_assert!(idx.len() <= self.len());
587                    if idx.is_empty() {
588                        T::Native::zero()
589                    } else if idx.len() == 1 {
590                        arr.get(first as usize).unwrap_or(T::Native::zero())
591                    } else if no_nulls {
592                        if T::Native::is_float() {
593                            take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
594                                .fold(KahanSum::default(), |k, x| k + x)
595                                .sum()
596                        } else {
597                            take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
598                                .fold(T::Native::zero(), |a, b| a + b)
599                        }
600                    } else if T::Native::is_float() {
601                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
602                            .fold(KahanSum::default(), |k, x| k + x)
603                            .sum()
604                    } else {
605                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
606                            .fold(T::Native::zero(), |a, b| a + b)
607                    }
608                })
609            },
610            GroupsType::Slice { groups, .. } => {
611                if _use_rolling_kernels(groups, self.chunks()) {
612                    let arr = self.downcast_iter().next().unwrap();
613                    let values = arr.values().as_slice();
614                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
615                    let arr = match arr.validity() {
616                        None => _rolling_apply_agg_window_no_nulls::<
617                            SumWindow<T::Native, T::Native>,
618                            _,
619                            _,
620                        >(values, offset_iter, None),
621                        Some(validity) => _rolling_apply_agg_window_nulls::<
622                            SumWindow<T::Native, T::Native>,
623                            _,
624                            _,
625                        >(
626                            values, validity, offset_iter, None
627                        ),
628                    };
629                    Self::from(arr).into_series()
630                } else {
631                    _agg_helper_slice_no_null::<T, _>(groups, |[first, len]| {
632                        debug_assert!(len <= self.len() as IdxSize);
633                        match len {
634                            0 => T::Native::zero(),
635                            1 => self.get(first as usize).unwrap_or(T::Native::zero()),
636                            _ => {
637                                let arr_group = _slice_from_offsets(self, first, len);
638                                arr_group.sum().unwrap_or(T::Native::zero())
639                            },
640                        }
641                    })
642                }
643            },
644        }
645    }
646}
647
648impl<T> SeriesWrap<ChunkedArray<T>>
649where
650    T: PolarsFloatType,
651    ChunkedArray<T>: ChunkVar
652        + VarAggSeries
653        + ChunkQuantile<T::Native>
654        + QuantileAggSeries
655        + ChunkAgg<T::Native>,
656    T::Native: Pow<T::Native, Output = T::Native>,
657{
658    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
659        match groups {
660            GroupsType::Idx(groups) => {
661                let ca = self.rechunk();
662                let arr = ca.downcast_iter().next().unwrap();
663                let no_nulls = arr.null_count() == 0;
664                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
665                    // this can fail due to a bug in lazy code.
666                    // here users can create filters in aggregations
667                    // and thereby creating shorter columns than the original group tuples.
668                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
669                    // access
670                    debug_assert!(idx.len() <= self.len());
671                    let out = if idx.is_empty() {
672                        None
673                    } else if idx.len() == 1 {
674                        arr.get(first as usize).map(|sum| sum.to_f64().unwrap())
675                    } else if no_nulls {
676                        Some(
677                            take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
678                                .fold(KahanSum::default(), |a, b| {
679                                    a + b.to_f64().unwrap_unchecked()
680                                })
681                                .sum()
682                                / idx.len() as f64,
683                        )
684                    } else {
685                        take_agg_primitive_iter_unchecked_count_nulls(
686                            arr,
687                            idx2usize(idx),
688                            KahanSum::default(),
689                            |a, b| a + b.to_f64().unwrap_unchecked(),
690                            idx.len() as IdxSize,
691                        )
692                        .map(|(sum, null_count)| sum.sum() / (idx.len() as f64 - null_count as f64))
693                    };
694                    out.map(|flt| NumCast::from(flt).unwrap())
695                })
696            },
697            GroupsType::Slice { groups, .. } => {
698                if _use_rolling_kernels(groups, self.chunks()) {
699                    let arr = self.downcast_iter().next().unwrap();
700                    let values = arr.values().as_slice();
701                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
702                    let arr = match arr.validity() {
703                        None => _rolling_apply_agg_window_no_nulls::<MeanWindow<_>, _, _>(
704                            values,
705                            offset_iter,
706                            None,
707                        ),
708                        Some(validity) => _rolling_apply_agg_window_nulls::<MeanWindow<_>, _, _>(
709                            values,
710                            validity,
711                            offset_iter,
712                            None,
713                        ),
714                    };
715                    ChunkedArray::<T>::from(arr).into_series()
716                } else {
717                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
718                        debug_assert!(len <= self.len() as IdxSize);
719                        match len {
720                            0 => None,
721                            1 => self.get(first as usize),
722                            _ => {
723                                let arr_group = _slice_from_offsets(self, first, len);
724                                arr_group.mean().map(|flt| NumCast::from(flt).unwrap())
725                            },
726                        }
727                    })
728                }
729            },
730        }
731    }
732
733    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series
734    where
735        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
736    {
737        let ca = &self.0.rechunk();
738        match groups {
739            GroupsType::Idx(groups) => {
740                let ca = ca.rechunk();
741                let arr = ca.downcast_iter().next().unwrap();
742                let no_nulls = arr.null_count() == 0;
743                agg_helper_idx_on_all::<T, _>(groups, |idx| {
744                    debug_assert!(idx.len() <= ca.len());
745                    if idx.is_empty() {
746                        return None;
747                    }
748                    let out = if no_nulls {
749                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
750                    } else {
751                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
752                    };
753                    out.map(|flt| NumCast::from(flt).unwrap())
754                })
755            },
756            GroupsType::Slice { groups, .. } => {
757                if _use_rolling_kernels(groups, self.chunks()) {
758                    let arr = self.downcast_iter().next().unwrap();
759                    let values = arr.values().as_slice();
760                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
761                    let arr = match arr.validity() {
762                        None => _rolling_apply_agg_window_no_nulls::<
763                            MomentWindow<_, VarianceMoment>,
764                            _,
765                            _,
766                        >(
767                            values,
768                            offset_iter,
769                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
770                        ),
771                        Some(validity) => _rolling_apply_agg_window_nulls::<
772                            rolling::nulls::MomentWindow<_, VarianceMoment>,
773                            _,
774                            _,
775                        >(
776                            values,
777                            validity,
778                            offset_iter,
779                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
780                        ),
781                    };
782                    ChunkedArray::<T>::from(arr).into_series()
783                } else {
784                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
785                        debug_assert!(len <= self.len() as IdxSize);
786                        match len {
787                            0 => None,
788                            1 => {
789                                if ddof == 0 {
790                                    NumCast::from(0)
791                                } else {
792                                    None
793                                }
794                            },
795                            _ => {
796                                let arr_group = _slice_from_offsets(self, first, len);
797                                arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap())
798                            },
799                        }
800                    })
801                }
802            },
803        }
804    }
805    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series
806    where
807        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
808    {
809        let ca = &self.0.rechunk();
810        match groups {
811            GroupsType::Idx(groups) => {
812                let arr = ca.downcast_iter().next().unwrap();
813                let no_nulls = arr.null_count() == 0;
814                agg_helper_idx_on_all::<T, _>(groups, |idx| {
815                    debug_assert!(idx.len() <= ca.len());
816                    if idx.is_empty() {
817                        return None;
818                    }
819                    let out = if no_nulls {
820                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
821                    } else {
822                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
823                    };
824                    out.map(|flt| NumCast::from(flt.sqrt()).unwrap())
825                })
826            },
827            GroupsType::Slice { groups, .. } => {
828                if _use_rolling_kernels(groups, self.chunks()) {
829                    let arr = ca.downcast_iter().next().unwrap();
830                    let values = arr.values().as_slice();
831                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
832                    let arr = match arr.validity() {
833                        None => _rolling_apply_agg_window_no_nulls::<
834                            MomentWindow<_, VarianceMoment>,
835                            _,
836                            _,
837                        >(
838                            values,
839                            offset_iter,
840                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
841                        ),
842                        Some(validity) => _rolling_apply_agg_window_nulls::<
843                            rolling::nulls::MomentWindow<_, rolling::nulls::VarianceMoment>,
844                            _,
845                            _,
846                        >(
847                            values,
848                            validity,
849                            offset_iter,
850                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
851                        ),
852                    };
853
854                    let mut ca = ChunkedArray::<T>::from(arr);
855                    ca.apply_mut(|v| v.powf(NumCast::from(0.5).unwrap()));
856                    ca.into_series()
857                } else {
858                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
859                        debug_assert!(len <= self.len() as IdxSize);
860                        match len {
861                            0 => None,
862                            1 => {
863                                if ddof == 0 {
864                                    NumCast::from(0)
865                                } else {
866                                    None
867                                }
868                            },
869                            _ => {
870                                let arr_group = _slice_from_offsets(self, first, len);
871                                arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap())
872                            },
873                        }
874                    })
875                }
876            },
877        }
878    }
879}
880
881impl Float32Chunked {
882    pub(crate) unsafe fn agg_quantile(
883        &self,
884        groups: &GroupsType,
885        quantile: f64,
886        method: QuantileMethod,
887    ) -> Series {
888        agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method)
889    }
890    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
891        agg_median_generic::<_, Float32Type>(self, groups)
892    }
893}
894impl Float64Chunked {
895    pub(crate) unsafe fn agg_quantile(
896        &self,
897        groups: &GroupsType,
898        quantile: f64,
899        method: QuantileMethod,
900    ) -> Series {
901        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
902    }
903    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
904        agg_median_generic::<_, Float64Type>(self, groups)
905    }
906}
907
908impl<T> ChunkedArray<T>
909where
910    T: PolarsIntegerType,
911    ChunkedArray<T>: ChunkAgg<T::Native> + ChunkVar,
912    T::Native: NumericNative + Ord,
913{
914    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
915        match groups {
916            GroupsType::Idx(groups) => {
917                let ca = self.rechunk();
918                let arr = ca.downcast_get(0).unwrap();
919                _agg_helper_idx::<Float64Type, _>(groups, |(first, idx)| {
920                    // this can fail due to a bug in lazy code.
921                    // here users can create filters in aggregations
922                    // and thereby creating shorter columns than the original group tuples.
923                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
924                    // access
925                    debug_assert!(idx.len() <= self.len());
926                    if idx.is_empty() {
927                        None
928                    } else if idx.len() == 1 {
929                        self.get(first as usize).map(|sum| sum.to_f64().unwrap())
930                    } else {
931                        match (self.has_nulls(), self.chunks.len()) {
932                            (false, 1) => Some(
933                                take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
934                                    .fold(KahanSum::default(), |a, b| a + b.to_f64().unwrap())
935                                    .sum()
936                                    / idx.len() as f64,
937                            ),
938                            (_, 1) => {
939                                take_agg_primitive_iter_unchecked_count_nulls(
940                                    arr,
941                                    idx2usize(idx),
942                                    KahanSum::default(),
943                                    |a, b| a + b.to_f64().unwrap(),
944                                    idx.len() as IdxSize,
945                                )
946                            }
947                            .map(|(sum, null_count)| {
948                                sum.sum() / (idx.len() as f64 - null_count as f64)
949                            }),
950                            _ => {
951                                let take = { self.take_unchecked(idx) };
952                                take.mean()
953                            },
954                        }
955                    }
956                })
957            },
958            GroupsType::Slice {
959                groups: groups_slice,
960                ..
961            } => {
962                if _use_rolling_kernels(groups_slice, self.chunks()) {
963                    let ca = self
964                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
965                        .unwrap();
966                    ca.agg_mean(groups)
967                } else {
968                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
969                        debug_assert!(first + len <= self.len() as IdxSize);
970                        match len {
971                            0 => None,
972                            1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
973                            _ => {
974                                let arr_group = _slice_from_offsets(self, first, len);
975                                arr_group.mean()
976                            },
977                        }
978                    })
979                }
980            },
981        }
982    }
983
984    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series {
985        match groups {
986            GroupsType::Idx(groups) => {
987                let ca_self = self.rechunk();
988                let arr = ca_self.downcast_iter().next().unwrap();
989                let no_nulls = arr.null_count() == 0;
990                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
991                    debug_assert!(idx.len() <= arr.len());
992                    if idx.is_empty() {
993                        return None;
994                    }
995                    if no_nulls {
996                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
997                    } else {
998                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
999                    }
1000                })
1001            },
1002            GroupsType::Slice {
1003                groups: groups_slice,
1004                ..
1005            } => {
1006                if _use_rolling_kernels(groups_slice, self.chunks()) {
1007                    let ca = self
1008                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1009                        .unwrap();
1010                    ca.agg_var(groups, ddof)
1011                } else {
1012                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1013                        debug_assert!(first + len <= self.len() as IdxSize);
1014                        match len {
1015                            0 => None,
1016                            1 => {
1017                                if ddof == 0 {
1018                                    NumCast::from(0)
1019                                } else {
1020                                    None
1021                                }
1022                            },
1023                            _ => {
1024                                let arr_group = _slice_from_offsets(self, first, len);
1025                                arr_group.var(ddof)
1026                            },
1027                        }
1028                    })
1029                }
1030            },
1031        }
1032    }
1033    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series {
1034        match groups {
1035            GroupsType::Idx(groups) => {
1036                let ca_self = self.rechunk();
1037                let arr = ca_self.downcast_iter().next().unwrap();
1038                let no_nulls = arr.null_count() == 0;
1039                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1040                    debug_assert!(idx.len() <= self.len());
1041                    if idx.is_empty() {
1042                        return None;
1043                    }
1044                    let out = if no_nulls {
1045                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1046                    } else {
1047                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1048                    };
1049                    out.map(|v| v.sqrt())
1050                })
1051            },
1052            GroupsType::Slice {
1053                groups: groups_slice,
1054                ..
1055            } => {
1056                if _use_rolling_kernels(groups_slice, self.chunks()) {
1057                    let ca = self
1058                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1059                        .unwrap();
1060                    ca.agg_std(groups, ddof)
1061                } else {
1062                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1063                        debug_assert!(first + len <= self.len() as IdxSize);
1064                        match len {
1065                            0 => None,
1066                            1 => {
1067                                if ddof == 0 {
1068                                    NumCast::from(0)
1069                                } else {
1070                                    None
1071                                }
1072                            },
1073                            _ => {
1074                                let arr_group = _slice_from_offsets(self, first, len);
1075                                arr_group.std(ddof)
1076                            },
1077                        }
1078                    })
1079                }
1080            },
1081        }
1082    }
1083
1084    pub(crate) unsafe fn agg_quantile(
1085        &self,
1086        groups: &GroupsType,
1087        quantile: f64,
1088        method: QuantileMethod,
1089    ) -> Series {
1090        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
1091    }
1092    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
1093        agg_median_generic::<_, Float64Type>(self, groups)
1094    }
1095}