polars_core/frame/group_by/aggregations/
mod.rs

1mod agg_list;
2mod boolean;
3mod dispatch;
4mod string;
5
6use std::borrow::Cow;
7
8pub use agg_list::*;
9use arrow::bitmap::{Bitmap, MutableBitmap};
10use arrow::legacy::kernels::take_agg::*;
11use arrow::legacy::trusted_len::TrustedLenPush;
12use arrow::types::NativeType;
13use num_traits::pow::Pow;
14use num_traits::{Bounded, Float, Num, NumCast, ToPrimitive, Zero};
15use polars_compute::rolling::no_nulls::{
16    MaxWindow, MeanWindow, MinWindow, MomentWindow, QuantileWindow, RollingAggWindowNoNulls,
17    SumWindow,
18};
19use polars_compute::rolling::nulls::{RollingAggWindowNulls, VarianceMoment};
20use polars_compute::rolling::quantile_filter::SealedRolling;
21use polars_compute::rolling::{
22    self, QuantileMethod, RollingFnParams, RollingQuantileParams, RollingVarParams, quantile_filter,
23};
24use polars_utils::float::IsFloat;
25use polars_utils::idx_vec::IdxVec;
26use polars_utils::min_max::MinMax;
27use rayon::prelude::*;
28
29use crate::chunked_array::cast::CastOptions;
30#[cfg(feature = "object")]
31use crate::chunked_array::object::extension::create_extension;
32use crate::frame::group_by::GroupsIdx;
33#[cfg(feature = "object")]
34use crate::frame::group_by::GroupsIndicator;
35use crate::prelude::*;
36use crate::series::IsSorted;
37use crate::series::implementations::SeriesWrap;
38use crate::utils::NoNull;
39use crate::{POOL, apply_method_physical_integer};
40
41fn idx2usize(idx: &[IdxSize]) -> impl ExactSizeIterator<Item = usize> + '_ {
42    idx.iter().map(|i| *i as usize)
43}
44
45// if the windows overlap, we can use the rolling_<agg> kernels
46// they maintain state, which saves a lot of compute by not naively traversing all elements every
47// window
48//
49// if the windows don't overlap, we should not use these kernels as they are single threaded, so
50// we miss out on easy parallelization.
51pub fn _use_rolling_kernels(groups: &GroupsSlice, chunks: &[ArrayRef]) -> bool {
52    match groups.len() {
53        0 | 1 => false,
54        _ => {
55            let [first_offset, first_len] = groups[0];
56            let second_offset = groups[1][0];
57
58            second_offset >= first_offset // Prevent false positive from regular group-by that has out of order slices.
59                                          // Rolling group-by is expected to have monotonically increasing slices.
60                && second_offset < (first_offset + first_len)
61                && chunks.len() == 1
62        },
63    }
64}
65
66// Use an aggregation window that maintains the state
67pub fn _rolling_apply_agg_window_nulls<'a, Agg, T, O>(
68    values: &'a [T],
69    validity: &'a Bitmap,
70    offsets: O,
71    params: Option<RollingFnParams>,
72) -> PrimitiveArray<T>
73where
74    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
75    Agg: RollingAggWindowNulls<'a, T>,
76    T: IsFloat + NativeType,
77{
78    if values.is_empty() {
79        let out: Vec<T> = vec![];
80        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
81    }
82
83    // This iterators length can be trusted
84    // these represent the number of groups in the group_by operation
85    let output_len = offsets.size_hint().0;
86    // start with a dummy index, will be overwritten on first iteration.
87    // SAFETY:
88    // we are in bounds
89    let mut agg_window = unsafe { Agg::new(values, validity, 0, 0, params, None) };
90
91    let mut validity = MutableBitmap::with_capacity(output_len);
92    validity.extend_constant(output_len, true);
93
94    let out = offsets
95        .enumerate()
96        .map(|(idx, (start, len))| {
97            let end = start + len;
98
99            // SAFETY:
100            // we are in bounds
101
102            let agg = if start == end {
103                None
104            } else {
105                unsafe { agg_window.update(start as usize, end as usize) }
106            };
107
108            match agg {
109                Some(val) => val,
110                None => {
111                    // SAFETY: we are in bounds
112                    unsafe { validity.set_unchecked(idx, false) };
113                    T::default()
114                },
115            }
116        })
117        .collect_trusted::<Vec<_>>();
118
119    PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), Some(validity.into()))
120}
121
122// Use an aggregation window that maintains the state.
123pub fn _rolling_apply_agg_window_no_nulls<'a, Agg, T, O>(
124    values: &'a [T],
125    offsets: O,
126    params: Option<RollingFnParams>,
127) -> PrimitiveArray<T>
128where
129    // items (offset, len) -> so offsets are offset, offset + len
130    Agg: RollingAggWindowNoNulls<'a, T>,
131    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
132    T: IsFloat + NativeType,
133{
134    if values.is_empty() {
135        let out: Vec<T> = vec![];
136        return PrimitiveArray::new(T::PRIMITIVE.into(), out.into(), None);
137    }
138    // start with a dummy index, will be overwritten on first iteration.
139    let mut agg_window = Agg::new(values, 0, 0, params, None);
140
141    offsets
142        .map(|(start, len)| {
143            let end = start + len;
144
145            if start == end {
146                None
147            } else {
148                // SAFETY: we are in bounds.
149                unsafe { agg_window.update(start as usize, end as usize) }
150            }
151        })
152        .collect::<PrimitiveArray<T>>()
153}
154
155pub fn _slice_from_offsets<T>(ca: &ChunkedArray<T>, first: IdxSize, len: IdxSize) -> ChunkedArray<T>
156where
157    T: PolarsDataType,
158{
159    ca.slice(first as i64, len as usize)
160}
161
162/// Helper that combines the groups into a parallel iterator over `(first, all): (u32, &Vec<u32>)`.
163pub fn _agg_helper_idx<T, F>(groups: &GroupsIdx, f: F) -> Series
164where
165    F: Fn((IdxSize, &IdxVec)) -> Option<T::Native> + Send + Sync,
166    T: PolarsNumericType,
167    ChunkedArray<T>: IntoSeries,
168{
169    let ca: ChunkedArray<T> = POOL.install(|| groups.into_par_iter().map(f).collect());
170    ca.into_series()
171}
172
173/// Same helper as `_agg_helper_idx` but for aggregations that don't return an Option.
174pub fn _agg_helper_idx_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
175where
176    F: Fn((IdxSize, &IdxVec)) -> T::Native + Send + Sync,
177    T: PolarsNumericType,
178    ChunkedArray<T>: IntoSeries,
179{
180    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.into_par_iter().map(f).collect());
181    ca.into_inner().into_series()
182}
183
184/// Helper that iterates on the `all: Vec<Vec<u32>` collection,
185/// this doesn't have traverse the `first: Vec<u32>` memory and is therefore faster.
186fn agg_helper_idx_on_all<T, F>(groups: &GroupsIdx, f: F) -> Series
187where
188    F: Fn(&IdxVec) -> Option<T::Native> + Send + Sync,
189    T: PolarsNumericType,
190    ChunkedArray<T>: IntoSeries,
191{
192    let ca: ChunkedArray<T> = POOL.install(|| groups.all().into_par_iter().map(f).collect());
193    ca.into_series()
194}
195
196/// Same as `agg_helper_idx_on_all` but for aggregations that don't return an Option.
197fn agg_helper_idx_on_all_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
198where
199    F: Fn(&IdxVec) -> T::Native + Send + Sync,
200    T: PolarsNumericType,
201    ChunkedArray<T>: IntoSeries,
202{
203    let ca: NoNull<ChunkedArray<T>> =
204        POOL.install(|| groups.all().into_par_iter().map(f).collect());
205    ca.into_inner().into_series()
206}
207
208pub fn _agg_helper_slice<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
209where
210    F: Fn([IdxSize; 2]) -> Option<T::Native> + Send + Sync,
211    T: PolarsNumericType,
212    ChunkedArray<T>: IntoSeries,
213{
214    let ca: ChunkedArray<T> = POOL.install(|| groups.par_iter().copied().map(f).collect());
215    ca.into_series()
216}
217
218pub fn _agg_helper_slice_no_null<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
219where
220    F: Fn([IdxSize; 2]) -> T::Native + Send + Sync,
221    T: PolarsNumericType,
222    ChunkedArray<T>: IntoSeries,
223{
224    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.par_iter().copied().map(f).collect());
225    ca.into_inner().into_series()
226}
227
228/// Intermediate helper trait so we can have a single generic implementation
229/// This trait will ensure the specific dispatch works without complicating
230/// the trait bounds.
231trait QuantileDispatcher<K> {
232    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<K>>;
233
234    fn _median(self) -> Option<K>;
235}
236
237impl<T> QuantileDispatcher<f64> for ChunkedArray<T>
238where
239    T: PolarsIntegerType,
240    T::Native: Ord,
241    ChunkedArray<T>: IntoSeries,
242{
243    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
244        self.quantile_faster(quantile, method)
245    }
246    fn _median(self) -> Option<f64> {
247        self.median_faster()
248    }
249}
250
251impl QuantileDispatcher<f32> for Float32Chunked {
252    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f32>> {
253        self.quantile_faster(quantile, method)
254    }
255    fn _median(self) -> Option<f32> {
256        self.median_faster()
257    }
258}
259impl QuantileDispatcher<f64> for Float64Chunked {
260    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
261        self.quantile_faster(quantile, method)
262    }
263    fn _median(self) -> Option<f64> {
264        self.median_faster()
265    }
266}
267
268unsafe fn agg_quantile_generic<T, K>(
269    ca: &ChunkedArray<T>,
270    groups: &GroupsType,
271    quantile: f64,
272    method: QuantileMethod,
273) -> Series
274where
275    T: PolarsNumericType,
276    ChunkedArray<T>: QuantileDispatcher<K::Native>,
277    ChunkedArray<K>: IntoSeries,
278    K: PolarsNumericType,
279    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + quantile_filter::SealedRolling,
280{
281    let invalid_quantile = !(0.0..=1.0).contains(&quantile);
282    if invalid_quantile {
283        return Series::full_null(ca.name().clone(), groups.len(), ca.dtype());
284    }
285    match groups {
286        GroupsType::Idx(groups) => {
287            let ca = ca.rechunk();
288            agg_helper_idx_on_all::<K, _>(groups, |idx| {
289                debug_assert!(idx.len() <= ca.len());
290                if idx.is_empty() {
291                    return None;
292                }
293                let take = { ca.take_unchecked(idx) };
294                // checked with invalid quantile check
295                take._quantile(quantile, method).unwrap_unchecked()
296            })
297        },
298        GroupsType::Slice { groups, .. } => {
299            if _use_rolling_kernels(groups, ca.chunks()) {
300                // this cast is a no-op for floats
301                let s = ca
302                    .cast_with_options(&K::get_dtype(), CastOptions::Overflowing)
303                    .unwrap();
304                let ca: &ChunkedArray<K> = s.as_ref().as_ref();
305                let arr = ca.downcast_iter().next().unwrap();
306                let values = arr.values().as_slice();
307                let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
308                let arr = match arr.validity() {
309                    None => _rolling_apply_agg_window_no_nulls::<QuantileWindow<_>, _, _>(
310                        values,
311                        offset_iter,
312                        Some(RollingFnParams::Quantile(RollingQuantileParams {
313                            prob: quantile,
314                            method,
315                        })),
316                    ),
317                    Some(validity) => {
318                        _rolling_apply_agg_window_nulls::<rolling::nulls::QuantileWindow<_>, _, _>(
319                            values,
320                            validity,
321                            offset_iter,
322                            Some(RollingFnParams::Quantile(RollingQuantileParams {
323                                prob: quantile,
324                                method,
325                            })),
326                        )
327                    },
328                };
329                // The rolling kernels works on the dtype, this is not yet the
330                // float output type we need.
331                ChunkedArray::from(arr).into_series()
332            } else {
333                _agg_helper_slice::<K, _>(groups, |[first, len]| {
334                    debug_assert!(first + len <= ca.len() as IdxSize);
335                    match len {
336                        0 => None,
337                        1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()),
338                        _ => {
339                            let arr_group = _slice_from_offsets(ca, first, len);
340                            // unwrap checked with invalid quantile check
341                            arr_group
342                                ._quantile(quantile, method)
343                                .unwrap_unchecked()
344                                .map(|flt| NumCast::from(flt).unwrap_unchecked())
345                        },
346                    }
347                })
348            }
349        },
350    }
351}
352
353unsafe fn agg_median_generic<T, K>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series
354where
355    T: PolarsNumericType,
356    ChunkedArray<T>: QuantileDispatcher<K::Native>,
357    ChunkedArray<K>: IntoSeries,
358    K: PolarsNumericType,
359    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + SealedRolling,
360{
361    match groups {
362        GroupsType::Idx(groups) => {
363            let ca = ca.rechunk();
364            agg_helper_idx_on_all::<K, _>(groups, |idx| {
365                debug_assert!(idx.len() <= ca.len());
366                if idx.is_empty() {
367                    return None;
368                }
369                let take = { ca.take_unchecked(idx) };
370                take._median()
371            })
372        },
373        GroupsType::Slice { .. } => {
374            agg_quantile_generic::<T, K>(ca, groups, 0.5, QuantileMethod::Linear)
375        },
376    }
377}
378
379/// # Safety
380///
381/// No bounds checks on `groups`.
382#[cfg(feature = "bitwise")]
383unsafe fn bitwise_agg<T: PolarsNumericType>(
384    ca: &ChunkedArray<T>,
385    groups: &GroupsType,
386    f: fn(&ChunkedArray<T>) -> Option<T::Native>,
387) -> Series
388where
389    ChunkedArray<T>:
390        ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native> + IntoSeries,
391{
392    // Prevent a rechunk for every individual group.
393
394    let s = if groups.len() > 1 {
395        ca.rechunk()
396    } else {
397        Cow::Borrowed(ca)
398    };
399
400    match groups {
401        GroupsType::Idx(groups) => agg_helper_idx_on_all::<T, _>(groups, |idx| {
402            debug_assert!(idx.len() <= s.len());
403            if idx.is_empty() {
404                None
405            } else {
406                let take = unsafe { s.take_unchecked(idx) };
407                f(&take)
408            }
409        }),
410        GroupsType::Slice { groups, .. } => _agg_helper_slice::<T, _>(groups, |[first, len]| {
411            debug_assert!(len <= s.len() as IdxSize);
412            if len == 0 {
413                None
414            } else {
415                let take = _slice_from_offsets(&s, first, len);
416                f(&take)
417            }
418        }),
419    }
420}
421
422#[cfg(feature = "bitwise")]
423impl<T> ChunkedArray<T>
424where
425    T: PolarsNumericType,
426    ChunkedArray<T>:
427        ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native> + IntoSeries,
428{
429    /// # Safety
430    ///
431    /// No bounds checks on `groups`.
432    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series {
433        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce) }
434    }
435
436    /// # Safety
437    ///
438    /// No bounds checks on `groups`.
439    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series {
440        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce) }
441    }
442
443    /// # Safety
444    ///
445    /// No bounds checks on `groups`.
446    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series {
447        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce) }
448    }
449}
450
451impl<T> ChunkedArray<T>
452where
453    T: PolarsNumericType + Sync,
454    T::Native: NativeType + PartialOrd + Num + NumCast + Zero + Bounded + std::iter::Sum<T::Native>,
455    ChunkedArray<T>: IntoSeries + ChunkAgg<T::Native>,
456{
457    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
458        // faster paths
459        match (self.is_sorted_flag(), self.null_count()) {
460            (IsSorted::Ascending, 0) => {
461                return self.clone().into_series().agg_first(groups);
462            },
463            (IsSorted::Descending, 0) => {
464                return self.clone().into_series().agg_last(groups);
465            },
466            _ => {},
467        }
468        match groups {
469            GroupsType::Idx(groups) => {
470                let ca = self.rechunk();
471                let arr = ca.downcast_iter().next().unwrap();
472                let no_nulls = arr.null_count() == 0;
473                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
474                    debug_assert!(idx.len() <= arr.len());
475                    if idx.is_empty() {
476                        None
477                    } else if idx.len() == 1 {
478                        arr.get(first as usize)
479                    } else if no_nulls {
480                        take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
481                            arr,
482                            idx2usize(idx),
483                            |a, b| a.min_ignore_nan(b),
484                        )
485                    } else {
486                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| {
487                            a.min_ignore_nan(b)
488                        })
489                    }
490                })
491            },
492            GroupsType::Slice {
493                groups: groups_slice,
494                ..
495            } => {
496                if _use_rolling_kernels(groups_slice, self.chunks()) {
497                    let arr = self.downcast_iter().next().unwrap();
498                    let values = arr.values().as_slice();
499                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
500                    let arr = match arr.validity() {
501                        None => _rolling_apply_agg_window_no_nulls::<MinWindow<_>, _, _>(
502                            values,
503                            offset_iter,
504                            None,
505                        ),
506                        Some(validity) => _rolling_apply_agg_window_nulls::<
507                            rolling::nulls::MinWindow<_>,
508                            _,
509                            _,
510                        >(
511                            values, validity, offset_iter, None
512                        ),
513                    };
514                    Self::from(arr).into_series()
515                } else {
516                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
517                        debug_assert!(len <= self.len() as IdxSize);
518                        match len {
519                            0 => None,
520                            1 => self.get(first as usize),
521                            _ => {
522                                let arr_group = _slice_from_offsets(self, first, len);
523                                ChunkAgg::min(&arr_group)
524                            },
525                        }
526                    })
527                }
528            },
529        }
530    }
531
532    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
533        // faster paths
534        match (self.is_sorted_flag(), self.null_count()) {
535            (IsSorted::Ascending, 0) => {
536                return self.clone().into_series().agg_last(groups);
537            },
538            (IsSorted::Descending, 0) => {
539                return self.clone().into_series().agg_first(groups);
540            },
541            _ => {},
542        }
543
544        match groups {
545            GroupsType::Idx(groups) => {
546                let ca = self.rechunk();
547                let arr = ca.downcast_iter().next().unwrap();
548                let no_nulls = arr.null_count() == 0;
549                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
550                    debug_assert!(idx.len() <= arr.len());
551                    if idx.is_empty() {
552                        None
553                    } else if idx.len() == 1 {
554                        arr.get(first as usize)
555                    } else if no_nulls {
556                        take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
557                            arr,
558                            idx2usize(idx),
559                            |a, b| a.max_ignore_nan(b),
560                        )
561                    } else {
562                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| {
563                            a.max_ignore_nan(b)
564                        })
565                    }
566                })
567            },
568            GroupsType::Slice {
569                groups: groups_slice,
570                ..
571            } => {
572                if _use_rolling_kernels(groups_slice, self.chunks()) {
573                    let arr = self.downcast_iter().next().unwrap();
574                    let values = arr.values().as_slice();
575                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
576                    let arr = match arr.validity() {
577                        None => _rolling_apply_agg_window_no_nulls::<MaxWindow<_>, _, _>(
578                            values,
579                            offset_iter,
580                            None,
581                        ),
582                        Some(validity) => _rolling_apply_agg_window_nulls::<
583                            rolling::nulls::MaxWindow<_>,
584                            _,
585                            _,
586                        >(
587                            values, validity, offset_iter, None
588                        ),
589                    };
590                    Self::from(arr).into_series()
591                } else {
592                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
593                        debug_assert!(len <= self.len() as IdxSize);
594                        match len {
595                            0 => None,
596                            1 => self.get(first as usize),
597                            _ => {
598                                let arr_group = _slice_from_offsets(self, first, len);
599                                ChunkAgg::max(&arr_group)
600                            },
601                        }
602                    })
603                }
604            },
605        }
606    }
607
608    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
609        match groups {
610            GroupsType::Idx(groups) => {
611                let ca = self.rechunk();
612                let arr = ca.downcast_iter().next().unwrap();
613                let no_nulls = arr.null_count() == 0;
614                _agg_helper_idx_no_null::<T, _>(groups, |(first, idx)| {
615                    debug_assert!(idx.len() <= self.len());
616                    if idx.is_empty() {
617                        T::Native::zero()
618                    } else if idx.len() == 1 {
619                        arr.get(first as usize).unwrap_or(T::Native::zero())
620                    } else if no_nulls {
621                        take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a + b)
622                            .unwrap_or(T::Native::zero())
623                    } else {
624                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx), |a, b| a + b)
625                            .unwrap_or(T::Native::zero())
626                    }
627                })
628            },
629            GroupsType::Slice { groups, .. } => {
630                if _use_rolling_kernels(groups, self.chunks()) {
631                    let arr = self.downcast_iter().next().unwrap();
632                    let values = arr.values().as_slice();
633                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
634                    let arr = match arr.validity() {
635                        None => _rolling_apply_agg_window_no_nulls::<SumWindow<_>, _, _>(
636                            values,
637                            offset_iter,
638                            None,
639                        ),
640                        Some(validity) => _rolling_apply_agg_window_nulls::<
641                            rolling::nulls::SumWindow<_>,
642                            _,
643                            _,
644                        >(
645                            values, validity, offset_iter, None
646                        ),
647                    };
648                    Self::from(arr).into_series()
649                } else {
650                    _agg_helper_slice_no_null::<T, _>(groups, |[first, len]| {
651                        debug_assert!(len <= self.len() as IdxSize);
652                        match len {
653                            0 => T::Native::zero(),
654                            1 => self.get(first as usize).unwrap_or(T::Native::zero()),
655                            _ => {
656                                let arr_group = _slice_from_offsets(self, first, len);
657                                arr_group.sum().unwrap_or(T::Native::zero())
658                            },
659                        }
660                    })
661                }
662            },
663        }
664    }
665}
666
667impl<T> SeriesWrap<ChunkedArray<T>>
668where
669    T: PolarsFloatType,
670    ChunkedArray<T>: IntoSeries
671        + ChunkVar
672        + VarAggSeries
673        + ChunkQuantile<T::Native>
674        + QuantileAggSeries
675        + ChunkAgg<T::Native>,
676    T::Native: Pow<T::Native, Output = T::Native>,
677{
678    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
679        match groups {
680            GroupsType::Idx(groups) => {
681                let ca = self.rechunk();
682                let arr = ca.downcast_iter().next().unwrap();
683                let no_nulls = arr.null_count() == 0;
684                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
685                    // this can fail due to a bug in lazy code.
686                    // here users can create filters in aggregations
687                    // and thereby creating shorter columns than the original group tuples.
688                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
689                    // access
690                    debug_assert!(idx.len() <= self.len());
691                    let out = if idx.is_empty() {
692                        None
693                    } else if idx.len() == 1 {
694                        arr.get(first as usize).map(|sum| sum.to_f64().unwrap())
695                    } else if no_nulls {
696                        take_agg_no_null_primitive_iter_unchecked::<_, T::Native, _, _>(
697                            arr,
698                            idx2usize(idx),
699                            |a, b| a + b,
700                        )
701                        .unwrap()
702                        .to_f64()
703                        .map(|sum| sum / idx.len() as f64)
704                    } else {
705                        take_agg_primitive_iter_unchecked_count_nulls::<T::Native, _, _, _>(
706                            arr,
707                            idx2usize(idx),
708                            |a, b| a + b,
709                            T::Native::zero(),
710                            idx.len() as IdxSize,
711                        )
712                        .map(|(sum, null_count)| {
713                            sum.to_f64()
714                                .map(|sum| sum / (idx.len() as f64 - null_count as f64))
715                                .unwrap()
716                        })
717                    };
718                    out.map(|flt| NumCast::from(flt).unwrap())
719                })
720            },
721            GroupsType::Slice { groups, .. } => {
722                if _use_rolling_kernels(groups, self.chunks()) {
723                    let arr = self.downcast_iter().next().unwrap();
724                    let values = arr.values().as_slice();
725                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
726                    let arr = match arr.validity() {
727                        None => _rolling_apply_agg_window_no_nulls::<MeanWindow<_>, _, _>(
728                            values,
729                            offset_iter,
730                            None,
731                        ),
732                        Some(validity) => _rolling_apply_agg_window_nulls::<
733                            rolling::nulls::MeanWindow<_>,
734                            _,
735                            _,
736                        >(
737                            values, validity, offset_iter, None
738                        ),
739                    };
740                    ChunkedArray::from(arr).into_series()
741                } else {
742                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
743                        debug_assert!(len <= self.len() as IdxSize);
744                        match len {
745                            0 => None,
746                            1 => self.get(first as usize),
747                            _ => {
748                                let arr_group = _slice_from_offsets(self, first, len);
749                                arr_group.mean().map(|flt| NumCast::from(flt).unwrap())
750                            },
751                        }
752                    })
753                }
754            },
755        }
756    }
757
758    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series
759    where
760        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
761    {
762        let ca = &self.0.rechunk();
763        match groups {
764            GroupsType::Idx(groups) => {
765                let ca = ca.rechunk();
766                let arr = ca.downcast_iter().next().unwrap();
767                let no_nulls = arr.null_count() == 0;
768                agg_helper_idx_on_all::<T, _>(groups, |idx| {
769                    debug_assert!(idx.len() <= ca.len());
770                    if idx.is_empty() {
771                        return None;
772                    }
773                    let out = if no_nulls {
774                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
775                    } else {
776                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
777                    };
778                    out.map(|flt| NumCast::from(flt).unwrap())
779                })
780            },
781            GroupsType::Slice { groups, .. } => {
782                if _use_rolling_kernels(groups, self.chunks()) {
783                    let arr = self.downcast_iter().next().unwrap();
784                    let values = arr.values().as_slice();
785                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
786                    let arr = match arr.validity() {
787                        None => _rolling_apply_agg_window_no_nulls::<
788                            MomentWindow<_, VarianceMoment>,
789                            _,
790                            _,
791                        >(
792                            values,
793                            offset_iter,
794                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
795                        ),
796                        Some(validity) => _rolling_apply_agg_window_nulls::<
797                            rolling::nulls::MomentWindow<_, VarianceMoment>,
798                            _,
799                            _,
800                        >(
801                            values,
802                            validity,
803                            offset_iter,
804                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
805                        ),
806                    };
807                    ChunkedArray::from(arr).into_series()
808                } else {
809                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
810                        debug_assert!(len <= self.len() as IdxSize);
811                        match len {
812                            0 => None,
813                            1 => {
814                                if ddof == 0 {
815                                    NumCast::from(0)
816                                } else {
817                                    None
818                                }
819                            },
820                            _ => {
821                                let arr_group = _slice_from_offsets(self, first, len);
822                                arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap())
823                            },
824                        }
825                    })
826                }
827            },
828        }
829    }
830    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series
831    where
832        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
833    {
834        let ca = &self.0.rechunk();
835        match groups {
836            GroupsType::Idx(groups) => {
837                let arr = ca.downcast_iter().next().unwrap();
838                let no_nulls = arr.null_count() == 0;
839                agg_helper_idx_on_all::<T, _>(groups, |idx| {
840                    debug_assert!(idx.len() <= ca.len());
841                    if idx.is_empty() {
842                        return None;
843                    }
844                    let out = if no_nulls {
845                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
846                    } else {
847                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
848                    };
849                    out.map(|flt| NumCast::from(flt.sqrt()).unwrap())
850                })
851            },
852            GroupsType::Slice { groups, .. } => {
853                if _use_rolling_kernels(groups, self.chunks()) {
854                    let arr = ca.downcast_iter().next().unwrap();
855                    let values = arr.values().as_slice();
856                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
857                    let arr = match arr.validity() {
858                        None => _rolling_apply_agg_window_no_nulls::<
859                            MomentWindow<_, VarianceMoment>,
860                            _,
861                            _,
862                        >(
863                            values,
864                            offset_iter,
865                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
866                        ),
867                        Some(validity) => _rolling_apply_agg_window_nulls::<
868                            rolling::nulls::MomentWindow<_, rolling::nulls::VarianceMoment>,
869                            _,
870                            _,
871                        >(
872                            values,
873                            validity,
874                            offset_iter,
875                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
876                        ),
877                    };
878
879                    let mut ca = ChunkedArray::<T>::from(arr);
880                    ca.apply_mut(|v| v.powf(NumCast::from(0.5).unwrap()));
881                    ca.into_series()
882                } else {
883                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
884                        debug_assert!(len <= self.len() as IdxSize);
885                        match len {
886                            0 => None,
887                            1 => {
888                                if ddof == 0 {
889                                    NumCast::from(0)
890                                } else {
891                                    None
892                                }
893                            },
894                            _ => {
895                                let arr_group = _slice_from_offsets(self, first, len);
896                                arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap())
897                            },
898                        }
899                    })
900                }
901            },
902        }
903    }
904}
905
906impl Float32Chunked {
907    pub(crate) unsafe fn agg_quantile(
908        &self,
909        groups: &GroupsType,
910        quantile: f64,
911        method: QuantileMethod,
912    ) -> Series {
913        agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method)
914    }
915    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
916        agg_median_generic::<_, Float32Type>(self, groups)
917    }
918}
919impl Float64Chunked {
920    pub(crate) unsafe fn agg_quantile(
921        &self,
922        groups: &GroupsType,
923        quantile: f64,
924        method: QuantileMethod,
925    ) -> Series {
926        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
927    }
928    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
929        agg_median_generic::<_, Float64Type>(self, groups)
930    }
931}
932
933impl<T> ChunkedArray<T>
934where
935    T: PolarsIntegerType,
936    ChunkedArray<T>: IntoSeries + ChunkAgg<T::Native> + ChunkVar,
937    T::Native: NumericNative + Ord,
938{
939    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
940        match groups {
941            GroupsType::Idx(groups) => {
942                let ca = self.rechunk();
943                let arr = ca.downcast_get(0).unwrap();
944                _agg_helper_idx::<Float64Type, _>(groups, |(first, idx)| {
945                    // this can fail due to a bug in lazy code.
946                    // here users can create filters in aggregations
947                    // and thereby creating shorter columns than the original group tuples.
948                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
949                    // access
950                    debug_assert!(idx.len() <= self.len());
951                    if idx.is_empty() {
952                        None
953                    } else if idx.len() == 1 {
954                        self.get(first as usize).map(|sum| sum.to_f64().unwrap())
955                    } else {
956                        match (self.has_nulls(), self.chunks.len()) {
957                            (false, 1) => {
958                                take_agg_no_null_primitive_iter_unchecked::<_, f64, _, _>(
959                                    arr,
960                                    idx2usize(idx),
961                                    |a, b| a + b,
962                                )
963                                .map(|sum| sum / idx.len() as f64)
964                            },
965                            (_, 1) => {
966                                {
967                                    take_agg_primitive_iter_unchecked_count_nulls::<
968                                        T::Native,
969                                        f64,
970                                        _,
971                                        _,
972                                    >(
973                                        arr, idx2usize(idx), |a, b| a + b, 0.0, idx.len() as IdxSize
974                                    )
975                                }
976                                .map(|(sum, null_count)| {
977                                    sum / (idx.len() as f64 - null_count as f64)
978                                })
979                            },
980                            _ => {
981                                let take = { self.take_unchecked(idx) };
982                                take.mean()
983                            },
984                        }
985                    }
986                })
987            },
988            GroupsType::Slice {
989                groups: groups_slice,
990                ..
991            } => {
992                if _use_rolling_kernels(groups_slice, self.chunks()) {
993                    let ca = self
994                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
995                        .unwrap();
996                    ca.agg_mean(groups)
997                } else {
998                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
999                        debug_assert!(first + len <= self.len() as IdxSize);
1000                        match len {
1001                            0 => None,
1002                            1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
1003                            _ => {
1004                                let arr_group = _slice_from_offsets(self, first, len);
1005                                arr_group.mean()
1006                            },
1007                        }
1008                    })
1009                }
1010            },
1011        }
1012    }
1013
1014    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series {
1015        match groups {
1016            GroupsType::Idx(groups) => {
1017                let ca_self = self.rechunk();
1018                let arr = ca_self.downcast_iter().next().unwrap();
1019                let no_nulls = arr.null_count() == 0;
1020                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1021                    debug_assert!(idx.len() <= arr.len());
1022                    if idx.is_empty() {
1023                        return None;
1024                    }
1025                    if no_nulls {
1026                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1027                    } else {
1028                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1029                    }
1030                })
1031            },
1032            GroupsType::Slice {
1033                groups: groups_slice,
1034                ..
1035            } => {
1036                if _use_rolling_kernels(groups_slice, self.chunks()) {
1037                    let ca = self
1038                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1039                        .unwrap();
1040                    ca.agg_var(groups, ddof)
1041                } else {
1042                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1043                        debug_assert!(first + len <= self.len() as IdxSize);
1044                        match len {
1045                            0 => None,
1046                            1 => {
1047                                if ddof == 0 {
1048                                    NumCast::from(0)
1049                                } else {
1050                                    None
1051                                }
1052                            },
1053                            _ => {
1054                                let arr_group = _slice_from_offsets(self, first, len);
1055                                arr_group.var(ddof)
1056                            },
1057                        }
1058                    })
1059                }
1060            },
1061        }
1062    }
1063    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series {
1064        match groups {
1065            GroupsType::Idx(groups) => {
1066                let ca_self = self.rechunk();
1067                let arr = ca_self.downcast_iter().next().unwrap();
1068                let no_nulls = arr.null_count() == 0;
1069                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1070                    debug_assert!(idx.len() <= self.len());
1071                    if idx.is_empty() {
1072                        return None;
1073                    }
1074                    let out = if no_nulls {
1075                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1076                    } else {
1077                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1078                    };
1079                    out.map(|v| v.sqrt())
1080                })
1081            },
1082            GroupsType::Slice {
1083                groups: groups_slice,
1084                ..
1085            } => {
1086                if _use_rolling_kernels(groups_slice, self.chunks()) {
1087                    let ca = self
1088                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1089                        .unwrap();
1090                    ca.agg_std(groups, ddof)
1091                } else {
1092                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1093                        debug_assert!(first + len <= self.len() as IdxSize);
1094                        match len {
1095                            0 => None,
1096                            1 => {
1097                                if ddof == 0 {
1098                                    NumCast::from(0)
1099                                } else {
1100                                    None
1101                                }
1102                            },
1103                            _ => {
1104                                let arr_group = _slice_from_offsets(self, first, len);
1105                                arr_group.std(ddof)
1106                            },
1107                        }
1108                    })
1109                }
1110            },
1111        }
1112    }
1113
1114    pub(crate) unsafe fn agg_quantile(
1115        &self,
1116        groups: &GroupsType,
1117        quantile: f64,
1118        method: QuantileMethod,
1119    ) -> Series {
1120        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
1121    }
1122    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
1123        agg_median_generic::<_, Float64Type>(self, groups)
1124    }
1125}