polars_core/frame/group_by/aggregations/
dispatch.rs1use arrow::bitmap::bitmask::BitMask;
2use polars_compute::unique::{AmortizedUnique, amortized_unique_from_dtype};
3
4use super::*;
5use crate::prelude::row_encode::encode_rows_unordered;
6
7impl Series {
9 fn slice_from_offsets(&self, first: IdxSize, len: IdxSize) -> Self {
10 self.slice(first as i64, len as usize)
11 }
12
13 unsafe fn restore_logical(&self, out: Series) -> Series {
14 if self.dtype().is_logical() && !out.dtype().is_logical() {
15 out.from_physical_unchecked(self.dtype()).unwrap()
16 } else {
17 out
18 }
19 }
20
21 #[doc(hidden)]
22 pub unsafe fn agg_valid_count(&self, groups: &GroupsType) -> Series {
23 let s = if groups.len() > 1 && self.null_count() > 0 {
25 self.rechunk()
26 } else {
27 self.clone()
28 };
29
30 match groups {
31 GroupsType::Idx(groups) => agg_helper_idx_on_all::<IdxType, _>(groups, |idx| {
32 debug_assert!(idx.len() <= s.len());
33 if idx.is_empty() {
34 None
35 } else if s.null_count() == 0 {
36 Some(idx.len() as IdxSize)
37 } else {
38 let take = unsafe { s.take_slice_unchecked(idx) };
39 Some((take.len() - take.null_count()) as IdxSize)
40 }
41 }),
42 GroupsType::Slice { groups, .. } => {
43 _agg_helper_slice::<IdxType, _>(groups, |[first, len]| {
44 debug_assert!(len <= s.len() as IdxSize);
45 if len == 0 {
46 None
47 } else if s.null_count() == 0 {
48 Some(len)
49 } else {
50 let take = s.slice_from_offsets(first, len);
51 Some((take.len() - take.null_count()) as IdxSize)
52 }
53 })
54 },
55 }
56 }
57
58 #[doc(hidden)]
59 pub unsafe fn agg_first(&self, groups: &GroupsType) -> Series {
60 let s = if groups.len() > 1 {
62 self.rechunk()
63 } else {
64 self.clone()
65 };
66
67 let mut out = match groups {
68 GroupsType::Idx(groups) => {
69 let indices = groups
70 .iter()
71 .map(
72 |(first, idx)| {
73 if idx.is_empty() { None } else { Some(first) }
74 },
75 )
76 .collect_ca(PlSmallStr::EMPTY);
77 s.take_unchecked(&indices)
79 },
80 GroupsType::Slice { groups, .. } => {
81 let indices = groups
82 .iter()
83 .map(|&[first, len]| if len == 0 { None } else { Some(first) })
84 .collect_ca(PlSmallStr::EMPTY);
85 s.take_unchecked(&indices)
87 },
88 };
89 if groups.is_sorted_flag() {
90 out.set_sorted_flag(s.is_sorted_flag())
91 }
92 s.restore_logical(out)
93 }
94
95 #[doc(hidden)]
96 pub unsafe fn agg_first_non_null(&self, groups: &GroupsType) -> Series {
97 if !self.has_nulls() {
98 return self.agg_first(groups);
99 }
100
101 let s = if groups.len() > 1 {
103 self.rechunk()
104 } else {
105 self.clone()
106 };
107
108 let validity = s.rechunk_validity().unwrap();
109 let indices = match groups {
110 GroupsType::Idx(groups) => {
111 groups
112 .iter()
113 .map(|(_, idx)| {
114 let mut this_idx = None;
115 for &ii in idx.iter() {
116 if validity.get_bit_unchecked(ii as usize) {
118 this_idx = Some(ii);
119 break;
120 }
121 }
122 this_idx
123 })
124 .collect_ca(PlSmallStr::EMPTY)
125 },
126 GroupsType::Slice { groups, .. } => {
127 let mask = BitMask::from_bitmap(&validity);
128 groups
129 .iter()
130 .map(|&[first, len]| {
131 let validity = mask.sliced_unchecked(first as usize, len as usize);
133 let leading_zeros = validity.leading_zeros() as IdxSize;
134 if leading_zeros == len {
135 None
137 } else {
138 Some(first + leading_zeros)
139 }
140 })
141 .collect_ca(PlSmallStr::EMPTY)
142 },
143 };
144 let mut out = s.take_unchecked(&indices);
146 if groups.is_sorted_flag() {
147 out.set_sorted_flag(s.is_sorted_flag())
148 }
149 s.restore_logical(out)
150 }
151
152 #[doc(hidden)]
153 pub unsafe fn agg_n_unique(&self, groups: &GroupsType) -> Series {
154 let values = self.to_physical_repr();
155 let dtype = values.dtype();
156 let values = if dtype.contains_objects() {
157 panic!("{}", polars_err!(opq = unique, dtype));
158 } else if let Some(ca) = values.try_str() {
159 ca.as_binary().into_column()
160 } else if dtype.is_nested() {
161 encode_rows_unordered(&[values.into_owned().into_column()])
162 .unwrap()
163 .into_column()
164 } else {
165 values.into_owned().into_column()
166 };
167
168 let values = values.rechunk_to_arrow(CompatLevel::newest());
169 let values = values.as_ref();
170 let state = amortized_unique_from_dtype(values.dtype());
171
172 struct CloneWrapper(Box<dyn AmortizedUnique>);
173 impl Clone for CloneWrapper {
174 fn clone(&self) -> Self {
175 Self(self.0.new_empty())
176 }
177 }
178
179 POOL.install(|| match groups {
180 GroupsType::Idx(idx) => idx
181 .all()
182 .into_par_iter()
183 .map_with(CloneWrapper(state), |state, idxs| unsafe {
184 state.0.n_unique_idx(values, idxs.as_slice())
185 })
186 .collect::<NoNull<IdxCa>>(),
187 GroupsType::Slice {
188 groups,
189 overlapping: _,
190 monotonic: _,
191 } => groups
192 .into_par_iter()
193 .map_with(CloneWrapper(state), |state, [start, len]| {
194 state.0.n_unique_slice(values, *start, *len)
195 })
196 .collect::<NoNull<IdxCa>>(),
197 })
198 .into_inner()
199 .into_series()
200 }
201
202 #[doc(hidden)]
203 pub unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
204 let s = if groups.len() > 1 {
206 self.rechunk()
207 } else {
208 self.clone()
209 };
210
211 use DataType::*;
212 match s.dtype() {
213 Boolean => s.cast(&Float64).unwrap().agg_mean(groups),
214 Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_mean(groups),
215 Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_mean(groups),
216 dt if dt.is_primitive_numeric() => apply_method_physical_integer!(s, agg_mean, groups),
217 #[cfg(feature = "dtype-decimal")]
218 Decimal(_, _) => self.cast(&Float64).unwrap().agg_mean(groups),
219 #[cfg(feature = "dtype-datetime")]
220 dt @ Datetime(_, _) => self
221 .to_physical_repr()
222 .agg_mean(groups)
223 .cast(&Int64)
224 .unwrap()
225 .cast(dt)
226 .unwrap(),
227 #[cfg(feature = "dtype-duration")]
228 dt @ Duration(_) => self
229 .to_physical_repr()
230 .agg_mean(groups)
231 .cast(&Int64)
232 .unwrap()
233 .cast(dt)
234 .unwrap(),
235 #[cfg(feature = "dtype-time")]
236 Time => self
237 .to_physical_repr()
238 .agg_mean(groups)
239 .cast(&Int64)
240 .unwrap()
241 .cast(&Time)
242 .unwrap(),
243 #[cfg(feature = "dtype-date")]
244 Date => (self
245 .to_physical_repr()
246 .agg_mean(groups)
247 .cast(&Float64)
248 .unwrap()
249 * (US_IN_DAY as f64))
250 .cast(&Datetime(TimeUnit::Microseconds, None))
251 .unwrap(),
252 _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
253 }
254 }
255
256 #[doc(hidden)]
257 pub unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
258 let s = if groups.len() > 1 {
260 self.rechunk()
261 } else {
262 self.clone()
263 };
264
265 use DataType::*;
266 match s.dtype() {
267 Boolean => s.cast(&Float64).unwrap().agg_median(groups),
268 Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_median(groups),
269 Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_median(groups),
270 dt if dt.is_primitive_numeric() => {
271 apply_method_physical_integer!(s, agg_median, groups)
272 },
273 #[cfg(feature = "dtype-decimal")]
274 Decimal(_, _) => self.cast(&Float64).unwrap().agg_median(groups),
275 #[cfg(feature = "dtype-datetime")]
276 dt @ Datetime(_, _) => self
277 .to_physical_repr()
278 .agg_median(groups)
279 .cast(&Int64)
280 .unwrap()
281 .cast(dt)
282 .unwrap(),
283 #[cfg(feature = "dtype-duration")]
284 dt @ Duration(_) => self
285 .to_physical_repr()
286 .agg_median(groups)
287 .cast(&Int64)
288 .unwrap()
289 .cast(dt)
290 .unwrap(),
291 #[cfg(feature = "dtype-time")]
292 Time => self
293 .to_physical_repr()
294 .agg_median(groups)
295 .cast(&Int64)
296 .unwrap()
297 .cast(&Time)
298 .unwrap(),
299 #[cfg(feature = "dtype-date")]
300 Date => (self
301 .to_physical_repr()
302 .agg_median(groups)
303 .cast(&Float64)
304 .unwrap()
305 * (US_IN_DAY as f64))
306 .cast(&Datetime(TimeUnit::Microseconds, None))
307 .unwrap(),
308 _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
309 }
310 }
311
312 #[doc(hidden)]
313 pub unsafe fn agg_quantile(
314 &self,
315 groups: &GroupsType,
316 quantile: f64,
317 method: QuantileMethod,
318 ) -> Series {
319 let s = if groups.len() > 1 {
321 self.rechunk()
322 } else {
323 self.clone()
324 };
325
326 use DataType::*;
327 match s.dtype() {
328 Float32 => s.f32().unwrap().agg_quantile(groups, quantile, method),
329 Float64 => s.f64().unwrap().agg_quantile(groups, quantile, method),
330 #[cfg(feature = "dtype-decimal")]
331 Decimal(_, _) => s
332 .cast(&DataType::Float64)
333 .unwrap()
334 .agg_quantile(groups, quantile, method),
335 #[cfg(feature = "dtype-datetime")]
336 Datetime(tu, tz) => self
337 .to_physical_repr()
338 .agg_quantile(groups, quantile, method)
339 .cast(&Int64)
340 .unwrap()
341 .into_datetime(*tu, tz.clone()),
342 #[cfg(feature = "dtype-duration")]
343 Duration(tu) => self
344 .to_physical_repr()
345 .agg_quantile(groups, quantile, method)
346 .cast(&Int64)
347 .unwrap()
348 .into_duration(*tu),
349 #[cfg(feature = "dtype-time")]
350 Time => self
351 .to_physical_repr()
352 .agg_quantile(groups, quantile, method)
353 .cast(&Int64)
354 .unwrap()
355 .into_time(),
356 #[cfg(feature = "dtype-date")]
357 Date => (self
358 .to_physical_repr()
359 .agg_quantile(groups, quantile, method)
360 .cast(&Float64)
361 .unwrap()
362 * (US_IN_DAY as f64))
363 .cast(&DataType::Int64)
364 .unwrap()
365 .into_datetime(TimeUnit::Microseconds, None),
366 dt if dt.is_primitive_numeric() => {
367 apply_method_physical_integer!(s, agg_quantile, groups, quantile, method)
368 },
369 _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
370 }
371 }
372
373 #[doc(hidden)]
374 pub unsafe fn agg_last(&self, groups: &GroupsType) -> Series {
375 let s = if groups.len() > 1 {
377 self.rechunk()
378 } else {
379 self.clone()
380 };
381
382 let out = match groups {
383 GroupsType::Idx(groups) => {
384 let indices = groups
385 .all()
386 .iter()
387 .map(|idx| {
388 if idx.is_empty() {
389 None
390 } else {
391 Some(idx[idx.len() - 1])
392 }
393 })
394 .collect_ca(PlSmallStr::EMPTY);
395 s.take_unchecked(&indices)
396 },
397 GroupsType::Slice { groups, .. } => {
398 let indices = groups
399 .iter()
400 .map(|&[first, len]| {
401 if len == 0 {
402 None
403 } else {
404 Some(first + len - 1)
405 }
406 })
407 .collect_ca(PlSmallStr::EMPTY);
408 s.take_unchecked(&indices)
409 },
410 };
411 s.restore_logical(out)
412 }
413
414 #[doc(hidden)]
415 pub unsafe fn agg_last_non_null(&self, groups: &GroupsType) -> Series {
416 if !self.has_nulls() {
417 return self.agg_last(groups);
418 }
419
420 let s = if groups.len() > 1 {
422 self.rechunk()
423 } else {
424 self.clone()
425 };
426
427 let validity = s.rechunk_validity().unwrap();
428 let indices = match groups {
429 GroupsType::Idx(groups) => {
430 groups
431 .iter()
432 .map(|(_, idx)| {
433 let mut opt_idx = None;
435 for &ii in idx.iter().rev() {
436 if validity.get_bit_unchecked(ii as usize) {
438 opt_idx = Some(ii);
439 break;
440 }
441 }
442 opt_idx
443 })
444 .collect_ca(PlSmallStr::EMPTY)
445 },
446 GroupsType::Slice { groups, .. } => {
447 let mask = BitMask::from_bitmap(&validity);
448 groups
449 .iter()
450 .map(|&[first, len]| {
451 let validity = mask.sliced_unchecked(first as usize, len as usize);
453 let trailing_zeros = validity.trailing_zeros() as IdxSize;
454 if trailing_zeros == len {
455 None
457 } else {
458 Some(first + len - trailing_zeros - 1)
459 }
460 })
461 .collect_ca(PlSmallStr::EMPTY)
462 },
463 };
464 let mut out = s.take_unchecked(&indices);
466 if groups.is_sorted_flag() {
467 out.set_sorted_flag(s.is_sorted_flag())
468 }
469 s.restore_logical(out)
470 }
471}