polars_core/frame/group_by/aggregations/
mod.rs

1mod agg_list;
2mod boolean;
3mod dispatch;
4mod string;
5
6use std::borrow::Cow;
7
8pub use agg_list::*;
9use arrow::bitmap::{Bitmap, MutableBitmap};
10use arrow::legacy::kernels::take_agg::*;
11use arrow::legacy::trusted_len::TrustedLenPush;
12use arrow::types::NativeType;
13use num_traits::pow::Pow;
14use num_traits::{Bounded, Float, Num, NumCast, ToPrimitive, Zero};
15use polars_compute::rolling::no_nulls::{
16    MaxWindow, MeanWindow, MinWindow, MomentWindow, QuantileWindow, RollingAggWindowNoNulls,
17    SumWindow,
18};
19use polars_compute::rolling::nulls::{RollingAggWindowNulls, VarianceMoment};
20use polars_compute::rolling::quantile_filter::SealedRolling;
21use polars_compute::rolling::{
22    self, QuantileMethod, RollingFnParams, RollingQuantileParams, RollingVarParams, quantile_filter,
23};
24use polars_utils::float::IsFloat;
25use polars_utils::idx_vec::IdxVec;
26use polars_utils::min_max::MinMax;
27use rayon::prelude::*;
28
29use crate::chunked_array::cast::CastOptions;
30#[cfg(feature = "object")]
31use crate::chunked_array::object::extension::create_extension;
32use crate::frame::group_by::GroupsIdx;
33#[cfg(feature = "object")]
34use crate::frame::group_by::GroupsIndicator;
35use crate::prelude::*;
36use crate::series::IsSorted;
37use crate::series::implementations::SeriesWrap;
38use crate::utils::NoNull;
39use crate::{POOL, apply_method_physical_integer};
40
41fn idx2usize(idx: &[IdxSize]) -> impl ExactSizeIterator<Item = usize> + '_ {
42    idx.iter().map(|i| *i as usize)
43}
44
45// if the windows overlap, we can use the rolling_<agg> kernels
46// they maintain state, which saves a lot of compute by not naively traversing all elements every
47// window
48//
49// if the windows don't overlap, we should not use these kernels as they are single threaded, so
50// we miss out on easy parallelization.
51pub fn _use_rolling_kernels(groups: &GroupsSlice, chunks: &[ArrayRef]) -> bool {
52    match groups.len() {
53        0 | 1 => false,
54        _ => {
55            let [first_offset, first_len] = groups[0];
56            let second_offset = groups[1][0];
57
58            second_offset >= first_offset // Prevent false positive from regular group-by that has out of order slices.
59                                          // Rolling group-by is expected to have monotonically increasing slices.
60                && second_offset < (first_offset + first_len)
61                && chunks.len() == 1
62        },
63    }
64}
65
66// Use an aggregation window that maintains the state
67pub fn _rolling_apply_agg_window_nulls<'a, Agg, T, O>(
68    values: &'a [T],
69    validity: &'a Bitmap,
70    offsets: O,
71    params: Option<RollingFnParams>,
72) -> PrimitiveArray<T>
73where
74    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
75    Agg: RollingAggWindowNulls<'a, T>,
76    T: IsFloat + NativeType,
77{
78    if values.is_empty() {
79        let out: Vec<T> = vec![];
80        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
81    }
82
83    // This iterators length can be trusted
84    // these represent the number of groups in the group_by operation
85    let output_len = offsets.size_hint().0;
86    // start with a dummy index, will be overwritten on first iteration.
87    // SAFETY:
88    // we are in bounds
89    let mut agg_window = unsafe { Agg::new(values, validity, 0, 0, params, None) };
90
91    let mut validity = MutableBitmap::with_capacity(output_len);
92    validity.extend_constant(output_len, true);
93
94    let out = offsets
95        .enumerate()
96        .map(|(idx, (start, len))| {
97            let end = start + len;
98
99            // SAFETY:
100            // we are in bounds
101
102            let agg = if start == end {
103                None
104            } else {
105                unsafe { agg_window.update(start as usize, end as usize) }
106            };
107
108            match agg {
109                Some(val) => val,
110                None => {
111                    // SAFETY: we are in bounds
112                    unsafe { validity.set_unchecked(idx, false) };
113                    T::default()
114                },
115            }
116        })
117        .collect_trusted::<Vec<_>>();
118
119    PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), Some(validity.into()))
120}
121
122// Use an aggregation window that maintains the state.
123pub fn _rolling_apply_agg_window_no_nulls<'a, Agg, T, O>(
124    values: &'a [T],
125    offsets: O,
126    params: Option<RollingFnParams>,
127) -> PrimitiveArray<T>
128where
129    // items (offset, len) -> so offsets are offset, offset + len
130    Agg: RollingAggWindowNoNulls<'a, T>,
131    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
132    T: IsFloat + NativeType,
133{
134    if values.is_empty() {
135        let out: Vec<T> = vec![];
136        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
137    }
138    // start with a dummy index, will be overwritten on first iteration.
139    let mut agg_window = Agg::new(values, 0, 0, params, None);
140
141    offsets
142        .map(|(start, len)| {
143            let end = start + len;
144
145            if start == end {
146                None
147            } else {
148                // SAFETY: we are in bounds.
149                unsafe { agg_window.update(start as usize, end as usize) }
150            }
151        })
152        .collect::<PrimitiveArray<T>>()
153}
154
155pub fn _slice_from_offsets<T>(ca: &ChunkedArray<T>, first: IdxSize, len: IdxSize) -> ChunkedArray<T>
156where
157    T: PolarsDataType,
158{
159    ca.slice(first as i64, len as usize)
160}
161
162/// Helper that combines the groups into a parallel iterator over `(first, all): (u32, &Vec<u32>)`.
163pub fn _agg_helper_idx<T, F>(groups: &GroupsIdx, f: F) -> Series
164where
165    F: Fn((IdxSize, &IdxVec)) -> Option<T::Native> + Send + Sync,
166    T: PolarsNumericType,
167{
168    let ca: ChunkedArray<T> = POOL.install(|| groups.into_par_iter().map(f).collect());
169    ca.into_series()
170}
171
172/// Same helper as `_agg_helper_idx` but for aggregations that don't return an Option.
173pub fn _agg_helper_idx_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
174where
175    F: Fn((IdxSize, &IdxVec)) -> T::Native + Send + Sync,
176    T: PolarsNumericType,
177{
178    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.into_par_iter().map(f).collect());
179    ca.into_inner().into_series()
180}
181
182/// Helper that iterates on the `all: Vec<Vec<u32>` collection,
183/// this doesn't have traverse the `first: Vec<u32>` memory and is therefore faster.
184fn agg_helper_idx_on_all<T, F>(groups: &GroupsIdx, f: F) -> Series
185where
186    F: Fn(&IdxVec) -> Option<T::Native> + Send + Sync,
187    T: PolarsNumericType,
188{
189    let ca: ChunkedArray<T> = POOL.install(|| groups.all().into_par_iter().map(f).collect());
190    ca.into_series()
191}
192
193/// Same as `agg_helper_idx_on_all` but for aggregations that don't return an Option.
194fn agg_helper_idx_on_all_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
195where
196    F: Fn(&IdxVec) -> T::Native + Send + Sync,
197    T: PolarsNumericType,
198{
199    let ca: NoNull<ChunkedArray<T>> =
200        POOL.install(|| groups.all().into_par_iter().map(f).collect());
201    ca.into_inner().into_series()
202}
203
204pub fn _agg_helper_slice<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
205where
206    F: Fn([IdxSize; 2]) -> Option<T::Native> + Send + Sync,
207    T: PolarsNumericType,
208{
209    let ca: ChunkedArray<T> = POOL.install(|| groups.par_iter().copied().map(f).collect());
210    ca.into_series()
211}
212
213pub fn _agg_helper_slice_no_null<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
214where
215    F: Fn([IdxSize; 2]) -> T::Native + Send + Sync,
216    T: PolarsNumericType,
217{
218    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.par_iter().copied().map(f).collect());
219    ca.into_inner().into_series()
220}
221
222/// Intermediate helper trait so we can have a single generic implementation
223/// This trait will ensure the specific dispatch works without complicating
224/// the trait bounds.
225trait QuantileDispatcher<K> {
226    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<K>>;
227
228    fn _median(self) -> Option<K>;
229}
230
231impl<T> QuantileDispatcher<f64> for ChunkedArray<T>
232where
233    T: PolarsIntegerType,
234    T::Native: Ord,
235{
236    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
237        self.quantile_faster(quantile, method)
238    }
239    fn _median(self) -> Option<f64> {
240        self.median_faster()
241    }
242}
243
244impl QuantileDispatcher<f32> for Float32Chunked {
245    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f32>> {
246        self.quantile_faster(quantile, method)
247    }
248    fn _median(self) -> Option<f32> {
249        self.median_faster()
250    }
251}
252impl QuantileDispatcher<f64> for Float64Chunked {
253    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
254        self.quantile_faster(quantile, method)
255    }
256    fn _median(self) -> Option<f64> {
257        self.median_faster()
258    }
259}
260
261unsafe fn agg_quantile_generic<T, K>(
262    ca: &ChunkedArray<T>,
263    groups: &GroupsType,
264    quantile: f64,
265    method: QuantileMethod,
266) -> Series
267where
268    T: PolarsNumericType,
269    ChunkedArray<T>: QuantileDispatcher<K::Native>,
270    K: PolarsNumericType,
271    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + quantile_filter::SealedRolling,
272{
273    let invalid_quantile = !(0.0..=1.0).contains(&quantile);
274    if invalid_quantile {
275        return Series::full_null(ca.name().clone(), groups.len(), ca.dtype());
276    }
277    match groups {
278        GroupsType::Idx(groups) => {
279            let ca = ca.rechunk();
280            agg_helper_idx_on_all::<K, _>(groups, |idx| {
281                debug_assert!(idx.len() <= ca.len());
282                if idx.is_empty() {
283                    return None;
284                }
285                let take = { ca.take_unchecked(idx) };
286                // checked with invalid quantile check
287                take._quantile(quantile, method).unwrap_unchecked()
288            })
289        },
290        GroupsType::Slice { groups, .. } => {
291            if _use_rolling_kernels(groups, ca.chunks()) {
292                // this cast is a no-op for floats
293                let s = ca
294                    .cast_with_options(&K::get_static_dtype(), CastOptions::Overflowing)
295                    .unwrap();
296                let ca: &ChunkedArray<K> = s.as_ref().as_ref();
297                let arr = ca.downcast_iter().next().unwrap();
298                let values = arr.values().as_slice();
299                let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
300                let arr = match arr.validity() {
301                    None => _rolling_apply_agg_window_no_nulls::<QuantileWindow<_>, _, _>(
302                        values,
303                        offset_iter,
304                        Some(RollingFnParams::Quantile(RollingQuantileParams {
305                            prob: quantile,
306                            method,
307                        })),
308                    ),
309                    Some(validity) => {
310                        _rolling_apply_agg_window_nulls::<rolling::nulls::QuantileWindow<_>, _, _>(
311                            values,
312                            validity,
313                            offset_iter,
314                            Some(RollingFnParams::Quantile(RollingQuantileParams {
315                                prob: quantile,
316                                method,
317                            })),
318                        )
319                    },
320                };
321                // The rolling kernels works on the dtype, this is not yet the
322                // float output type we need.
323                ChunkedArray::<K>::with_chunk(PlSmallStr::EMPTY, arr).into_series()
324            } else {
325                _agg_helper_slice::<K, _>(groups, |[first, len]| {
326                    debug_assert!(first + len <= ca.len() as IdxSize);
327                    match len {
328                        0 => None,
329                        1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()),
330                        _ => {
331                            let arr_group = _slice_from_offsets(ca, first, len);
332                            // unwrap checked with invalid quantile check
333                            arr_group
334                                ._quantile(quantile, method)
335                                .unwrap_unchecked()
336                                .map(|flt| NumCast::from(flt).unwrap_unchecked())
337                        },
338                    }
339                })
340            }
341        },
342    }
343}
344
345unsafe fn agg_median_generic<T, K>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series
346where
347    T: PolarsNumericType,
348    ChunkedArray<T>: QuantileDispatcher<K::Native>,
349    K: PolarsNumericType,
350    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + SealedRolling,
351{
352    match groups {
353        GroupsType::Idx(groups) => {
354            let ca = ca.rechunk();
355            agg_helper_idx_on_all::<K, _>(groups, |idx| {
356                debug_assert!(idx.len() <= ca.len());
357                if idx.is_empty() {
358                    return None;
359                }
360                let take = { ca.take_unchecked(idx) };
361                take._median()
362            })
363        },
364        GroupsType::Slice { .. } => {
365            agg_quantile_generic::<T, K>(ca, groups, 0.5, QuantileMethod::Linear)
366        },
367    }
368}
369
370/// # Safety
371///
372/// No bounds checks on `groups`.
373#[cfg(feature = "bitwise")]
374unsafe fn bitwise_agg<T: PolarsNumericType>(
375    ca: &ChunkedArray<T>,
376    groups: &GroupsType,
377    f: fn(&ChunkedArray<T>) -> Option<T::Native>,
378) -> Series
379where
380    ChunkedArray<T>: ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native>,
381{
382    // Prevent a rechunk for every individual group.
383
384    let s = if groups.len() > 1 {
385        ca.rechunk()
386    } else {
387        Cow::Borrowed(ca)
388    };
389
390    match groups {
391        GroupsType::Idx(groups) => agg_helper_idx_on_all::<T, _>(groups, |idx| {
392            debug_assert!(idx.len() <= s.len());
393            if idx.is_empty() {
394                None
395            } else {
396                let take = unsafe { s.take_unchecked(idx) };
397                f(&take)
398            }
399        }),
400        GroupsType::Slice { groups, .. } => _agg_helper_slice::<T, _>(groups, |[first, len]| {
401            debug_assert!(len <= s.len() as IdxSize);
402            if len == 0 {
403                None
404            } else {
405                let take = _slice_from_offsets(&s, first, len);
406                f(&take)
407            }
408        }),
409    }
410}
411
412#[cfg(feature = "bitwise")]
413impl<T> ChunkedArray<T>
414where
415    T: PolarsNumericType,
416    ChunkedArray<T>: ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native>,
417{
418    /// # Safety
419    ///
420    /// No bounds checks on `groups`.
421    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series {
422        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce) }
423    }
424
425    /// # Safety
426    ///
427    /// No bounds checks on `groups`.
428    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series {
429        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce) }
430    }
431
432    /// # Safety
433    ///
434    /// No bounds checks on `groups`.
435    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series {
436        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce) }
437    }
438}
439
440impl<T> ChunkedArray<T>
441where
442    T: PolarsNumericType + Sync,
443    T::Native: NativeType + PartialOrd + Num + NumCast + Zero + Bounded + std::iter::Sum<T::Native>,
444    ChunkedArray<T>: ChunkAgg<T::Native>,
445{
446    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
447        // faster paths
448        match (self.is_sorted_flag(), self.null_count()) {
449            (IsSorted::Ascending, 0) => {
450                return self.clone().into_series().agg_first(groups);
451            },
452            (IsSorted::Descending, 0) => {
453                return self.clone().into_series().agg_last(groups);
454            },
455            _ => {},
456        }
457        match groups {
458            GroupsType::Idx(groups) => {
459                let ca = self.rechunk();
460                let arr = ca.downcast_iter().next().unwrap();
461                let no_nulls = arr.null_count() == 0;
462                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
463                    debug_assert!(idx.len() <= arr.len());
464                    if idx.is_empty() {
465                        None
466                    } else if idx.len() == 1 {
467                        arr.get(first as usize)
468                    } else if no_nulls {
469                        take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
470                            arr,
471                            idx2usize(idx),
472                            |a, b| a.min_ignore_nan(b),
473                        )
474                    } else {
475                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| {
476                            a.min_ignore_nan(b)
477                        })
478                    }
479                })
480            },
481            GroupsType::Slice {
482                groups: groups_slice,
483                ..
484            } => {
485                if _use_rolling_kernels(groups_slice, self.chunks()) {
486                    let arr = self.downcast_iter().next().unwrap();
487                    let values = arr.values().as_slice();
488                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
489                    let arr = match arr.validity() {
490                        None => _rolling_apply_agg_window_no_nulls::<MinWindow<_>, _, _>(
491                            values,
492                            offset_iter,
493                            None,
494                        ),
495                        Some(validity) => _rolling_apply_agg_window_nulls::<
496                            rolling::nulls::MinWindow<_>,
497                            _,
498                            _,
499                        >(
500                            values, validity, offset_iter, None
501                        ),
502                    };
503                    Self::from(arr).into_series()
504                } else {
505                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
506                        debug_assert!(len <= self.len() as IdxSize);
507                        match len {
508                            0 => None,
509                            1 => self.get(first as usize),
510                            _ => {
511                                let arr_group = _slice_from_offsets(self, first, len);
512                                ChunkAgg::min(&arr_group)
513                            },
514                        }
515                    })
516                }
517            },
518        }
519    }
520
521    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
522        // faster paths
523        match (self.is_sorted_flag(), self.null_count()) {
524            (IsSorted::Ascending, 0) => {
525                return self.clone().into_series().agg_last(groups);
526            },
527            (IsSorted::Descending, 0) => {
528                return self.clone().into_series().agg_first(groups);
529            },
530            _ => {},
531        }
532
533        match groups {
534            GroupsType::Idx(groups) => {
535                let ca = self.rechunk();
536                let arr = ca.downcast_iter().next().unwrap();
537                let no_nulls = arr.null_count() == 0;
538                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
539                    debug_assert!(idx.len() <= arr.len());
540                    if idx.is_empty() {
541                        None
542                    } else if idx.len() == 1 {
543                        arr.get(first as usize)
544                    } else if no_nulls {
545                        take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
546                            arr,
547                            idx2usize(idx),
548                            |a, b| a.max_ignore_nan(b),
549                        )
550                    } else {
551                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| {
552                            a.max_ignore_nan(b)
553                        })
554                    }
555                })
556            },
557            GroupsType::Slice {
558                groups: groups_slice,
559                ..
560            } => {
561                if _use_rolling_kernels(groups_slice, self.chunks()) {
562                    let arr = self.downcast_iter().next().unwrap();
563                    let values = arr.values().as_slice();
564                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
565                    let arr = match arr.validity() {
566                        None => _rolling_apply_agg_window_no_nulls::<MaxWindow<_>, _, _>(
567                            values,
568                            offset_iter,
569                            None,
570                        ),
571                        Some(validity) => _rolling_apply_agg_window_nulls::<
572                            rolling::nulls::MaxWindow<_>,
573                            _,
574                            _,
575                        >(
576                            values, validity, offset_iter, None
577                        ),
578                    };
579                    Self::from(arr).into_series()
580                } else {
581                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
582                        debug_assert!(len <= self.len() as IdxSize);
583                        match len {
584                            0 => None,
585                            1 => self.get(first as usize),
586                            _ => {
587                                let arr_group = _slice_from_offsets(self, first, len);
588                                ChunkAgg::max(&arr_group)
589                            },
590                        }
591                    })
592                }
593            },
594        }
595    }
596
597    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
598        match groups {
599            GroupsType::Idx(groups) => {
600                let ca = self.rechunk();
601                let arr = ca.downcast_iter().next().unwrap();
602                let no_nulls = arr.null_count() == 0;
603                _agg_helper_idx_no_null::<T, _>(groups, |(first, idx)| {
604                    debug_assert!(idx.len() <= self.len());
605                    if idx.is_empty() {
606                        T::Native::zero()
607                    } else if idx.len() == 1 {
608                        arr.get(first as usize).unwrap_or(T::Native::zero())
609                    } else if no_nulls {
610                        take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a + b)
611                            .unwrap_or(T::Native::zero())
612                    } else {
613                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a + b)
614                            .unwrap_or(T::Native::zero())
615                    }
616                })
617            },
618            GroupsType::Slice { groups, .. } => {
619                if _use_rolling_kernels(groups, self.chunks()) {
620                    let arr = self.downcast_iter().next().unwrap();
621                    let values = arr.values().as_slice();
622                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
623                    let arr = match arr.validity() {
624                        None => _rolling_apply_agg_window_no_nulls::<
625                            SumWindow<T::Native, T::Native>,
626                            _,
627                            _,
628                        >(values, offset_iter, None),
629                        Some(validity) => {
630                            _rolling_apply_agg_window_nulls::<
631                                rolling::nulls::SumWindow<T::Native, T::Native>,
632                                _,
633                                _,
634                            >(values, validity, offset_iter, None)
635                        },
636                    };
637                    Self::from(arr).into_series()
638                } else {
639                    _agg_helper_slice_no_null::<T, _>(groups, |[first, len]| {
640                        debug_assert!(len <= self.len() as IdxSize);
641                        match len {
642                            0 => T::Native::zero(),
643                            1 => self.get(first as usize).unwrap_or(T::Native::zero()),
644                            _ => {
645                                let arr_group = _slice_from_offsets(self, first, len);
646                                arr_group.sum().unwrap_or(T::Native::zero())
647                            },
648                        }
649                    })
650                }
651            },
652        }
653    }
654}
655
656impl<T> SeriesWrap<ChunkedArray<T>>
657where
658    T: PolarsFloatType,
659    ChunkedArray<T>: ChunkVar
660        + VarAggSeries
661        + ChunkQuantile<T::Native>
662        + QuantileAggSeries
663        + ChunkAgg<T::Native>,
664    T::Native: Pow<T::Native, Output = T::Native>,
665{
666    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
667        match groups {
668            GroupsType::Idx(groups) => {
669                let ca = self.rechunk();
670                let arr = ca.downcast_iter().next().unwrap();
671                let no_nulls = arr.null_count() == 0;
672                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
673                    // this can fail due to a bug in lazy code.
674                    // here users can create filters in aggregations
675                    // and thereby creating shorter columns than the original group tuples.
676                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
677                    // access
678                    debug_assert!(idx.len() <= self.len());
679                    let out = if idx.is_empty() {
680                        None
681                    } else if idx.len() == 1 {
682                        arr.get(first as usize).map(|sum| sum.to_f64().unwrap())
683                    } else if no_nulls {
684                        take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
685                            arr,
686                            idx2usize(idx),
687                            |a, b| a + b,
688                        )
689                        .unwrap()
690                        .to_f64()
691                        .map(|sum| sum / idx.len() as f64)
692                    } else {
693                        take_agg_primitive_iter_unchecked_count_nulls::<T::Native, _, _, _>(
694                            arr,
695                            idx2usize(idx),
696                            |a, b| a + b,
697                            T::Native::zero(),
698                            idx.len() as IdxSize,
699                        )
700                        .map(|(sum, null_count)| {
701                            sum.to_f64()
702                                .map(|sum| sum / (idx.len() as f64 - null_count as f64))
703                                .unwrap()
704                        })
705                    };
706                    out.map(|flt| NumCast::from(flt).unwrap())
707                })
708            },
709            GroupsType::Slice { groups, .. } => {
710                if _use_rolling_kernels(groups, self.chunks()) {
711                    let arr = self.downcast_iter().next().unwrap();
712                    let values = arr.values().as_slice();
713                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
714                    let arr = match arr.validity() {
715                        None => _rolling_apply_agg_window_no_nulls::<MeanWindow<_>, _, _>(
716                            values,
717                            offset_iter,
718                            None,
719                        ),
720                        Some(validity) => _rolling_apply_agg_window_nulls::<
721                            rolling::nulls::MeanWindow<_>,
722                            _,
723                            _,
724                        >(
725                            values, validity, offset_iter, None
726                        ),
727                    };
728                    ChunkedArray::<T>::from(arr).into_series()
729                } else {
730                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
731                        debug_assert!(len <= self.len() as IdxSize);
732                        match len {
733                            0 => None,
734                            1 => self.get(first as usize),
735                            _ => {
736                                let arr_group = _slice_from_offsets(self, first, len);
737                                arr_group.mean().map(|flt| NumCast::from(flt).unwrap())
738                            },
739                        }
740                    })
741                }
742            },
743        }
744    }
745
746    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series
747    where
748        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
749    {
750        let ca = &self.0.rechunk();
751        match groups {
752            GroupsType::Idx(groups) => {
753                let ca = ca.rechunk();
754                let arr = ca.downcast_iter().next().unwrap();
755                let no_nulls = arr.null_count() == 0;
756                agg_helper_idx_on_all::<T, _>(groups, |idx| {
757                    debug_assert!(idx.len() <= ca.len());
758                    if idx.is_empty() {
759                        return None;
760                    }
761                    let out = if no_nulls {
762                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
763                    } else {
764                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
765                    };
766                    out.map(|flt| NumCast::from(flt).unwrap())
767                })
768            },
769            GroupsType::Slice { groups, .. } => {
770                if _use_rolling_kernels(groups, self.chunks()) {
771                    let arr = self.downcast_iter().next().unwrap();
772                    let values = arr.values().as_slice();
773                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
774                    let arr = match arr.validity() {
775                        None => _rolling_apply_agg_window_no_nulls::<
776                            MomentWindow<_, VarianceMoment>,
777                            _,
778                            _,
779                        >(
780                            values,
781                            offset_iter,
782                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
783                        ),
784                        Some(validity) => _rolling_apply_agg_window_nulls::<
785                            rolling::nulls::MomentWindow<_, VarianceMoment>,
786                            _,
787                            _,
788                        >(
789                            values,
790                            validity,
791                            offset_iter,
792                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
793                        ),
794                    };
795                    ChunkedArray::<T>::from(arr).into_series()
796                } else {
797                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
798                        debug_assert!(len <= self.len() as IdxSize);
799                        match len {
800                            0 => None,
801                            1 => {
802                                if ddof == 0 {
803                                    NumCast::from(0)
804                                } else {
805                                    None
806                                }
807                            },
808                            _ => {
809                                let arr_group = _slice_from_offsets(self, first, len);
810                                arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap())
811                            },
812                        }
813                    })
814                }
815            },
816        }
817    }
818    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series
819    where
820        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
821    {
822        let ca = &self.0.rechunk();
823        match groups {
824            GroupsType::Idx(groups) => {
825                let arr = ca.downcast_iter().next().unwrap();
826                let no_nulls = arr.null_count() == 0;
827                agg_helper_idx_on_all::<T, _>(groups, |idx| {
828                    debug_assert!(idx.len() <= ca.len());
829                    if idx.is_empty() {
830                        return None;
831                    }
832                    let out = if no_nulls {
833                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
834                    } else {
835                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
836                    };
837                    out.map(|flt| NumCast::from(flt.sqrt()).unwrap())
838                })
839            },
840            GroupsType::Slice { groups, .. } => {
841                if _use_rolling_kernels(groups, self.chunks()) {
842                    let arr = ca.downcast_iter().next().unwrap();
843                    let values = arr.values().as_slice();
844                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
845                    let arr = match arr.validity() {
846                        None => _rolling_apply_agg_window_no_nulls::<
847                            MomentWindow<_, VarianceMoment>,
848                            _,
849                            _,
850                        >(
851                            values,
852                            offset_iter,
853                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
854                        ),
855                        Some(validity) => _rolling_apply_agg_window_nulls::<
856                            rolling::nulls::MomentWindow<_, rolling::nulls::VarianceMoment>,
857                            _,
858                            _,
859                        >(
860                            values,
861                            validity,
862                            offset_iter,
863                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
864                        ),
865                    };
866
867                    let mut ca = ChunkedArray::<T>::from(arr);
868                    ca.apply_mut(|v| v.powf(NumCast::from(0.5).unwrap()));
869                    ca.into_series()
870                } else {
871                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
872                        debug_assert!(len <= self.len() as IdxSize);
873                        match len {
874                            0 => None,
875                            1 => {
876                                if ddof == 0 {
877                                    NumCast::from(0)
878                                } else {
879                                    None
880                                }
881                            },
882                            _ => {
883                                let arr_group = _slice_from_offsets(self, first, len);
884                                arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap())
885                            },
886                        }
887                    })
888                }
889            },
890        }
891    }
892}
893
894impl Float32Chunked {
895    pub(crate) unsafe fn agg_quantile(
896        &self,
897        groups: &GroupsType,
898        quantile: f64,
899        method: QuantileMethod,
900    ) -> Series {
901        agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method)
902    }
903    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
904        agg_median_generic::<_, Float32Type>(self, groups)
905    }
906}
907impl Float64Chunked {
908    pub(crate) unsafe fn agg_quantile(
909        &self,
910        groups: &GroupsType,
911        quantile: f64,
912        method: QuantileMethod,
913    ) -> Series {
914        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
915    }
916    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
917        agg_median_generic::<_, Float64Type>(self, groups)
918    }
919}
920
921impl<T> ChunkedArray<T>
922where
923    T: PolarsIntegerType,
924    ChunkedArray<T>: ChunkAgg<T::Native> + ChunkVar,
925    T::Native: NumericNative + Ord,
926{
927    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
928        match groups {
929            GroupsType::Idx(groups) => {
930                let ca = self.rechunk();
931                let arr = ca.downcast_get(0).unwrap();
932                _agg_helper_idx::<Float64Type, _>(groups, |(first, idx)| {
933                    // this can fail due to a bug in lazy code.
934                    // here users can create filters in aggregations
935                    // and thereby creating shorter columns than the original group tuples.
936                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
937                    // access
938                    debug_assert!(idx.len() <= self.len());
939                    if idx.is_empty() {
940                        None
941                    } else if idx.len() == 1 {
942                        self.get(first as usize).map(|sum| sum.to_f64().unwrap())
943                    } else {
944                        match (self.has_nulls(), self.chunks.len()) {
945                            (false, 1) => {
946                                take_agg_no_null_primitive_iter_unchecked::<_, f64, _, _>(
947                                    arr,
948                                    idx2usize(idx),
949                                    |a, b| a + b,
950                                )
951                                .map(|sum| sum / idx.len() as f64)
952                            },
953                            (_, 1) => {
954                                {
955                                    take_agg_primitive_iter_unchecked_count_nulls::<
956                                        T::Native,
957                                        f64,
958                                        _,
959                                        _,
960                                    >(
961                                        arr, idx2usize(idx), |a, b| a + b, 0.0, idx.len() as IdxSize
962                                    )
963                                }
964                                .map(|(sum, null_count)| {
965                                    sum / (idx.len() as f64 - null_count as f64)
966                                })
967                            },
968                            _ => {
969                                let take = { self.take_unchecked(idx) };
970                                take.mean()
971                            },
972                        }
973                    }
974                })
975            },
976            GroupsType::Slice {
977                groups: groups_slice,
978                ..
979            } => {
980                if _use_rolling_kernels(groups_slice, self.chunks()) {
981                    let ca = self
982                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
983                        .unwrap();
984                    ca.agg_mean(groups)
985                } else {
986                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
987                        debug_assert!(first + len <= self.len() as IdxSize);
988                        match len {
989                            0 => None,
990                            1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
991                            _ => {
992                                let arr_group = _slice_from_offsets(self, first, len);
993                                arr_group.mean()
994                            },
995                        }
996                    })
997                }
998            },
999        }
1000    }
1001
1002    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series {
1003        match groups {
1004            GroupsType::Idx(groups) => {
1005                let ca_self = self.rechunk();
1006                let arr = ca_self.downcast_iter().next().unwrap();
1007                let no_nulls = arr.null_count() == 0;
1008                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1009                    debug_assert!(idx.len() <= arr.len());
1010                    if idx.is_empty() {
1011                        return None;
1012                    }
1013                    if no_nulls {
1014                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1015                    } else {
1016                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1017                    }
1018                })
1019            },
1020            GroupsType::Slice {
1021                groups: groups_slice,
1022                ..
1023            } => {
1024                if _use_rolling_kernels(groups_slice, self.chunks()) {
1025                    let ca = self
1026                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1027                        .unwrap();
1028                    ca.agg_var(groups, ddof)
1029                } else {
1030                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1031                        debug_assert!(first + len <= self.len() as IdxSize);
1032                        match len {
1033                            0 => None,
1034                            1 => {
1035                                if ddof == 0 {
1036                                    NumCast::from(0)
1037                                } else {
1038                                    None
1039                                }
1040                            },
1041                            _ => {
1042                                let arr_group = _slice_from_offsets(self, first, len);
1043                                arr_group.var(ddof)
1044                            },
1045                        }
1046                    })
1047                }
1048            },
1049        }
1050    }
1051    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series {
1052        match groups {
1053            GroupsType::Idx(groups) => {
1054                let ca_self = self.rechunk();
1055                let arr = ca_self.downcast_iter().next().unwrap();
1056                let no_nulls = arr.null_count() == 0;
1057                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1058                    debug_assert!(idx.len() <= self.len());
1059                    if idx.is_empty() {
1060                        return None;
1061                    }
1062                    let out = if no_nulls {
1063                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1064                    } else {
1065                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1066                    };
1067                    out.map(|v| v.sqrt())
1068                })
1069            },
1070            GroupsType::Slice {
1071                groups: groups_slice,
1072                ..
1073            } => {
1074                if _use_rolling_kernels(groups_slice, self.chunks()) {
1075                    let ca = self
1076                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1077                        .unwrap();
1078                    ca.agg_std(groups, ddof)
1079                } else {
1080                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1081                        debug_assert!(first + len <= self.len() as IdxSize);
1082                        match len {
1083                            0 => None,
1084                            1 => {
1085                                if ddof == 0 {
1086                                    NumCast::from(0)
1087                                } else {
1088                                    None
1089                                }
1090                            },
1091                            _ => {
1092                                let arr_group = _slice_from_offsets(self, first, len);
1093                                arr_group.std(ddof)
1094                            },
1095                        }
1096                    })
1097                }
1098            },
1099        }
1100    }
1101
1102    pub(crate) unsafe fn agg_quantile(
1103        &self,
1104        groups: &GroupsType,
1105        quantile: f64,
1106        method: QuantileMethod,
1107    ) -> Series {
1108        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
1109    }
1110    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
1111        agg_median_generic::<_, Float64Type>(self, groups)
1112    }
1113}