polars_core/frame/group_by/aggregations/
mod.rs

1mod agg_list;
2mod boolean;
3mod dispatch;
4mod string;
5
6use std::borrow::Cow;
7
8pub use agg_list::*;
9use arrow::bitmap::{Bitmap, MutableBitmap};
10use arrow::legacy::kernels::take_agg::*;
11use arrow::legacy::trusted_len::TrustedLenPush;
12use arrow::types::NativeType;
13use num_traits::pow::Pow;
14use num_traits::{Bounded, Float, Num, NumCast, ToPrimitive, Zero};
15use polars_compute::rolling::no_nulls::{
16    MaxWindow, MinWindow, MomentWindow, QuantileWindow, RollingAggWindowNoNulls,
17};
18use polars_compute::rolling::nulls::{RollingAggWindowNulls, VarianceMoment};
19use polars_compute::rolling::quantile_filter::SealedRolling;
20use polars_compute::rolling::{
21    self, ArgMaxWindow, ArgMinWindow, MeanWindow, QuantileMethod, RollingFnParams,
22    RollingQuantileParams, RollingVarParams, SumWindow, quantile_filter,
23};
24use polars_utils::arg_min_max::ArgMinMax;
25use polars_utils::float::IsFloat;
26#[cfg(feature = "dtype-f16")]
27use polars_utils::float16::pf16;
28use polars_utils::idx_vec::IdxVec;
29use polars_utils::kahan_sum::KahanSum;
30use polars_utils::min_max::MinMax;
31use rayon::prelude::*;
32
33use crate::POOL;
34use crate::chunked_array::cast::CastOptions;
35#[cfg(feature = "object")]
36use crate::chunked_array::object::extension::create_extension;
37use crate::chunked_array::{arg_max_numeric, arg_min_numeric};
38#[cfg(feature = "object")]
39use crate::frame::group_by::GroupsIndicator;
40use crate::prelude::*;
41use crate::series::IsSorted;
42use crate::series::implementations::SeriesWrap;
43use crate::utils::NoNull;
44
45fn idx2usize(idx: &[IdxSize]) -> impl ExactSizeIterator<Item = usize> + '_ {
46    idx.iter().map(|i| *i as usize)
47}
48
49// if the windows overlap, we can use the rolling_<agg> kernels
50// they maintain state, which saves a lot of compute by not naively traversing all elements every
51// window
52//
53// if the windows don't overlap, we should not use these kernels as they are single threaded, so
54// we miss out on easy parallelization.
55pub fn _use_rolling_kernels(
56    groups: &GroupsSlice,
57    overlapping: bool,
58    monotonic: bool,
59    chunks: &[ArrayRef],
60) -> bool {
61    match groups.len() {
62        0 | 1 => false,
63        _ => overlapping && monotonic && chunks.len() == 1,
64    }
65}
66
67// Use an aggregation window that maintains the state
68pub fn _rolling_apply_agg_window_nulls<Agg, T, O, Out>(
69    values: &[T],
70    validity: &Bitmap,
71    offsets: O,
72    params: Option<RollingFnParams>,
73) -> PrimitiveArray<Out>
74where
75    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
76    Agg: RollingAggWindowNulls<T, Out>,
77    T: IsFloat + NativeType,
78    Out: NativeType,
79{
80    if values.is_empty() {
81        let out: Vec<Out> = vec![];
82        return PrimitiveArray::new(Out::PRIMITIVE.into(), out.into(), None);
83    }
84
85    // This iterators length can be trusted
86    // these represent the number of groups in the group_by operation
87    let output_len = offsets.size_hint().0;
88    // start with a dummy index, will be overwritten on first iteration.
89    let mut agg_window = Agg::new(values, validity, 0, 0, params, None);
90
91    let mut validity = MutableBitmap::with_capacity(output_len);
92    validity.extend_constant(output_len, true);
93
94    let out = offsets
95        .enumerate()
96        .map(|(idx, (start, len))| {
97            let end = start + len;
98
99            // SAFETY:
100            // we are in bounds
101            unsafe { agg_window.update(start as usize, end as usize) };
102            match agg_window.get_agg(idx) {
103                Some(val) => val,
104                None => {
105                    // SAFETY: we are in bounds
106                    unsafe { validity.set_unchecked(idx, false) };
107                    Out::default()
108                },
109            }
110        })
111        .collect_trusted::<Vec<_>>();
112
113    PrimitiveArray::new(Out::PRIMITIVE.into(), out.into(), Some(validity.into()))
114}
115
116// Use an aggregation window that maintains the state.
117pub fn _rolling_apply_agg_window_no_nulls<Agg, T, O, Out>(
118    values: &[T],
119    offsets: O,
120    params: Option<RollingFnParams>,
121) -> PrimitiveArray<Out>
122where
123    // items (offset, len) -> so offsets are offset, offset + len
124    Agg: RollingAggWindowNoNulls<T, Out>,
125    O: Iterator<Item = (IdxSize, IdxSize)> + TrustedLen,
126    T: IsFloat + NativeType,
127    Out: NativeType,
128{
129    if values.is_empty() {
130        let out: Vec<Out> = vec![];
131        return PrimitiveArray::new(Out::PRIMITIVE.into(), out.into(), None);
132    }
133    // start with a dummy index, will be overwritten on first iteration.
134    let mut agg_window = Agg::new(values, 0, 0, params, None);
135
136    offsets
137        .enumerate()
138        .map(|(idx, (start, len))| {
139            let end = start + len;
140
141            // SAFETY: we are in bounds.
142            unsafe { agg_window.update(start as usize, end as usize) };
143            agg_window.get_agg(idx)
144        })
145        .collect::<PrimitiveArray<Out>>()
146}
147
148pub fn _slice_from_offsets<T>(ca: &ChunkedArray<T>, first: IdxSize, len: IdxSize) -> ChunkedArray<T>
149where
150    T: PolarsDataType,
151{
152    ca.slice(first as i64, len as usize)
153}
154
155/// Helper that combines the groups into a parallel iterator over `(first, all): (u32, &Vec<u32>)`.
156pub fn _agg_helper_idx<T, F>(groups: &GroupsIdx, f: F) -> Series
157where
158    F: Fn((IdxSize, &IdxVec)) -> Option<T::Native> + Send + Sync,
159    T: PolarsNumericType,
160{
161    let ca: ChunkedArray<T> = POOL.install(|| groups.into_par_iter().map(f).collect());
162    ca.into_series()
163}
164
165/// Same helper as `_agg_helper_idx` but for aggregations that don't return an Option.
166pub fn _agg_helper_idx_no_null<T, F>(groups: &GroupsIdx, f: F) -> Series
167where
168    F: Fn((IdxSize, &IdxVec)) -> T::Native + Send + Sync,
169    T: PolarsNumericType,
170{
171    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.into_par_iter().map(f).collect());
172    ca.into_inner().into_series()
173}
174
175/// Helper that iterates on the `all: Vec<Vec<u32>` collection,
176/// this doesn't have traverse the `first: Vec<u32>` memory and is therefore faster.
177fn agg_helper_idx_on_all<T, F>(groups: &GroupsIdx, f: F) -> Series
178where
179    F: Fn(&IdxVec) -> Option<T::Native> + Send + Sync,
180    T: PolarsNumericType,
181{
182    let ca: ChunkedArray<T> = POOL.install(|| groups.all().into_par_iter().map(f).collect());
183    ca.into_series()
184}
185
186pub fn _agg_helper_slice<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
187where
188    F: Fn([IdxSize; 2]) -> Option<T::Native> + Send + Sync,
189    T: PolarsNumericType,
190{
191    let ca: ChunkedArray<T> = POOL.install(|| groups.par_iter().copied().map(f).collect());
192    ca.into_series()
193}
194
195pub fn _agg_helper_slice_no_null<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
196where
197    F: Fn([IdxSize; 2]) -> T::Native + Send + Sync,
198    T: PolarsNumericType,
199{
200    let ca: NoNull<ChunkedArray<T>> = POOL.install(|| groups.par_iter().copied().map(f).collect());
201    ca.into_inner().into_series()
202}
203
204/// Intermediate helper trait so we can have a single generic implementation
205/// This trait will ensure the specific dispatch works without complicating
206/// the trait bounds.
207trait QuantileDispatcher<K> {
208    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<K>>;
209
210    fn _median(self) -> Option<K>;
211}
212
213impl<T> QuantileDispatcher<f64> for ChunkedArray<T>
214where
215    T: PolarsIntegerType,
216    T::Native: Ord,
217{
218    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
219        self.quantile_faster(quantile, method)
220    }
221    fn _median(self) -> Option<f64> {
222        self.median_faster()
223    }
224}
225
226#[cfg(feature = "dtype-f16")]
227impl QuantileDispatcher<pf16> for Float16Chunked {
228    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<pf16>> {
229        self.quantile_faster(quantile, method)
230    }
231    fn _median(self) -> Option<pf16> {
232        self.median_faster()
233    }
234}
235
236impl QuantileDispatcher<f32> for Float32Chunked {
237    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f32>> {
238        self.quantile_faster(quantile, method)
239    }
240    fn _median(self) -> Option<f32> {
241        self.median_faster()
242    }
243}
244impl QuantileDispatcher<f64> for Float64Chunked {
245    fn _quantile(self, quantile: f64, method: QuantileMethod) -> PolarsResult<Option<f64>> {
246        self.quantile_faster(quantile, method)
247    }
248    fn _median(self) -> Option<f64> {
249        self.median_faster()
250    }
251}
252
253unsafe fn agg_quantile_generic<T, K>(
254    ca: &ChunkedArray<T>,
255    groups: &GroupsType,
256    quantile: f64,
257    method: QuantileMethod,
258) -> Series
259where
260    T: PolarsNumericType,
261    ChunkedArray<T>: QuantileDispatcher<K::Native>,
262    K: PolarsNumericType,
263    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + quantile_filter::SealedRolling,
264{
265    let invalid_quantile = !(0.0..=1.0).contains(&quantile);
266    if invalid_quantile {
267        return Series::full_null(ca.name().clone(), groups.len(), ca.dtype());
268    }
269    match groups {
270        GroupsType::Idx(groups) => {
271            let ca = ca.rechunk();
272            agg_helper_idx_on_all::<K, _>(groups, |idx| {
273                debug_assert!(idx.len() <= ca.len());
274                if idx.is_empty() {
275                    return None;
276                }
277                let take = { ca.take_unchecked(idx) };
278                // checked with invalid quantile check
279                take._quantile(quantile, method).unwrap_unchecked()
280            })
281        },
282        GroupsType::Slice {
283            groups,
284            overlapping,
285            monotonic,
286        } => {
287            if _use_rolling_kernels(groups, *overlapping, *monotonic, ca.chunks()) {
288                // this cast is a no-op for floats
289                let s = ca
290                    .cast_with_options(&K::get_static_dtype(), CastOptions::Overflowing)
291                    .unwrap();
292                let ca: &ChunkedArray<K> = s.as_ref().as_ref();
293                let arr = ca.downcast_iter().next().unwrap();
294                let values = arr.values().as_slice();
295                let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
296                let arr = match arr.validity() {
297                    None => _rolling_apply_agg_window_no_nulls::<QuantileWindow<_>, _, _, _>(
298                        values,
299                        offset_iter,
300                        Some(RollingFnParams::Quantile(RollingQuantileParams {
301                            prob: quantile,
302                            method,
303                        })),
304                    ),
305                    Some(validity) => {
306                        _rolling_apply_agg_window_nulls::<rolling::nulls::QuantileWindow<_>, _, _, _>(
307                            values,
308                            validity,
309                            offset_iter,
310                            Some(RollingFnParams::Quantile(RollingQuantileParams {
311                                prob: quantile,
312                                method,
313                            })),
314                        )
315                    },
316                };
317                // The rolling kernels works on the dtype, this is not yet the
318                // float output type we need.
319                ChunkedArray::<K>::with_chunk(PlSmallStr::EMPTY, arr).into_series()
320            } else {
321                _agg_helper_slice::<K, _>(groups, |[first, len]| {
322                    debug_assert!(first + len <= ca.len() as IdxSize);
323                    match len {
324                        0 => None,
325                        1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()),
326                        _ => {
327                            let arr_group = _slice_from_offsets(ca, first, len);
328                            // unwrap checked with invalid quantile check
329                            arr_group
330                                ._quantile(quantile, method)
331                                .unwrap_unchecked()
332                                .map(|flt| NumCast::from(flt).unwrap_unchecked())
333                        },
334                    }
335                })
336            }
337        },
338    }
339}
340
341unsafe fn agg_median_generic<T, K>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series
342where
343    T: PolarsNumericType,
344    ChunkedArray<T>: QuantileDispatcher<K::Native>,
345    K: PolarsNumericType,
346    <K as datatypes::PolarsNumericType>::Native: num_traits::Float + SealedRolling,
347{
348    match groups {
349        GroupsType::Idx(groups) => {
350            let ca = ca.rechunk();
351            agg_helper_idx_on_all::<K, _>(groups, |idx| {
352                debug_assert!(idx.len() <= ca.len());
353                if idx.is_empty() {
354                    return None;
355                }
356                let take = { ca.take_unchecked(idx) };
357                take._median()
358            })
359        },
360        GroupsType::Slice { .. } => {
361            agg_quantile_generic::<T, K>(ca, groups, 0.5, QuantileMethod::Linear)
362        },
363    }
364}
365
366/// # Safety
367///
368/// No bounds checks on `groups`.
369#[cfg(feature = "bitwise")]
370unsafe fn bitwise_agg<T: PolarsNumericType>(
371    ca: &ChunkedArray<T>,
372    groups: &GroupsType,
373    f: fn(&ChunkedArray<T>) -> Option<T::Native>,
374) -> Series
375where
376    ChunkedArray<T>: ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native>,
377{
378    // Prevent a rechunk for every individual group.
379
380    let s = if groups.len() > 1 {
381        ca.rechunk()
382    } else {
383        Cow::Borrowed(ca)
384    };
385
386    match groups {
387        GroupsType::Idx(groups) => agg_helper_idx_on_all::<T, _>(groups, |idx| {
388            debug_assert!(idx.len() <= s.len());
389            if idx.is_empty() {
390                None
391            } else {
392                let take = unsafe { s.take_unchecked(idx) };
393                f(&take)
394            }
395        }),
396        GroupsType::Slice { groups, .. } => _agg_helper_slice::<T, _>(groups, |[first, len]| {
397            debug_assert!(len <= s.len() as IdxSize);
398            if len == 0 {
399                None
400            } else {
401                let take = _slice_from_offsets(&s, first, len);
402                f(&take)
403            }
404        }),
405    }
406}
407
408#[cfg(feature = "bitwise")]
409impl<T> ChunkedArray<T>
410where
411    T: PolarsNumericType,
412    ChunkedArray<T>: ChunkTakeUnchecked<[IdxSize]> + ChunkBitwiseReduce<Physical = T::Native>,
413{
414    /// # Safety
415    ///
416    /// No bounds checks on `groups`.
417    pub(crate) unsafe fn agg_and(&self, groups: &GroupsType) -> Series {
418        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::and_reduce) }
419    }
420
421    /// # Safety
422    ///
423    /// No bounds checks on `groups`.
424    pub(crate) unsafe fn agg_or(&self, groups: &GroupsType) -> Series {
425        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::or_reduce) }
426    }
427
428    /// # Safety
429    ///
430    /// No bounds checks on `groups`.
431    pub(crate) unsafe fn agg_xor(&self, groups: &GroupsType) -> Series {
432        unsafe { bitwise_agg(self, groups, ChunkBitwiseReduce::xor_reduce) }
433    }
434}
435
436impl<T> ChunkedArray<T>
437where
438    T: PolarsNumericType + Sync,
439    T::Native: NativeType + PartialOrd + Num + NumCast + Zero + Bounded + std::iter::Sum<T::Native>,
440    ChunkedArray<T>: ChunkAgg<T::Native>,
441{
442    pub(crate) unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
443        // faster paths
444        if groups.is_sorted_flag() {
445            match self.is_sorted_flag() {
446                IsSorted::Ascending => {
447                    return self.clone().into_series().agg_first_non_null(groups);
448                },
449                IsSorted::Descending => {
450                    return self.clone().into_series().agg_last_non_null(groups);
451                },
452                _ => {},
453            }
454        }
455
456        match groups {
457            GroupsType::Idx(groups) => {
458                let ca = self.rechunk();
459                let arr = ca.downcast_iter().next().unwrap();
460                let no_nulls = arr.null_count() == 0;
461                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
462                    debug_assert!(idx.len() <= arr.len());
463                    if idx.is_empty() {
464                        None
465                    } else if idx.len() == 1 {
466                        arr.get(first as usize)
467                    } else if no_nulls {
468                        take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
469                            .reduce(|a, b| a.min_ignore_nan(b))
470                    } else {
471                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
472                            .reduce(|a, b| a.min_ignore_nan(b))
473                    }
474                })
475            },
476            GroupsType::Slice {
477                groups: groups_slice,
478                overlapping,
479                monotonic,
480            } => {
481                if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, self.chunks()) {
482                    let arr = self.downcast_iter().next().unwrap();
483                    let values = arr.values().as_slice();
484                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
485                    let arr = match arr.validity() {
486                        None => _rolling_apply_agg_window_no_nulls::<MinWindow<_>, _, _, _>(
487                            values,
488                            offset_iter,
489                            None,
490                        ),
491                        Some(validity) => {
492                            _rolling_apply_agg_window_nulls::<rolling::nulls::MinWindow<_>, _, _, _>(
493                                values,
494                                validity,
495                                offset_iter,
496                                None,
497                            )
498                        },
499                    };
500                    Self::from(arr).into_series()
501                } else {
502                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
503                        debug_assert!(len <= self.len() as IdxSize);
504                        match len {
505                            0 => None,
506                            1 => self.get(first as usize),
507                            _ => {
508                                let arr_group = _slice_from_offsets(self, first, len);
509                                ChunkAgg::min(&arr_group)
510                            },
511                        }
512                    })
513                }
514            },
515        }
516    }
517
518    pub(crate) unsafe fn agg_arg_min(&self, groups: &GroupsType) -> Series
519    where
520        for<'b> &'b [T::Native]: ArgMinMax,
521    {
522        if groups.is_sorted_flag() {
523            match self.is_sorted_flag() {
524                IsSorted::Ascending => {
525                    return self.clone().into_series().agg_arg_first_non_null(groups);
526                },
527                IsSorted::Descending => {
528                    return self.clone().into_series().agg_arg_last_non_null(groups);
529                },
530                _ => {},
531            }
532        }
533
534        match groups {
535            GroupsType::Idx(groups) => {
536                let ca = self.rechunk();
537                let arr = ca.downcast_iter().next().unwrap();
538                let no_nulls = !arr.has_nulls();
539
540                agg_helper_idx_on_all::<IdxType, _>(groups, |idx| {
541                    if idx.is_empty() {
542                        return None;
543                    }
544
545                    if no_nulls {
546                        let first_i = idx[0] as usize;
547                        let mut best_pos: IdxSize = 0;
548                        let mut best_val: T::Native = unsafe { arr.value_unchecked(first_i) };
549
550                        for (pos, &i) in idx.iter().enumerate().skip(1) {
551                            let v = unsafe { arr.value_unchecked(i as usize) };
552                            if v.nan_max_lt(&best_val) {
553                                best_val = v;
554                                best_pos = pos as IdxSize;
555                            }
556                        }
557                        Some(best_pos)
558                    } else {
559                        let (start_pos, mut best_val) = idx
560                            .iter()
561                            .enumerate()
562                            .find_map(|(pos, &i)| arr.get(i as usize).map(|v| (pos, v)))?;
563
564                        let mut best_pos: IdxSize = start_pos as IdxSize;
565
566                        for (pos, &i) in idx.iter().enumerate().skip(start_pos + 1) {
567                            if let Some(v) = arr.get(i as usize) {
568                                if v.nan_max_lt(&best_val) {
569                                    best_val = v;
570                                    best_pos = pos as IdxSize;
571                                }
572                            }
573                        }
574
575                        Some(best_pos)
576                    }
577                })
578            },
579            GroupsType::Slice {
580                groups: groups_slice,
581                overlapping,
582                monotonic,
583            } => {
584                if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, self.chunks()) {
585                    let arr = self.downcast_as_array();
586                    let values = arr.values().as_slice();
587                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
588                    let idx_arr = match arr.validity() {
589                        None => {
590                            _rolling_apply_agg_window_no_nulls::<ArgMinWindow<_>, _, _, IdxSize>(
591                                values,
592                                offset_iter,
593                                None,
594                            )
595                        },
596                        Some(validity) => {
597                            _rolling_apply_agg_window_nulls::<ArgMinWindow<_>, _, _, IdxSize>(
598                                values,
599                                validity,
600                                offset_iter,
601                                None,
602                            )
603                        },
604                    };
605
606                    IdxCa::from(idx_arr).into_series()
607                } else {
608                    _agg_helper_slice::<IdxType, _>(groups_slice, |[first, len]| {
609                        debug_assert!(len <= self.len() as IdxSize);
610                        match len {
611                            0 => None,
612                            1 => Some(0 as IdxSize),
613                            _ => {
614                                let group_ca = _slice_from_offsets(self, first, len);
615                                let pos_in_group: Option<usize> = arg_min_numeric(&group_ca);
616                                pos_in_group.map(|p| p as IdxSize)
617                            },
618                        }
619                    })
620                }
621            },
622        }
623    }
624
625    pub(crate) unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
626        // faster paths
627        if groups.is_sorted_flag() {
628            match self.is_sorted_flag() {
629                IsSorted::Ascending => return self.clone().into_series().agg_last_non_null(groups),
630                IsSorted::Descending => {
631                    return self.clone().into_series().agg_first_non_null(groups);
632                },
633                _ => {},
634            }
635        }
636
637        match groups {
638            GroupsType::Idx(groups) => {
639                let ca = self.rechunk();
640                let arr = ca.downcast_iter().next().unwrap();
641                let no_nulls = arr.null_count() == 0;
642                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
643                    debug_assert!(idx.len() <= arr.len());
644                    if idx.is_empty() {
645                        None
646                    } else if idx.len() == 1 {
647                        arr.get(first as usize)
648                    } else if no_nulls {
649                        take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
650                            .reduce(|a, b| a.max_ignore_nan(b))
651                    } else {
652                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
653                            .reduce(|a, b| a.max_ignore_nan(b))
654                    }
655                })
656            },
657            GroupsType::Slice {
658                groups: groups_slice,
659                overlapping,
660                monotonic,
661            } => {
662                if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, self.chunks()) {
663                    let arr = self.downcast_iter().next().unwrap();
664                    let values = arr.values().as_slice();
665                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
666                    let arr = match arr.validity() {
667                        None => _rolling_apply_agg_window_no_nulls::<MaxWindow<_>, _, _, _>(
668                            values,
669                            offset_iter,
670                            None,
671                        ),
672                        Some(validity) => {
673                            _rolling_apply_agg_window_nulls::<rolling::nulls::MaxWindow<_>, _, _, _>(
674                                values,
675                                validity,
676                                offset_iter,
677                                None,
678                            )
679                        },
680                    };
681                    Self::from(arr).into_series()
682                } else {
683                    _agg_helper_slice::<T, _>(groups_slice, |[first, len]| {
684                        debug_assert!(len <= self.len() as IdxSize);
685                        match len {
686                            0 => None,
687                            1 => self.get(first as usize),
688                            _ => {
689                                let arr_group = _slice_from_offsets(self, first, len);
690                                ChunkAgg::max(&arr_group)
691                            },
692                        }
693                    })
694                }
695            },
696        }
697    }
698
699    pub(crate) unsafe fn agg_arg_max(&self, groups: &GroupsType) -> Series
700    where
701        for<'b> &'b [T::Native]: ArgMinMax,
702    {
703        if groups.is_sorted_flag() {
704            match self.is_sorted_flag() {
705                IsSorted::Ascending => {
706                    return self.clone().into_series().agg_arg_last_non_null(groups);
707                },
708                IsSorted::Descending => {
709                    return self.clone().into_series().agg_arg_first_non_null(groups);
710                },
711                _ => {},
712            }
713        }
714        match groups {
715            GroupsType::Idx(groups) => {
716                let ca = self.rechunk();
717                let arr = ca.downcast_as_array();
718                let no_nulls = arr.null_count() == 0;
719
720                agg_helper_idx_on_all::<IdxType, _>(groups, |idx| {
721                    if idx.is_empty() {
722                        return None;
723                    }
724
725                    if no_nulls {
726                        let first_i = idx[0] as usize;
727                        let mut best_pos: IdxSize = 0;
728                        let mut best_val: T::Native = unsafe { arr.value_unchecked(first_i) };
729
730                        for (pos, &i) in idx.iter().enumerate().skip(1) {
731                            let v = unsafe { arr.value_unchecked(i as usize) };
732
733                            if v.nan_min_gt(&best_val) {
734                                best_val = v;
735                                best_pos = pos as IdxSize;
736                            }
737                        }
738
739                        Some(best_pos)
740                    } else {
741                        let (start_pos, mut best_val) = idx
742                            .iter()
743                            .enumerate()
744                            .find_map(|(pos, &i)| arr.get(i as usize).map(|v| (pos, v)))?;
745
746                        let mut best_pos: IdxSize = start_pos as IdxSize;
747
748                        for (pos, &i) in idx.iter().enumerate().skip(start_pos + 1) {
749                            if let Some(v) = arr.get(i as usize) {
750                                if v.nan_min_gt(&best_val) {
751                                    best_val = v;
752                                    best_pos = pos as IdxSize;
753                                }
754                            }
755                        }
756
757                        Some(best_pos)
758                    }
759                })
760            },
761
762            GroupsType::Slice {
763                groups: groups_slice,
764                overlapping,
765                monotonic,
766            } => {
767                if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, self.chunks()) {
768                    let arr = self.downcast_iter().next().unwrap();
769                    let values = arr.values().as_slice();
770                    let offset_iter = groups_slice.iter().map(|[first, len]| (*first, *len));
771                    let idx_arr = match arr.validity() {
772                        None => {
773                            _rolling_apply_agg_window_no_nulls::<ArgMaxWindow<_>, _, _, IdxSize>(
774                                values,
775                                offset_iter,
776                                None,
777                            )
778                        },
779                        Some(validity) => {
780                            _rolling_apply_agg_window_nulls::<ArgMaxWindow<_>, _, _, IdxSize>(
781                                values,
782                                validity,
783                                offset_iter,
784                                None,
785                            )
786                        },
787                    };
788                    IdxCa::from(idx_arr).into_series()
789                } else {
790                    _agg_helper_slice::<IdxType, _>(groups_slice, |[first, len]| {
791                        debug_assert!(len <= self.len() as IdxSize);
792                        match len {
793                            0 => None,
794                            1 => Some(0 as IdxSize),
795                            _ => {
796                                let group_ca = _slice_from_offsets(self, first, len);
797                                let pos_in_group: Option<usize> = arg_max_numeric(&group_ca);
798                                pos_in_group.map(|p| p as IdxSize)
799                            },
800                        }
801                    })
802                }
803            },
804        }
805    }
806    pub(crate) unsafe fn agg_sum(&self, groups: &GroupsType) -> Series {
807        match groups {
808            GroupsType::Idx(groups) => {
809                let ca = self.rechunk();
810                let arr = ca.downcast_iter().next().unwrap();
811                let no_nulls = arr.null_count() == 0;
812                _agg_helper_idx_no_null::<T, _>(groups, |(first, idx)| {
813                    debug_assert!(idx.len() <= self.len());
814                    if idx.is_empty() {
815                        T::Native::zero()
816                    } else if idx.len() == 1 {
817                        arr.get(first as usize).unwrap_or(T::Native::zero())
818                    } else if no_nulls {
819                        if T::Native::is_float() {
820                            take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
821                                .fold(KahanSum::default(), |k, x| k + x)
822                                .sum()
823                        } else {
824                            take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
825                                .fold(T::Native::zero(), |a, b| a + b)
826                        }
827                    } else if T::Native::is_float() {
828                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
829                            .fold(KahanSum::default(), |k, x| k + x)
830                            .sum()
831                    } else {
832                        take_agg_primitive_iter_unchecked(arr, idx2usize(idx))
833                            .fold(T::Native::zero(), |a, b| a + b)
834                    }
835                })
836            },
837            GroupsType::Slice {
838                groups,
839                overlapping,
840                monotonic,
841            } => {
842                if _use_rolling_kernels(groups, *overlapping, *monotonic, self.chunks()) {
843                    let arr = self.downcast_iter().next().unwrap();
844                    let values = arr.values().as_slice();
845                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
846                    let arr = match arr.validity() {
847                        None => _rolling_apply_agg_window_no_nulls::<
848                            SumWindow<T::Native, T::Native>,
849                            _,
850                            _,
851                            _,
852                        >(values, offset_iter, None),
853                        Some(validity) => {
854                            _rolling_apply_agg_window_nulls::<
855                                SumWindow<T::Native, T::Native>,
856                                _,
857                                _,
858                                _,
859                            >(values, validity, offset_iter, None)
860                        },
861                    };
862                    Self::from(arr).into_series()
863                } else {
864                    _agg_helper_slice_no_null::<T, _>(groups, |[first, len]| {
865                        debug_assert!(len <= self.len() as IdxSize);
866                        match len {
867                            0 => T::Native::zero(),
868                            1 => self.get(first as usize).unwrap_or(T::Native::zero()),
869                            _ => {
870                                let arr_group = _slice_from_offsets(self, first, len);
871                                arr_group.sum().unwrap_or(T::Native::zero())
872                            },
873                        }
874                    })
875                }
876            },
877        }
878    }
879}
880
881impl<T> SeriesWrap<ChunkedArray<T>>
882where
883    T: PolarsFloatType,
884    ChunkedArray<T>: ChunkVar
885        + VarAggSeries
886        + ChunkQuantile<T::Native>
887        + QuantileAggSeries
888        + ChunkAgg<T::Native>,
889    T::Native: Pow<T::Native, Output = T::Native>,
890{
891    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
892        match groups {
893            GroupsType::Idx(groups) => {
894                let ca = self.rechunk();
895                let arr = ca.downcast_iter().next().unwrap();
896                let no_nulls = arr.null_count() == 0;
897                _agg_helper_idx::<T, _>(groups, |(first, idx)| {
898                    // this can fail due to a bug in lazy code.
899                    // here users can create filters in aggregations
900                    // and thereby creating shorter columns than the original group tuples.
901                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
902                    // access
903                    debug_assert!(idx.len() <= self.len());
904                    let out = if idx.is_empty() {
905                        None
906                    } else if idx.len() == 1 {
907                        arr.get(first as usize).map(|sum| sum.to_f64().unwrap())
908                    } else if no_nulls {
909                        Some(
910                            take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
911                                .fold(KahanSum::default(), |a, b| {
912                                    a + b.to_f64().unwrap_unchecked()
913                                })
914                                .sum()
915                                / idx.len() as f64,
916                        )
917                    } else {
918                        take_agg_primitive_iter_unchecked_count_nulls(
919                            arr,
920                            idx2usize(idx),
921                            KahanSum::default(),
922                            |a, b| a + b.to_f64().unwrap_unchecked(),
923                            idx.len() as IdxSize,
924                        )
925                        .map(|(sum, null_count)| sum.sum() / (idx.len() as f64 - null_count as f64))
926                    };
927                    out.map(|flt| NumCast::from(flt).unwrap())
928                })
929            },
930            GroupsType::Slice {
931                groups,
932                overlapping,
933                monotonic,
934            } => {
935                if _use_rolling_kernels(groups, *overlapping, *monotonic, self.chunks()) {
936                    let arr = self.downcast_iter().next().unwrap();
937                    let values = arr.values().as_slice();
938                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
939                    let arr = match arr.validity() {
940                        None => _rolling_apply_agg_window_no_nulls::<MeanWindow<_>, _, _, _>(
941                            values,
942                            offset_iter,
943                            None,
944                        ),
945                        Some(validity) => {
946                            _rolling_apply_agg_window_nulls::<MeanWindow<_>, _, _, _>(
947                                values,
948                                validity,
949                                offset_iter,
950                                None,
951                            )
952                        },
953                    };
954                    ChunkedArray::<T>::from(arr).into_series()
955                } else {
956                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
957                        debug_assert!(len <= self.len() as IdxSize);
958                        match len {
959                            0 => None,
960                            1 => self.get(first as usize),
961                            _ => {
962                                let arr_group = _slice_from_offsets(self, first, len);
963                                arr_group.mean().map(|flt| NumCast::from(flt).unwrap())
964                            },
965                        }
966                    })
967                }
968            },
969        }
970    }
971
972    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series
973    where
974        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
975    {
976        let ca = &self.0.rechunk();
977        match groups {
978            GroupsType::Idx(groups) => {
979                let ca = ca.rechunk();
980                let arr = ca.downcast_iter().next().unwrap();
981                let no_nulls = arr.null_count() == 0;
982                agg_helper_idx_on_all::<T, _>(groups, |idx| {
983                    debug_assert!(idx.len() <= ca.len());
984                    if idx.is_empty() {
985                        return None;
986                    }
987                    let out = if no_nulls {
988                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
989                    } else {
990                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
991                    };
992                    out.map(|flt| NumCast::from(flt).unwrap())
993                })
994            },
995            GroupsType::Slice {
996                groups,
997                overlapping,
998                monotonic,
999            } => {
1000                if _use_rolling_kernels(groups, *overlapping, *monotonic, self.chunks()) {
1001                    let arr = self.downcast_iter().next().unwrap();
1002                    let values = arr.values().as_slice();
1003                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
1004                    let arr = match arr.validity() {
1005                        None => _rolling_apply_agg_window_no_nulls::<
1006                            MomentWindow<_, VarianceMoment>,
1007                            _,
1008                            _,
1009                            _,
1010                        >(
1011                            values,
1012                            offset_iter,
1013                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
1014                        ),
1015                        Some(validity) => _rolling_apply_agg_window_nulls::<
1016                            rolling::nulls::MomentWindow<_, VarianceMoment>,
1017                            _,
1018                            _,
1019                            _,
1020                        >(
1021                            values,
1022                            validity,
1023                            offset_iter,
1024                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
1025                        ),
1026                    };
1027                    ChunkedArray::<T>::from(arr).into_series()
1028                } else {
1029                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
1030                        debug_assert!(len <= self.len() as IdxSize);
1031                        match len {
1032                            0 => None,
1033                            1 => {
1034                                if ddof == 0 {
1035                                    NumCast::from(0)
1036                                } else {
1037                                    None
1038                                }
1039                            },
1040                            _ => {
1041                                let arr_group = _slice_from_offsets(self, first, len);
1042                                arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap())
1043                            },
1044                        }
1045                    })
1046                }
1047            },
1048        }
1049    }
1050    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series
1051    where
1052        <T as datatypes::PolarsNumericType>::Native: num_traits::Float,
1053    {
1054        let ca = &self.0.rechunk();
1055        match groups {
1056            GroupsType::Idx(groups) => {
1057                let arr = ca.downcast_iter().next().unwrap();
1058                let no_nulls = arr.null_count() == 0;
1059                agg_helper_idx_on_all::<T, _>(groups, |idx| {
1060                    debug_assert!(idx.len() <= ca.len());
1061                    if idx.is_empty() {
1062                        return None;
1063                    }
1064                    let out = if no_nulls {
1065                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1066                    } else {
1067                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1068                    };
1069                    out.map(|flt| NumCast::from(flt.sqrt()).unwrap())
1070                })
1071            },
1072            GroupsType::Slice {
1073                groups,
1074                overlapping,
1075                monotonic,
1076            } => {
1077                if _use_rolling_kernels(groups, *overlapping, *monotonic, self.chunks()) {
1078                    let arr = ca.downcast_iter().next().unwrap();
1079                    let values = arr.values().as_slice();
1080                    let offset_iter = groups.iter().map(|[first, len]| (*first, *len));
1081                    let arr = match arr.validity() {
1082                        None => _rolling_apply_agg_window_no_nulls::<
1083                            MomentWindow<_, VarianceMoment>,
1084                            _,
1085                            _,
1086                            _,
1087                        >(
1088                            values,
1089                            offset_iter,
1090                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
1091                        ),
1092                        Some(validity) => _rolling_apply_agg_window_nulls::<
1093                            rolling::nulls::MomentWindow<_, rolling::nulls::VarianceMoment>,
1094                            _,
1095                            _,
1096                            _,
1097                        >(
1098                            values,
1099                            validity,
1100                            offset_iter,
1101                            Some(RollingFnParams::Var(RollingVarParams { ddof })),
1102                        ),
1103                    };
1104
1105                    let mut ca = ChunkedArray::<T>::from(arr);
1106                    ca.apply_mut(|v| v.powf(NumCast::from(0.5).unwrap()));
1107                    ca.into_series()
1108                } else {
1109                    _agg_helper_slice::<T, _>(groups, |[first, len]| {
1110                        debug_assert!(len <= self.len() as IdxSize);
1111                        match len {
1112                            0 => None,
1113                            1 => {
1114                                if ddof == 0 {
1115                                    NumCast::from(0)
1116                                } else {
1117                                    None
1118                                }
1119                            },
1120                            _ => {
1121                                let arr_group = _slice_from_offsets(self, first, len);
1122                                arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap())
1123                            },
1124                        }
1125                    })
1126                }
1127            },
1128        }
1129    }
1130}
1131
1132impl Float32Chunked {
1133    pub(crate) unsafe fn agg_quantile(
1134        &self,
1135        groups: &GroupsType,
1136        quantile: f64,
1137        method: QuantileMethod,
1138    ) -> Series {
1139        agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method)
1140    }
1141    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
1142        agg_median_generic::<_, Float32Type>(self, groups)
1143    }
1144}
1145impl Float64Chunked {
1146    pub(crate) unsafe fn agg_quantile(
1147        &self,
1148        groups: &GroupsType,
1149        quantile: f64,
1150        method: QuantileMethod,
1151    ) -> Series {
1152        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
1153    }
1154    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
1155        agg_median_generic::<_, Float64Type>(self, groups)
1156    }
1157}
1158
1159impl<T> ChunkedArray<T>
1160where
1161    T: PolarsIntegerType,
1162    ChunkedArray<T>: ChunkAgg<T::Native> + ChunkVar,
1163    T::Native: NumericNative + Ord,
1164{
1165    pub(crate) unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
1166        match groups {
1167            GroupsType::Idx(groups) => {
1168                let ca = self.rechunk();
1169                let arr = ca.downcast_get(0).unwrap();
1170                _agg_helper_idx::<Float64Type, _>(groups, |(first, idx)| {
1171                    // this can fail due to a bug in lazy code.
1172                    // here users can create filters in aggregations
1173                    // and thereby creating shorter columns than the original group tuples.
1174                    // the group tuples are modified, but if that's done incorrect there can be out of bounds
1175                    // access
1176                    debug_assert!(idx.len() <= self.len());
1177                    if idx.is_empty() {
1178                        None
1179                    } else if idx.len() == 1 {
1180                        self.get(first as usize).map(|sum| sum.to_f64().unwrap())
1181                    } else {
1182                        match (self.has_nulls(), self.chunks.len()) {
1183                            (false, 1) => Some(
1184                                take_agg_no_null_primitive_iter_unchecked(arr, idx2usize(idx))
1185                                    .fold(KahanSum::default(), |a, b| a + b.to_f64().unwrap())
1186                                    .sum()
1187                                    / idx.len() as f64,
1188                            ),
1189                            (_, 1) => {
1190                                take_agg_primitive_iter_unchecked_count_nulls(
1191                                    arr,
1192                                    idx2usize(idx),
1193                                    KahanSum::default(),
1194                                    |a, b| a + b.to_f64().unwrap(),
1195                                    idx.len() as IdxSize,
1196                                )
1197                            }
1198                            .map(|(sum, null_count)| {
1199                                sum.sum() / (idx.len() as f64 - null_count as f64)
1200                            }),
1201                            _ => {
1202                                let take = { self.take_unchecked(idx) };
1203                                take.mean()
1204                            },
1205                        }
1206                    }
1207                })
1208            },
1209            GroupsType::Slice {
1210                groups: groups_slice,
1211                overlapping,
1212                monotonic,
1213            } => {
1214                if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, self.chunks()) {
1215                    let ca = self
1216                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1217                        .unwrap();
1218                    ca.agg_mean(groups)
1219                } else {
1220                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1221                        debug_assert!(first + len <= self.len() as IdxSize);
1222                        match len {
1223                            0 => None,
1224                            1 => self.get(first as usize).map(|v| NumCast::from(v).unwrap()),
1225                            _ => {
1226                                let arr_group = _slice_from_offsets(self, first, len);
1227                                arr_group.mean()
1228                            },
1229                        }
1230                    })
1231                }
1232            },
1233        }
1234    }
1235
1236    pub(crate) unsafe fn agg_var(&self, groups: &GroupsType, ddof: u8) -> Series {
1237        match groups {
1238            GroupsType::Idx(groups) => {
1239                let ca_self = self.rechunk();
1240                let arr = ca_self.downcast_iter().next().unwrap();
1241                let no_nulls = arr.null_count() == 0;
1242                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1243                    debug_assert!(idx.len() <= arr.len());
1244                    if idx.is_empty() {
1245                        return None;
1246                    }
1247                    if no_nulls {
1248                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1249                    } else {
1250                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1251                    }
1252                })
1253            },
1254            GroupsType::Slice {
1255                groups: groups_slice,
1256                overlapping,
1257                monotonic,
1258            } => {
1259                if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, self.chunks()) {
1260                    let ca = self
1261                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1262                        .unwrap();
1263                    ca.agg_var(groups, ddof)
1264                } else {
1265                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1266                        debug_assert!(first + len <= self.len() as IdxSize);
1267                        match len {
1268                            0 => None,
1269                            1 => {
1270                                if ddof == 0 {
1271                                    NumCast::from(0)
1272                                } else {
1273                                    None
1274                                }
1275                            },
1276                            _ => {
1277                                let arr_group = _slice_from_offsets(self, first, len);
1278                                arr_group.var(ddof)
1279                            },
1280                        }
1281                    })
1282                }
1283            },
1284        }
1285    }
1286    pub(crate) unsafe fn agg_std(&self, groups: &GroupsType, ddof: u8) -> Series {
1287        match groups {
1288            GroupsType::Idx(groups) => {
1289                let ca_self = self.rechunk();
1290                let arr = ca_self.downcast_iter().next().unwrap();
1291                let no_nulls = arr.null_count() == 0;
1292                agg_helper_idx_on_all::<Float64Type, _>(groups, |idx| {
1293                    debug_assert!(idx.len() <= self.len());
1294                    if idx.is_empty() {
1295                        return None;
1296                    }
1297                    let out = if no_nulls {
1298                        take_var_no_null_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1299                    } else {
1300                        take_var_nulls_primitive_iter_unchecked(arr, idx2usize(idx), ddof)
1301                    };
1302                    out.map(|v| v.sqrt())
1303                })
1304            },
1305            GroupsType::Slice {
1306                groups: groups_slice,
1307                overlapping,
1308                monotonic,
1309            } => {
1310                if _use_rolling_kernels(groups_slice, *overlapping, *monotonic, self.chunks()) {
1311                    let ca = self
1312                        .cast_with_options(&DataType::Float64, CastOptions::Overflowing)
1313                        .unwrap();
1314                    ca.agg_std(groups, ddof)
1315                } else {
1316                    _agg_helper_slice::<Float64Type, _>(groups_slice, |[first, len]| {
1317                        debug_assert!(first + len <= self.len() as IdxSize);
1318                        match len {
1319                            0 => None,
1320                            1 => {
1321                                if ddof == 0 {
1322                                    NumCast::from(0)
1323                                } else {
1324                                    None
1325                                }
1326                            },
1327                            _ => {
1328                                let arr_group = _slice_from_offsets(self, first, len);
1329                                arr_group.std(ddof)
1330                            },
1331                        }
1332                    })
1333                }
1334            },
1335        }
1336    }
1337
1338    pub(crate) unsafe fn agg_quantile(
1339        &self,
1340        groups: &GroupsType,
1341        quantile: f64,
1342        method: QuantileMethod,
1343    ) -> Series {
1344        agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
1345    }
1346    pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
1347        agg_median_generic::<_, Float64Type>(self, groups)
1348    }
1349}