polars_core/frame/group_by/aggregations/
dispatch.rs

1use super::*;
2
3// implemented on the series because we don't need types
4impl Series {
5    fn slice_from_offsets(&self, first: IdxSize, len: IdxSize) -> Self {
6        self.slice(first as i64, len as usize)
7    }
8
9    fn restore_logical(&self, out: Series) -> Series {
10        if self.dtype().is_logical() {
11            out.cast(self.dtype()).unwrap()
12        } else {
13            out
14        }
15    }
16
17    #[doc(hidden)]
18    pub unsafe fn agg_valid_count(&self, groups: &GroupsType) -> Series {
19        // Prevent a rechunk for every individual group.
20        let s = if groups.len() > 1 && self.null_count() > 0 {
21            self.rechunk()
22        } else {
23            self.clone()
24        };
25
26        match groups {
27            GroupsType::Idx(groups) => agg_helper_idx_on_all::<IdxType, _>(groups, |idx| {
28                debug_assert!(idx.len() <= s.len());
29                if idx.is_empty() {
30                    None
31                } else if s.null_count() == 0 {
32                    Some(idx.len() as IdxSize)
33                } else {
34                    let take = unsafe { s.take_slice_unchecked(idx) };
35                    Some((take.len() - take.null_count()) as IdxSize)
36                }
37            }),
38            GroupsType::Slice { groups, .. } => {
39                _agg_helper_slice::<IdxType, _>(groups, |[first, len]| {
40                    debug_assert!(len <= s.len() as IdxSize);
41                    if len == 0 {
42                        None
43                    } else if s.null_count() == 0 {
44                        Some(len)
45                    } else {
46                        let take = s.slice_from_offsets(first, len);
47                        Some((take.len() - take.null_count()) as IdxSize)
48                    }
49                })
50            },
51        }
52    }
53
54    #[doc(hidden)]
55    pub unsafe fn agg_first(&self, groups: &GroupsType) -> Series {
56        // Prevent a rechunk for every individual group.
57        let s = if groups.len() > 1 {
58            self.rechunk()
59        } else {
60            self.clone()
61        };
62
63        let mut out = match groups {
64            GroupsType::Idx(groups) => {
65                let indices = groups
66                    .iter()
67                    .map(
68                        |(first, idx)| {
69                            if idx.is_empty() { None } else { Some(first) }
70                        },
71                    )
72                    .collect_ca(PlSmallStr::EMPTY);
73                // SAFETY: groups are always in bounds.
74                s.take_unchecked(&indices)
75            },
76            GroupsType::Slice { groups, .. } => {
77                let indices = groups
78                    .iter()
79                    .map(|&[first, len]| if len == 0 { None } else { Some(first) })
80                    .collect_ca(PlSmallStr::EMPTY);
81                // SAFETY: groups are always in bounds.
82                s.take_unchecked(&indices)
83            },
84        };
85        if groups.is_sorted_flag() {
86            out.set_sorted_flag(s.is_sorted_flag())
87        }
88        s.restore_logical(out)
89    }
90
91    #[doc(hidden)]
92    pub unsafe fn agg_n_unique(&self, groups: &GroupsType) -> Series {
93        // Prevent a rechunk for every individual group.
94        let s = if groups.len() > 1 {
95            self.rechunk()
96        } else {
97            self.clone()
98        };
99
100        match groups {
101            GroupsType::Idx(groups) => agg_helper_idx_on_all_no_null::<IdxType, _>(groups, |idx| {
102                debug_assert!(idx.len() <= s.len());
103                if idx.is_empty() {
104                    0
105                } else {
106                    let take = s.take_slice_unchecked(idx);
107                    take.n_unique().unwrap() as IdxSize
108                }
109            }),
110            GroupsType::Slice { groups, .. } => {
111                _agg_helper_slice_no_null::<IdxType, _>(groups, |[first, len]| {
112                    debug_assert!(len <= s.len() as IdxSize);
113                    if len == 0 {
114                        0
115                    } else {
116                        let take = s.slice_from_offsets(first, len);
117                        take.n_unique().unwrap() as IdxSize
118                    }
119                })
120            },
121        }
122    }
123
124    #[doc(hidden)]
125    pub unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
126        // Prevent a rechunk for every individual group.
127        let s = if groups.len() > 1 {
128            self.rechunk()
129        } else {
130            self.clone()
131        };
132
133        use DataType::*;
134        match s.dtype() {
135            Boolean => s.cast(&Float64).unwrap().agg_mean(groups),
136            Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_mean(groups),
137            Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_mean(groups),
138            dt if dt.is_primitive_numeric() => apply_method_physical_integer!(s, agg_mean, groups),
139            #[cfg(feature = "dtype-datetime")]
140            dt @ Datetime(_, _) => self
141                .to_physical_repr()
142                .agg_mean(groups)
143                .cast(&Int64)
144                .unwrap()
145                .cast(dt)
146                .unwrap(),
147            #[cfg(feature = "dtype-duration")]
148            dt @ Duration(_) => self
149                .to_physical_repr()
150                .agg_mean(groups)
151                .cast(&Int64)
152                .unwrap()
153                .cast(dt)
154                .unwrap(),
155            #[cfg(feature = "dtype-time")]
156            Time => self
157                .to_physical_repr()
158                .agg_mean(groups)
159                .cast(&Int64)
160                .unwrap()
161                .cast(&Time)
162                .unwrap(),
163            #[cfg(feature = "dtype-date")]
164            Date => (self
165                .to_physical_repr()
166                .agg_mean(groups)
167                .cast(&Float64)
168                .unwrap()
169                * (MS_IN_DAY as f64))
170                .cast(&Datetime(TimeUnit::Milliseconds, None))
171                .unwrap(),
172            _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
173        }
174    }
175
176    #[doc(hidden)]
177    pub unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
178        // Prevent a rechunk for every individual group.
179        let s = if groups.len() > 1 {
180            self.rechunk()
181        } else {
182            self.clone()
183        };
184
185        use DataType::*;
186        match s.dtype() {
187            Boolean => s.cast(&Float64).unwrap().agg_median(groups),
188            Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_median(groups),
189            Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_median(groups),
190            dt if dt.is_primitive_numeric() => {
191                apply_method_physical_integer!(s, agg_median, groups)
192            },
193            #[cfg(feature = "dtype-datetime")]
194            dt @ Datetime(_, _) => self
195                .to_physical_repr()
196                .agg_median(groups)
197                .cast(&Int64)
198                .unwrap()
199                .cast(dt)
200                .unwrap(),
201            #[cfg(feature = "dtype-duration")]
202            dt @ Duration(_) => self
203                .to_physical_repr()
204                .agg_median(groups)
205                .cast(&Int64)
206                .unwrap()
207                .cast(dt)
208                .unwrap(),
209            #[cfg(feature = "dtype-time")]
210            Time => self
211                .to_physical_repr()
212                .agg_median(groups)
213                .cast(&Int64)
214                .unwrap()
215                .cast(&Time)
216                .unwrap(),
217            #[cfg(feature = "dtype-date")]
218            Date => (self
219                .to_physical_repr()
220                .agg_median(groups)
221                .cast(&Float64)
222                .unwrap()
223                * (MS_IN_DAY as f64))
224                .cast(&Datetime(TimeUnit::Milliseconds, None))
225                .unwrap(),
226            _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
227        }
228    }
229
230    #[doc(hidden)]
231    pub unsafe fn agg_quantile(
232        &self,
233        groups: &GroupsType,
234        quantile: f64,
235        method: QuantileMethod,
236    ) -> Series {
237        // Prevent a rechunk for every individual group.
238        let s = if groups.len() > 1 {
239            self.rechunk()
240        } else {
241            self.clone()
242        };
243
244        use DataType::*;
245        match s.dtype() {
246            Float32 => s.f32().unwrap().agg_quantile(groups, quantile, method),
247            Float64 => s.f64().unwrap().agg_quantile(groups, quantile, method),
248            dt if dt.is_primitive_numeric() || dt.is_temporal() => {
249                let ca = s.to_physical_repr();
250                let physical_type = ca.dtype();
251                let s = apply_method_physical_integer!(ca, agg_quantile, groups, quantile, method);
252                if dt.is_logical() {
253                    // back to physical and then
254                    // back to logical type
255                    s.cast(physical_type).unwrap().cast(dt).unwrap()
256                } else {
257                    s
258                }
259            },
260            _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
261        }
262    }
263
264    #[doc(hidden)]
265    pub unsafe fn agg_last(&self, groups: &GroupsType) -> Series {
266        // Prevent a rechunk for every individual group.
267        let s = if groups.len() > 1 {
268            self.rechunk()
269        } else {
270            self.clone()
271        };
272
273        let out = match groups {
274            GroupsType::Idx(groups) => {
275                let indices = groups
276                    .all()
277                    .iter()
278                    .map(|idx| {
279                        if idx.is_empty() {
280                            None
281                        } else {
282                            Some(idx[idx.len() - 1])
283                        }
284                    })
285                    .collect_ca(PlSmallStr::EMPTY);
286                s.take_unchecked(&indices)
287            },
288            GroupsType::Slice { groups, .. } => {
289                let indices = groups
290                    .iter()
291                    .map(|&[first, len]| {
292                        if len == 0 {
293                            None
294                        } else {
295                            Some(first + len - 1)
296                        }
297                    })
298                    .collect_ca(PlSmallStr::EMPTY);
299                s.take_unchecked(&indices)
300            },
301        };
302        s.restore_logical(out)
303    }
304}