polars_core/frame/group_by/aggregations/
dispatch.rs

1use arrow::bitmap::bitmask::BitMask;
2use polars_compute::unique::{AmortizedUnique, amortized_unique_from_dtype};
3
4use super::*;
5use crate::prelude::row_encode::encode_rows_unordered;
6
7// implemented on the series because we don't need types
8impl Series {
9    fn slice_from_offsets(&self, first: IdxSize, len: IdxSize) -> Self {
10        self.slice(first as i64, len as usize)
11    }
12
13    unsafe fn restore_logical(&self, out: Series) -> Series {
14        if self.dtype().is_logical() && !out.dtype().is_logical() {
15            out.from_physical_unchecked(self.dtype()).unwrap()
16        } else {
17            out
18        }
19    }
20
21    #[doc(hidden)]
22    pub unsafe fn agg_valid_count(&self, groups: &GroupsType) -> Series {
23        // Prevent a rechunk for every individual group.
24        let s = if groups.len() > 1 && self.null_count() > 0 {
25            self.rechunk()
26        } else {
27            self.clone()
28        };
29
30        match groups {
31            GroupsType::Idx(groups) => agg_helper_idx_on_all::<IdxType, _>(groups, |idx| {
32                debug_assert!(idx.len() <= s.len());
33                if idx.is_empty() {
34                    None
35                } else if s.null_count() == 0 {
36                    Some(idx.len() as IdxSize)
37                } else {
38                    let take = unsafe { s.take_slice_unchecked(idx) };
39                    Some((take.len() - take.null_count()) as IdxSize)
40                }
41            }),
42            GroupsType::Slice { groups, .. } => {
43                _agg_helper_slice::<IdxType, _>(groups, |[first, len]| {
44                    debug_assert!(len <= s.len() as IdxSize);
45                    if len == 0 {
46                        None
47                    } else if s.null_count() == 0 {
48                        Some(len)
49                    } else {
50                        let take = s.slice_from_offsets(first, len);
51                        Some((take.len() - take.null_count()) as IdxSize)
52                    }
53                })
54            },
55        }
56    }
57
58    #[doc(hidden)]
59    pub unsafe fn agg_first(&self, groups: &GroupsType) -> Series {
60        // Prevent a rechunk for every individual group.
61        let s = if groups.len() > 1 {
62            self.rechunk()
63        } else {
64            self.clone()
65        };
66
67        let mut out = match groups {
68            GroupsType::Idx(groups) => {
69                let indices = groups
70                    .iter()
71                    .map(
72                        |(first, idx)| {
73                            if idx.is_empty() { None } else { Some(first) }
74                        },
75                    )
76                    .collect_ca(PlSmallStr::EMPTY);
77                // SAFETY: groups are always in bounds.
78                s.take_unchecked(&indices)
79            },
80            GroupsType::Slice { groups, .. } => {
81                let indices = groups
82                    .iter()
83                    .map(|&[first, len]| if len == 0 { None } else { Some(first) })
84                    .collect_ca(PlSmallStr::EMPTY);
85                // SAFETY: groups are always in bounds.
86                s.take_unchecked(&indices)
87            },
88        };
89        if groups.is_sorted_flag() {
90            out.set_sorted_flag(s.is_sorted_flag())
91        }
92        s.restore_logical(out)
93    }
94
95    #[doc(hidden)]
96    pub unsafe fn agg_first_non_null(&self, groups: &GroupsType) -> Series {
97        if !self.has_nulls() {
98            return self.agg_first(groups);
99        }
100
101        // Prevent a rechunk for every individual group.
102        let s = if groups.len() > 1 {
103            self.rechunk()
104        } else {
105            self.clone()
106        };
107
108        let validity = s.rechunk_validity().unwrap();
109        let indices = match groups {
110            GroupsType::Idx(groups) => {
111                groups
112                    .iter()
113                    .map(|(_, idx)| {
114                        let mut this_idx = None;
115                        for &ii in idx.iter() {
116                            // SAFETY: null_values has no null values
117                            if validity.get_bit_unchecked(ii as usize) {
118                                this_idx = Some(ii);
119                                break;
120                            }
121                        }
122                        this_idx
123                    })
124                    .collect_ca(PlSmallStr::EMPTY)
125            },
126            GroupsType::Slice { groups, .. } => {
127                let mask = BitMask::from_bitmap(&validity);
128                groups
129                    .iter()
130                    .map(|&[first, len]| {
131                        // SAFETY: group slice is valid.
132                        let validity = mask.sliced_unchecked(first as usize, len as usize);
133                        let leading_zeros = validity.leading_zeros() as IdxSize;
134                        if leading_zeros == len {
135                            // All values are null, we have no first non-null.
136                            None
137                        } else {
138                            Some(first + leading_zeros)
139                        }
140                    })
141                    .collect_ca(PlSmallStr::EMPTY)
142            },
143        };
144        // SAFETY: groups are always in bounds.
145        let mut out = s.take_unchecked(&indices);
146        if groups.is_sorted_flag() {
147            out.set_sorted_flag(s.is_sorted_flag())
148        }
149        s.restore_logical(out)
150    }
151
152    #[doc(hidden)]
153    pub unsafe fn agg_n_unique(&self, groups: &GroupsType) -> Series {
154        let values = self.to_physical_repr();
155        let dtype = values.dtype();
156        let values = if dtype.contains_objects() {
157            panic!("{}", polars_err!(opq = unique, dtype));
158        } else if let Some(ca) = values.try_str() {
159            ca.as_binary().into_column()
160        } else if dtype.is_nested() {
161            encode_rows_unordered(&[values.into_owned().into_column()])
162                .unwrap()
163                .into_column()
164        } else {
165            values.into_owned().into_column()
166        };
167
168        let values = values.rechunk_to_arrow(CompatLevel::newest());
169        let values = values.as_ref();
170        let state = amortized_unique_from_dtype(values.dtype());
171
172        struct CloneWrapper(Box<dyn AmortizedUnique>);
173        impl Clone for CloneWrapper {
174            fn clone(&self) -> Self {
175                Self(self.0.new_empty())
176            }
177        }
178
179        POOL.install(|| match groups {
180            GroupsType::Idx(idx) => idx
181                .all()
182                .into_par_iter()
183                .map_with(CloneWrapper(state), |state, idxs| unsafe {
184                    state.0.n_unique_idx(values, idxs.as_slice())
185                })
186                .collect::<NoNull<IdxCa>>(),
187            GroupsType::Slice {
188                groups,
189                overlapping: _,
190                monotonic: _,
191            } => groups
192                .into_par_iter()
193                .map_with(CloneWrapper(state), |state, [start, len]| {
194                    state.0.n_unique_slice(values, *start, *len)
195                })
196                .collect::<NoNull<IdxCa>>(),
197        })
198        .into_inner()
199        .into_series()
200    }
201
202    #[doc(hidden)]
203    pub unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
204        // Prevent a rechunk for every individual group.
205        let s = if groups.len() > 1 {
206            self.rechunk()
207        } else {
208            self.clone()
209        };
210
211        use DataType::*;
212        match s.dtype() {
213            Boolean => s.cast(&Float64).unwrap().agg_mean(groups),
214            Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_mean(groups),
215            Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_mean(groups),
216            dt if dt.is_primitive_numeric() => apply_method_physical_integer!(s, agg_mean, groups),
217            #[cfg(feature = "dtype-decimal")]
218            Decimal(_, _) => self.cast(&Float64).unwrap().agg_mean(groups),
219            #[cfg(feature = "dtype-datetime")]
220            dt @ Datetime(_, _) => self
221                .to_physical_repr()
222                .agg_mean(groups)
223                .cast(&Int64)
224                .unwrap()
225                .cast(dt)
226                .unwrap(),
227            #[cfg(feature = "dtype-duration")]
228            dt @ Duration(_) => self
229                .to_physical_repr()
230                .agg_mean(groups)
231                .cast(&Int64)
232                .unwrap()
233                .cast(dt)
234                .unwrap(),
235            #[cfg(feature = "dtype-time")]
236            Time => self
237                .to_physical_repr()
238                .agg_mean(groups)
239                .cast(&Int64)
240                .unwrap()
241                .cast(&Time)
242                .unwrap(),
243            #[cfg(feature = "dtype-date")]
244            Date => (self
245                .to_physical_repr()
246                .agg_mean(groups)
247                .cast(&Float64)
248                .unwrap()
249                * (US_IN_DAY as f64))
250                .cast(&Datetime(TimeUnit::Microseconds, None))
251                .unwrap(),
252            _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
253        }
254    }
255
256    #[doc(hidden)]
257    pub unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
258        // Prevent a rechunk for every individual group.
259        let s = if groups.len() > 1 {
260            self.rechunk()
261        } else {
262            self.clone()
263        };
264
265        use DataType::*;
266        match s.dtype() {
267            Boolean => s.cast(&Float64).unwrap().agg_median(groups),
268            Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_median(groups),
269            Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_median(groups),
270            dt if dt.is_primitive_numeric() => {
271                apply_method_physical_integer!(s, agg_median, groups)
272            },
273            #[cfg(feature = "dtype-decimal")]
274            Decimal(_, _) => self.cast(&Float64).unwrap().agg_median(groups),
275            #[cfg(feature = "dtype-datetime")]
276            dt @ Datetime(_, _) => self
277                .to_physical_repr()
278                .agg_median(groups)
279                .cast(&Int64)
280                .unwrap()
281                .cast(dt)
282                .unwrap(),
283            #[cfg(feature = "dtype-duration")]
284            dt @ Duration(_) => self
285                .to_physical_repr()
286                .agg_median(groups)
287                .cast(&Int64)
288                .unwrap()
289                .cast(dt)
290                .unwrap(),
291            #[cfg(feature = "dtype-time")]
292            Time => self
293                .to_physical_repr()
294                .agg_median(groups)
295                .cast(&Int64)
296                .unwrap()
297                .cast(&Time)
298                .unwrap(),
299            #[cfg(feature = "dtype-date")]
300            Date => (self
301                .to_physical_repr()
302                .agg_median(groups)
303                .cast(&Float64)
304                .unwrap()
305                * (US_IN_DAY as f64))
306                .cast(&Datetime(TimeUnit::Microseconds, None))
307                .unwrap(),
308            _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
309        }
310    }
311
312    #[doc(hidden)]
313    pub unsafe fn agg_quantile(
314        &self,
315        groups: &GroupsType,
316        quantile: f64,
317        method: QuantileMethod,
318    ) -> Series {
319        // Prevent a rechunk for every individual group.
320        let s = if groups.len() > 1 {
321            self.rechunk()
322        } else {
323            self.clone()
324        };
325
326        use DataType::*;
327        match s.dtype() {
328            Float32 => s.f32().unwrap().agg_quantile(groups, quantile, method),
329            Float64 => s.f64().unwrap().agg_quantile(groups, quantile, method),
330            #[cfg(feature = "dtype-decimal")]
331            Decimal(_, _) => s
332                .cast(&DataType::Float64)
333                .unwrap()
334                .agg_quantile(groups, quantile, method),
335            #[cfg(feature = "dtype-datetime")]
336            Datetime(tu, tz) => self
337                .to_physical_repr()
338                .agg_quantile(groups, quantile, method)
339                .cast(&Int64)
340                .unwrap()
341                .into_datetime(*tu, tz.clone()),
342            #[cfg(feature = "dtype-duration")]
343            Duration(tu) => self
344                .to_physical_repr()
345                .agg_quantile(groups, quantile, method)
346                .cast(&Int64)
347                .unwrap()
348                .into_duration(*tu),
349            #[cfg(feature = "dtype-time")]
350            Time => self
351                .to_physical_repr()
352                .agg_quantile(groups, quantile, method)
353                .cast(&Int64)
354                .unwrap()
355                .into_time(),
356            #[cfg(feature = "dtype-date")]
357            Date => (self
358                .to_physical_repr()
359                .agg_quantile(groups, quantile, method)
360                .cast(&Float64)
361                .unwrap()
362                * (US_IN_DAY as f64))
363                .cast(&DataType::Int64)
364                .unwrap()
365                .into_datetime(TimeUnit::Microseconds, None),
366            dt if dt.is_primitive_numeric() => {
367                apply_method_physical_integer!(s, agg_quantile, groups, quantile, method)
368            },
369            _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
370        }
371    }
372
373    #[doc(hidden)]
374    pub unsafe fn agg_last(&self, groups: &GroupsType) -> Series {
375        // Prevent a rechunk for every individual group.
376        let s = if groups.len() > 1 {
377            self.rechunk()
378        } else {
379            self.clone()
380        };
381
382        let out = match groups {
383            GroupsType::Idx(groups) => {
384                let indices = groups
385                    .all()
386                    .iter()
387                    .map(|idx| {
388                        if idx.is_empty() {
389                            None
390                        } else {
391                            Some(idx[idx.len() - 1])
392                        }
393                    })
394                    .collect_ca(PlSmallStr::EMPTY);
395                s.take_unchecked(&indices)
396            },
397            GroupsType::Slice { groups, .. } => {
398                let indices = groups
399                    .iter()
400                    .map(|&[first, len]| {
401                        if len == 0 {
402                            None
403                        } else {
404                            Some(first + len - 1)
405                        }
406                    })
407                    .collect_ca(PlSmallStr::EMPTY);
408                s.take_unchecked(&indices)
409            },
410        };
411        s.restore_logical(out)
412    }
413
414    #[doc(hidden)]
415    pub unsafe fn agg_last_non_null(&self, groups: &GroupsType) -> Series {
416        if !self.has_nulls() {
417            return self.agg_last(groups);
418        }
419
420        // Prevent a rechunk for every individual group.
421        let s = if groups.len() > 1 {
422            self.rechunk()
423        } else {
424            self.clone()
425        };
426
427        let validity = s.rechunk_validity().unwrap();
428        let indices = match groups {
429            GroupsType::Idx(groups) => {
430                groups
431                    .iter()
432                    .map(|(_, idx)| {
433                        // We may or may not find a valid value.
434                        let mut opt_idx = None;
435                        for &ii in idx.iter().rev() {
436                            // SAFETY: index is always in range.
437                            if validity.get_bit_unchecked(ii as usize) {
438                                opt_idx = Some(ii);
439                                break;
440                            }
441                        }
442                        opt_idx
443                    })
444                    .collect_ca(PlSmallStr::EMPTY)
445            },
446            GroupsType::Slice { groups, .. } => {
447                let mask = BitMask::from_bitmap(&validity);
448                groups
449                    .iter()
450                    .map(|&[first, len]| {
451                        // SAFETY: group slice is valid.
452                        let validity = mask.sliced_unchecked(first as usize, len as usize);
453                        let trailing_zeros = validity.trailing_zeros() as IdxSize;
454                        if trailing_zeros == len {
455                            // All values are null, we have no last non-null.
456                            None
457                        } else {
458                            Some(first + len - trailing_zeros - 1)
459                        }
460                    })
461                    .collect_ca(PlSmallStr::EMPTY)
462            },
463        };
464        // SAFETY: groups are always in bounds.
465        let mut out = s.take_unchecked(&indices);
466        if groups.is_sorted_flag() {
467            out.set_sorted_flag(s.is_sorted_flag())
468        }
469        s.restore_logical(out)
470    }
471}