Skip to main content

polars_core/frame/group_by/aggregations/
dispatch.rs

1use arrow::bitmap::bitmask::BitMask;
2use polars_compute::unique::{AmortizedUnique, amortized_unique_from_dtype};
3
4use super::*;
5use crate::prelude::row_encode::encode_rows_unordered;
6
7// Groups larger than this route to `Series::n_unique` (radix sort + scan
8// for primitives, hashset for binary) instead of the amortized
9// `AmortizedUnique` hashset: sort wins on cost, and keeping big groups
10// out of the amortized hashset bounds its capacity, which avoids the
11// O(capacity) `.clear()` storm of polars#27655.
12const N_UNIQUE_SORT_FALLBACK_THRESHOLD: usize = 16384;
13
14// implemented on the series because we don't need types
15impl Series {
16    unsafe fn restore_logical(&self, out: Series) -> Series {
17        if self.dtype().is_logical() && !out.dtype().is_logical() {
18            out.from_physical_unchecked(self.dtype()).unwrap()
19        } else {
20            out
21        }
22    }
23
24    #[doc(hidden)]
25    pub unsafe fn agg_valid_count(&self, groups: &GroupsType) -> Series {
26        // Prevent a rechunk for every individual group.
27        let valid = self.rechunk_validity();
28
29        match groups {
30            GroupsType::Idx(groups) => agg_helper_idx_on_all::<IdxType, _>(groups, |idxs| {
31                debug_assert!(idxs.len() <= self.len());
32                if let Some(v) = &valid {
33                    let mut count = 0;
34                    for idx in idxs.iter() {
35                        count += unsafe { v.get_bit_unchecked(*idx as usize) as IdxSize };
36                    }
37                    Some(count)
38                } else {
39                    Some(self.len() as IdxSize)
40                }
41            }),
42            GroupsType::Slice { groups, .. } => {
43                _agg_helper_slice::<IdxType, _>(groups, |[first, len]| {
44                    debug_assert!(len <= self.len() as IdxSize);
45                    if let Some(v) = &valid {
46                        let m = BitMask::from_bitmap(v).sliced(first as usize, len as usize);
47                        Some(m.set_bits() as IdxSize)
48                    } else {
49                        Some(self.len() as IdxSize)
50                    }
51                })
52            },
53        }
54    }
55
56    #[doc(hidden)]
57    pub unsafe fn agg_first(&self, groups: &GroupsType) -> Series {
58        // Prevent a rechunk for every individual group.
59        let s = if groups.len() > 1 {
60            self.rechunk()
61        } else {
62            self.clone()
63        };
64
65        let mut out = match groups {
66            GroupsType::Idx(groups) => {
67                let indices = groups
68                    .iter()
69                    .map(
70                        |(first, idx)| {
71                            if idx.is_empty() { None } else { Some(first) }
72                        },
73                    )
74                    .collect_ca(PlSmallStr::EMPTY);
75                // SAFETY: groups are always in bounds.
76                s.take_unchecked(&indices)
77            },
78            GroupsType::Slice { groups, .. } => {
79                let indices = groups
80                    .iter()
81                    .map(|&[first, len]| if len == 0 { None } else { Some(first) })
82                    .collect_ca(PlSmallStr::EMPTY);
83                // SAFETY: groups are always in bounds.
84                s.take_unchecked(&indices)
85            },
86        };
87        if groups.is_sorted_by_first_idx() {
88            out.set_sorted_flag(s.is_sorted_flag())
89        }
90        s.restore_logical(out)
91    }
92
93    #[doc(hidden)]
94    pub unsafe fn agg_first_non_null(&self, groups: &GroupsType) -> Series {
95        if !self.has_nulls() {
96            return self.agg_first(groups);
97        }
98
99        // Prevent a rechunk for every individual group.
100        let s = if groups.len() > 1 {
101            self.rechunk()
102        } else {
103            self.clone()
104        };
105
106        let validity = s.rechunk_validity().unwrap();
107        let indices = match groups {
108            GroupsType::Idx(groups) => {
109                groups
110                    .iter()
111                    .map(|(_, idx)| {
112                        let mut this_idx = None;
113                        for &ii in idx.iter() {
114                            // SAFETY: null_values has no null values
115                            if validity.get_bit_unchecked(ii as usize) {
116                                this_idx = Some(ii);
117                                break;
118                            }
119                        }
120                        this_idx
121                    })
122                    .collect_ca(PlSmallStr::EMPTY)
123            },
124            GroupsType::Slice { groups, .. } => {
125                let mask = BitMask::from_bitmap(&validity);
126                groups
127                    .iter()
128                    .map(|&[first, len]| {
129                        // SAFETY: group slice is valid.
130                        let validity = mask.sliced_unchecked(first as usize, len as usize);
131                        let leading_zeros = validity.leading_zeros() as IdxSize;
132                        if leading_zeros == len {
133                            // All values are null, we have no first non-null.
134                            None
135                        } else {
136                            Some(first + leading_zeros)
137                        }
138                    })
139                    .collect_ca(PlSmallStr::EMPTY)
140            },
141        };
142        // SAFETY: groups are always in bounds.
143        let mut out = s.take_unchecked(&indices);
144        if matches!(groups, GroupsType::Slice { .. }) && !groups.is_overlapping() {
145            out.set_sorted_flag(s.is_sorted_flag())
146        }
147        s.restore_logical(out)
148    }
149
150    #[doc(hidden)]
151    pub unsafe fn agg_arg_first(&self, groups: &GroupsType) -> Series {
152        let out: IdxCa = match groups {
153            GroupsType::Idx(groups) => groups
154                .iter()
155                .map(|(_, idx)| {
156                    if idx.is_empty() {
157                        None
158                    } else {
159                        Some(0 as IdxSize)
160                    }
161                })
162                .collect_ca(PlSmallStr::EMPTY),
163
164            GroupsType::Slice { groups, .. } => groups
165                .iter()
166                .map(|&[_first, len]| if len == 0 { None } else { Some(0 as IdxSize) })
167                .collect_ca(PlSmallStr::EMPTY),
168        };
169        out.into_series()
170    }
171
172    #[doc(hidden)]
173    pub unsafe fn agg_arg_first_non_null(&self, groups: &GroupsType) -> Series {
174        if !self.has_nulls() {
175            return self.agg_arg_first(groups);
176        }
177
178        let validity = self.rechunk_validity().unwrap();
179
180        let out: IdxCa = match groups {
181            GroupsType::Idx(groups) => groups
182                .iter()
183                .map(|(_, idx)| {
184                    let mut pos: Option<IdxSize> = None;
185                    for (p, &ii) in idx.iter().enumerate() {
186                        if validity.get_bit_unchecked(ii as usize) {
187                            pos = Some(p as IdxSize);
188                            break;
189                        }
190                    }
191                    pos
192                })
193                .collect_ca(PlSmallStr::EMPTY),
194
195            GroupsType::Slice { groups, .. } => {
196                let mask = BitMask::from_bitmap(&validity);
197                groups
198                    .iter()
199                    .map(|&[first, len]| {
200                        if len == 0 {
201                            return None;
202                        }
203                        let v = mask.sliced_unchecked(first as usize, len as usize);
204                        let lz = v.leading_zeros() as IdxSize;
205                        if lz == len { None } else { Some(lz) }
206                    })
207                    .collect_ca(PlSmallStr::EMPTY)
208            },
209        };
210
211        out.into_series()
212    }
213
214    #[doc(hidden)]
215    pub unsafe fn agg_arg_last(&self, groups: &GroupsType) -> Series {
216        let out: IdxCa = match groups {
217            GroupsType::Idx(groups) => groups
218                .all()
219                .iter()
220                .map(|idx| {
221                    if idx.is_empty() {
222                        None
223                    } else {
224                        Some((idx.len() - 1) as IdxSize)
225                    }
226                })
227                .collect_ca(PlSmallStr::EMPTY),
228
229            GroupsType::Slice { groups, .. } => groups
230                .iter()
231                .map(|&[_first, len]| {
232                    if len == 0 {
233                        None
234                    } else {
235                        Some((len - 1) as IdxSize)
236                    }
237                })
238                .collect_ca(PlSmallStr::EMPTY),
239        };
240
241        out.into_series()
242    }
243
244    #[doc(hidden)]
245    pub unsafe fn agg_arg_last_non_null(&self, groups: &GroupsType) -> Series {
246        if !self.has_nulls() {
247            return self.agg_arg_last(groups);
248        }
249
250        let validity = self.rechunk_validity().unwrap();
251
252        let out: IdxCa = match groups {
253            GroupsType::Idx(groups) => groups
254                .iter()
255                .map(|(_, idx)| {
256                    for (p, &ii) in idx.iter().enumerate().rev() {
257                        if validity.get_bit_unchecked(ii as usize) {
258                            return Some(p as IdxSize);
259                        }
260                    }
261                    None
262                })
263                .collect_ca(PlSmallStr::EMPTY),
264
265            GroupsType::Slice { groups, .. } => {
266                let mask = BitMask::from_bitmap(&validity);
267                groups
268                    .iter()
269                    .map(|&[first, len]| {
270                        if len == 0 {
271                            return None;
272                        }
273                        let v = mask.sliced_unchecked(first as usize, len as usize);
274                        let tz = v.trailing_zeros() as IdxSize;
275                        if tz == len { None } else { Some(len - tz - 1) }
276                    })
277                    .collect_ca(PlSmallStr::EMPTY)
278            },
279        };
280
281        out.into_series()
282    }
283
284    #[doc(hidden)]
285    pub unsafe fn agg_n_unique(&self, groups: &GroupsType) -> Series {
286        let values = self.to_physical_repr();
287        let dtype = values.dtype();
288        let values = if dtype.contains_objects() {
289            panic!("{}", polars_err!(opq = unique, dtype));
290        } else if let Some(ca) = values.try_str() {
291            ca.as_binary().into_column()
292        } else if dtype.is_nested() {
293            encode_rows_unordered(&[values.into_owned().into_column()])
294                .unwrap()
295                .into_column()
296        } else {
297            values.into_owned().into_column()
298        };
299
300        // Keep the Column for the sort-fallback path. Big groups go through
301        // `Series::n_unique`, bypassing the amortized hashset.
302        let col = values.clone();
303        let values = values.rechunk_to_arrow(CompatLevel::newest());
304        let values = values.as_ref();
305        let state = amortized_unique_from_dtype(values.dtype());
306
307        struct CloneWrapper(Box<dyn AmortizedUnique>);
308        impl Clone for CloneWrapper {
309            fn clone(&self) -> Self {
310                Self(self.0.new_empty())
311            }
312        }
313
314        // SAFETY for the `.unwrap()` on `Series::n_unique()` below: we've
315        // already filtered out dtypes that can fail (objects panic above;
316        // nested types are row-encoded to binary). All remaining dtypes
317        // have a `ChunkUnique` impl that infallibly returns `Ok(_)`.
318        RAYON
319            .install(|| match groups {
320                GroupsType::Idx(idx) => idx
321                    .all()
322                    .into_par_iter()
323                    .map_with(CloneWrapper(state), |state, idxs| unsafe {
324                        let idxs = idxs.as_slice();
325                        if idxs.len() > N_UNIQUE_SORT_FALLBACK_THRESHOLD {
326                            col.take_slice_unchecked(idxs).n_unique().unwrap() as IdxSize
327                        } else {
328                            state.0.n_unique_idx(values, idxs)
329                        }
330                    })
331                    .collect::<NoNull<IdxCa>>(),
332                GroupsType::Slice {
333                    groups,
334                    overlapping: _,
335                    monotonic: _,
336                } => groups
337                    .into_par_iter()
338                    .map_with(CloneWrapper(state), |state, &[start, len]| {
339                        let len_us = len as usize;
340                        if len_us > N_UNIQUE_SORT_FALLBACK_THRESHOLD {
341                            col.slice(start as i64, len_us).n_unique().unwrap() as IdxSize
342                        } else {
343                            state.0.n_unique_slice(values, start, len)
344                        }
345                    })
346                    .collect::<NoNull<IdxCa>>(),
347            })
348            .into_inner()
349            .into_series()
350    }
351
352    #[doc(hidden)]
353    pub unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
354        // Prevent a rechunk for every individual group.
355        let s = if groups.len() > 1 {
356            self.rechunk()
357        } else {
358            self.clone()
359        };
360
361        use DataType::*;
362        match s.dtype() {
363            Boolean => s.cast(&Float64).unwrap().agg_mean(groups),
364            Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_mean(groups),
365            Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_mean(groups),
366            dt if dt.is_primitive_numeric() => apply_method_physical_integer!(s, agg_mean, groups),
367            #[cfg(feature = "dtype-decimal")]
368            Decimal(_, _) => self.cast(&Float64).unwrap().agg_mean(groups),
369            #[cfg(feature = "dtype-datetime")]
370            dt @ Datetime(_, _) => self
371                .to_physical_repr()
372                .agg_mean(groups)
373                .cast(&Int64)
374                .unwrap()
375                .cast(dt)
376                .unwrap(),
377            #[cfg(feature = "dtype-duration")]
378            dt @ Duration(_) => self
379                .to_physical_repr()
380                .agg_mean(groups)
381                .cast(&Int64)
382                .unwrap()
383                .cast(dt)
384                .unwrap(),
385            #[cfg(feature = "dtype-time")]
386            Time => self
387                .to_physical_repr()
388                .agg_mean(groups)
389                .cast(&Int64)
390                .unwrap()
391                .cast(&Time)
392                .unwrap(),
393            #[cfg(feature = "dtype-date")]
394            Date => (self
395                .to_physical_repr()
396                .agg_mean(groups)
397                .cast(&Float64)
398                .unwrap()
399                * (US_IN_DAY as f64))
400                .cast(&Datetime(TimeUnit::Microseconds, None))
401                .unwrap(),
402            _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
403        }
404    }
405
406    #[doc(hidden)]
407    pub unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
408        // Prevent a rechunk for every individual group.
409        let s = if groups.len() > 1 {
410            self.rechunk()
411        } else {
412            self.clone()
413        };
414
415        use DataType::*;
416        match s.dtype() {
417            Boolean => s.cast(&Float64).unwrap().agg_median(groups),
418            Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_median(groups),
419            Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_median(groups),
420            dt if dt.is_primitive_numeric() => {
421                apply_method_physical_integer!(s, agg_median, groups)
422            },
423            #[cfg(feature = "dtype-decimal")]
424            Decimal(_, _) => self.cast(&Float64).unwrap().agg_median(groups),
425            #[cfg(feature = "dtype-datetime")]
426            dt @ Datetime(_, _) => self
427                .to_physical_repr()
428                .agg_median(groups)
429                .cast(&Int64)
430                .unwrap()
431                .cast(dt)
432                .unwrap(),
433            #[cfg(feature = "dtype-duration")]
434            dt @ Duration(_) => self
435                .to_physical_repr()
436                .agg_median(groups)
437                .cast(&Int64)
438                .unwrap()
439                .cast(dt)
440                .unwrap(),
441            #[cfg(feature = "dtype-time")]
442            Time => self
443                .to_physical_repr()
444                .agg_median(groups)
445                .cast(&Int64)
446                .unwrap()
447                .cast(&Time)
448                .unwrap(),
449            #[cfg(feature = "dtype-date")]
450            Date => (self
451                .to_physical_repr()
452                .agg_median(groups)
453                .cast(&Float64)
454                .unwrap()
455                * (US_IN_DAY as f64))
456                .cast(&Datetime(TimeUnit::Microseconds, None))
457                .unwrap(),
458            _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
459        }
460    }
461
462    #[doc(hidden)]
463    pub unsafe fn agg_quantile(
464        &self,
465        groups: &GroupsType,
466        quantile: f64,
467        method: QuantileMethod,
468    ) -> Series {
469        // Prevent a rechunk for every individual group.
470        let s = if groups.len() > 1 {
471            self.rechunk()
472        } else {
473            self.clone()
474        };
475
476        use DataType::*;
477        match s.dtype() {
478            Float32 => s.f32().unwrap().agg_quantile(groups, quantile, method),
479            Float64 => s.f64().unwrap().agg_quantile(groups, quantile, method),
480            #[cfg(feature = "dtype-decimal")]
481            Decimal(_, _) => s
482                .cast(&DataType::Float64)
483                .unwrap()
484                .agg_quantile(groups, quantile, method),
485            #[cfg(feature = "dtype-datetime")]
486            Datetime(tu, tz) => self
487                .to_physical_repr()
488                .agg_quantile(groups, quantile, method)
489                .cast(&Int64)
490                .unwrap()
491                .into_datetime(*tu, tz.clone()),
492            #[cfg(feature = "dtype-duration")]
493            Duration(tu) => self
494                .to_physical_repr()
495                .agg_quantile(groups, quantile, method)
496                .cast(&Int64)
497                .unwrap()
498                .into_duration(*tu),
499            #[cfg(feature = "dtype-time")]
500            Time => self
501                .to_physical_repr()
502                .agg_quantile(groups, quantile, method)
503                .cast(&Int64)
504                .unwrap()
505                .into_time(),
506            #[cfg(feature = "dtype-date")]
507            Date => (self
508                .to_physical_repr()
509                .agg_quantile(groups, quantile, method)
510                .cast(&Float64)
511                .unwrap()
512                * (US_IN_DAY as f64))
513                .cast(&DataType::Int64)
514                .unwrap()
515                .into_datetime(TimeUnit::Microseconds, None),
516            dt if dt.is_primitive_numeric() => {
517                apply_method_physical_integer!(s, agg_quantile, groups, quantile, method)
518            },
519            _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
520        }
521    }
522
523    #[doc(hidden)]
524    pub unsafe fn agg_last(&self, groups: &GroupsType) -> Series {
525        // Prevent a rechunk for every individual group.
526        let s = if groups.len() > 1 {
527            self.rechunk()
528        } else {
529            self.clone()
530        };
531
532        let out = match groups {
533            GroupsType::Idx(groups) => {
534                let indices = groups
535                    .all()
536                    .iter()
537                    .map(|idx| {
538                        if idx.is_empty() {
539                            None
540                        } else {
541                            Some(idx[idx.len() - 1])
542                        }
543                    })
544                    .collect_ca(PlSmallStr::EMPTY);
545                s.take_unchecked(&indices)
546            },
547            GroupsType::Slice { groups, .. } => {
548                let indices = groups
549                    .iter()
550                    .map(|&[first, len]| {
551                        if len == 0 {
552                            None
553                        } else {
554                            Some(first + len - 1)
555                        }
556                    })
557                    .collect_ca(PlSmallStr::EMPTY);
558                s.take_unchecked(&indices)
559            },
560        };
561        s.restore_logical(out)
562    }
563
564    #[doc(hidden)]
565    pub unsafe fn agg_last_non_null(&self, groups: &GroupsType) -> Series {
566        if !self.has_nulls() {
567            return self.agg_last(groups);
568        }
569
570        // Prevent a rechunk for every individual group.
571        let s = if groups.len() > 1 {
572            self.rechunk()
573        } else {
574            self.clone()
575        };
576
577        let validity = s.rechunk_validity().unwrap();
578        let indices = match groups {
579            GroupsType::Idx(groups) => {
580                groups
581                    .iter()
582                    .map(|(_, idx)| {
583                        // We may or may not find a valid value.
584                        let mut opt_idx = None;
585                        for &ii in idx.iter().rev() {
586                            // SAFETY: index is always in range.
587                            if validity.get_bit_unchecked(ii as usize) {
588                                opt_idx = Some(ii);
589                                break;
590                            }
591                        }
592                        opt_idx
593                    })
594                    .collect_ca(PlSmallStr::EMPTY)
595            },
596            GroupsType::Slice { groups, .. } => {
597                let mask = BitMask::from_bitmap(&validity);
598                groups
599                    .iter()
600                    .map(|&[first, len]| {
601                        // SAFETY: group slice is valid.
602                        let validity = mask.sliced_unchecked(first as usize, len as usize);
603                        let trailing_zeros = validity.trailing_zeros() as IdxSize;
604                        if trailing_zeros == len {
605                            // All values are null, we have no last non-null.
606                            None
607                        } else {
608                            Some(first + len - trailing_zeros - 1)
609                        }
610                    })
611                    .collect_ca(PlSmallStr::EMPTY)
612            },
613        };
614        // SAFETY: groups are always in bounds.
615        let mut out = s.take_unchecked(&indices);
616        if groups.is_monotonic() {
617            out.set_sorted_flag(s.is_sorted_flag())
618        }
619        s.restore_logical(out)
620    }
621}