polars_core/frame/group_by/aggregations/
dispatch.rs1use arrow::bitmap::bitmask::BitMask;
2use polars_compute::unique::{AmortizedUnique, amortized_unique_from_dtype};
3
4use super::*;
5use crate::prelude::row_encode::encode_rows_unordered;
6
7const N_UNIQUE_SORT_FALLBACK_THRESHOLD: usize = 16384;
13
14impl Series {
16 unsafe fn restore_logical(&self, out: Series) -> Series {
17 if self.dtype().is_logical() && !out.dtype().is_logical() {
18 out.from_physical_unchecked(self.dtype()).unwrap()
19 } else {
20 out
21 }
22 }
23
24 #[doc(hidden)]
25 pub unsafe fn agg_valid_count(&self, groups: &GroupsType) -> Series {
26 let valid = self.rechunk_validity();
28
29 match groups {
30 GroupsType::Idx(groups) => agg_helper_idx_on_all::<IdxType, _>(groups, |idxs| {
31 debug_assert!(idxs.len() <= self.len());
32 if let Some(v) = &valid {
33 let mut count = 0;
34 for idx in idxs.iter() {
35 count += unsafe { v.get_bit_unchecked(*idx as usize) as IdxSize };
36 }
37 Some(count)
38 } else {
39 Some(self.len() as IdxSize)
40 }
41 }),
42 GroupsType::Slice { groups, .. } => {
43 _agg_helper_slice::<IdxType, _>(groups, |[first, len]| {
44 debug_assert!(len <= self.len() as IdxSize);
45 if let Some(v) = &valid {
46 let m = BitMask::from_bitmap(v).sliced(first as usize, len as usize);
47 Some(m.set_bits() as IdxSize)
48 } else {
49 Some(self.len() as IdxSize)
50 }
51 })
52 },
53 }
54 }
55
56 #[doc(hidden)]
57 pub unsafe fn agg_first(&self, groups: &GroupsType) -> Series {
58 let s = if groups.len() > 1 {
60 self.rechunk()
61 } else {
62 self.clone()
63 };
64
65 let mut out = match groups {
66 GroupsType::Idx(groups) => {
67 let indices = groups
68 .iter()
69 .map(
70 |(first, idx)| {
71 if idx.is_empty() { None } else { Some(first) }
72 },
73 )
74 .collect_ca(PlSmallStr::EMPTY);
75 s.take_unchecked(&indices)
77 },
78 GroupsType::Slice { groups, .. } => {
79 let indices = groups
80 .iter()
81 .map(|&[first, len]| if len == 0 { None } else { Some(first) })
82 .collect_ca(PlSmallStr::EMPTY);
83 s.take_unchecked(&indices)
85 },
86 };
87 if groups.is_sorted_by_first_idx() {
88 out.set_sorted_flag(s.is_sorted_flag())
89 }
90 s.restore_logical(out)
91 }
92
93 #[doc(hidden)]
94 pub unsafe fn agg_first_non_null(&self, groups: &GroupsType) -> Series {
95 if !self.has_nulls() {
96 return self.agg_first(groups);
97 }
98
99 let s = if groups.len() > 1 {
101 self.rechunk()
102 } else {
103 self.clone()
104 };
105
106 let validity = s.rechunk_validity().unwrap();
107 let indices = match groups {
108 GroupsType::Idx(groups) => {
109 groups
110 .iter()
111 .map(|(_, idx)| {
112 let mut this_idx = None;
113 for &ii in idx.iter() {
114 if validity.get_bit_unchecked(ii as usize) {
116 this_idx = Some(ii);
117 break;
118 }
119 }
120 this_idx
121 })
122 .collect_ca(PlSmallStr::EMPTY)
123 },
124 GroupsType::Slice { groups, .. } => {
125 let mask = BitMask::from_bitmap(&validity);
126 groups
127 .iter()
128 .map(|&[first, len]| {
129 let validity = mask.sliced_unchecked(first as usize, len as usize);
131 let leading_zeros = validity.leading_zeros() as IdxSize;
132 if leading_zeros == len {
133 None
135 } else {
136 Some(first + leading_zeros)
137 }
138 })
139 .collect_ca(PlSmallStr::EMPTY)
140 },
141 };
142 let mut out = s.take_unchecked(&indices);
144 if matches!(groups, GroupsType::Slice { .. }) && !groups.is_overlapping() {
145 out.set_sorted_flag(s.is_sorted_flag())
146 }
147 s.restore_logical(out)
148 }
149
150 #[doc(hidden)]
151 pub unsafe fn agg_arg_first(&self, groups: &GroupsType) -> Series {
152 let out: IdxCa = match groups {
153 GroupsType::Idx(groups) => groups
154 .iter()
155 .map(|(_, idx)| {
156 if idx.is_empty() {
157 None
158 } else {
159 Some(0 as IdxSize)
160 }
161 })
162 .collect_ca(PlSmallStr::EMPTY),
163
164 GroupsType::Slice { groups, .. } => groups
165 .iter()
166 .map(|&[_first, len]| if len == 0 { None } else { Some(0 as IdxSize) })
167 .collect_ca(PlSmallStr::EMPTY),
168 };
169 out.into_series()
170 }
171
172 #[doc(hidden)]
173 pub unsafe fn agg_arg_first_non_null(&self, groups: &GroupsType) -> Series {
174 if !self.has_nulls() {
175 return self.agg_arg_first(groups);
176 }
177
178 let validity = self.rechunk_validity().unwrap();
179
180 let out: IdxCa = match groups {
181 GroupsType::Idx(groups) => groups
182 .iter()
183 .map(|(_, idx)| {
184 let mut pos: Option<IdxSize> = None;
185 for (p, &ii) in idx.iter().enumerate() {
186 if validity.get_bit_unchecked(ii as usize) {
187 pos = Some(p as IdxSize);
188 break;
189 }
190 }
191 pos
192 })
193 .collect_ca(PlSmallStr::EMPTY),
194
195 GroupsType::Slice { groups, .. } => {
196 let mask = BitMask::from_bitmap(&validity);
197 groups
198 .iter()
199 .map(|&[first, len]| {
200 if len == 0 {
201 return None;
202 }
203 let v = mask.sliced_unchecked(first as usize, len as usize);
204 let lz = v.leading_zeros() as IdxSize;
205 if lz == len { None } else { Some(lz) }
206 })
207 .collect_ca(PlSmallStr::EMPTY)
208 },
209 };
210
211 out.into_series()
212 }
213
214 #[doc(hidden)]
215 pub unsafe fn agg_arg_last(&self, groups: &GroupsType) -> Series {
216 let out: IdxCa = match groups {
217 GroupsType::Idx(groups) => groups
218 .all()
219 .iter()
220 .map(|idx| {
221 if idx.is_empty() {
222 None
223 } else {
224 Some((idx.len() - 1) as IdxSize)
225 }
226 })
227 .collect_ca(PlSmallStr::EMPTY),
228
229 GroupsType::Slice { groups, .. } => groups
230 .iter()
231 .map(|&[_first, len]| {
232 if len == 0 {
233 None
234 } else {
235 Some((len - 1) as IdxSize)
236 }
237 })
238 .collect_ca(PlSmallStr::EMPTY),
239 };
240
241 out.into_series()
242 }
243
244 #[doc(hidden)]
245 pub unsafe fn agg_arg_last_non_null(&self, groups: &GroupsType) -> Series {
246 if !self.has_nulls() {
247 return self.agg_arg_last(groups);
248 }
249
250 let validity = self.rechunk_validity().unwrap();
251
252 let out: IdxCa = match groups {
253 GroupsType::Idx(groups) => groups
254 .iter()
255 .map(|(_, idx)| {
256 for (p, &ii) in idx.iter().enumerate().rev() {
257 if validity.get_bit_unchecked(ii as usize) {
258 return Some(p as IdxSize);
259 }
260 }
261 None
262 })
263 .collect_ca(PlSmallStr::EMPTY),
264
265 GroupsType::Slice { groups, .. } => {
266 let mask = BitMask::from_bitmap(&validity);
267 groups
268 .iter()
269 .map(|&[first, len]| {
270 if len == 0 {
271 return None;
272 }
273 let v = mask.sliced_unchecked(first as usize, len as usize);
274 let tz = v.trailing_zeros() as IdxSize;
275 if tz == len { None } else { Some(len - tz - 1) }
276 })
277 .collect_ca(PlSmallStr::EMPTY)
278 },
279 };
280
281 out.into_series()
282 }
283
284 #[doc(hidden)]
285 pub unsafe fn agg_n_unique(&self, groups: &GroupsType) -> Series {
286 let values = self.to_physical_repr();
287 let dtype = values.dtype();
288 let values = if dtype.contains_objects() {
289 panic!("{}", polars_err!(opq = unique, dtype));
290 } else if let Some(ca) = values.try_str() {
291 ca.as_binary().into_column()
292 } else if dtype.is_nested() {
293 encode_rows_unordered(&[values.into_owned().into_column()])
294 .unwrap()
295 .into_column()
296 } else {
297 values.into_owned().into_column()
298 };
299
300 let col = values.clone();
303 let values = values.rechunk_to_arrow(CompatLevel::newest());
304 let values = values.as_ref();
305 let state = amortized_unique_from_dtype(values.dtype());
306
307 struct CloneWrapper(Box<dyn AmortizedUnique>);
308 impl Clone for CloneWrapper {
309 fn clone(&self) -> Self {
310 Self(self.0.new_empty())
311 }
312 }
313
314 RAYON
319 .install(|| match groups {
320 GroupsType::Idx(idx) => idx
321 .all()
322 .into_par_iter()
323 .map_with(CloneWrapper(state), |state, idxs| unsafe {
324 let idxs = idxs.as_slice();
325 if idxs.len() > N_UNIQUE_SORT_FALLBACK_THRESHOLD {
326 col.take_slice_unchecked(idxs).n_unique().unwrap() as IdxSize
327 } else {
328 state.0.n_unique_idx(values, idxs)
329 }
330 })
331 .collect::<NoNull<IdxCa>>(),
332 GroupsType::Slice {
333 groups,
334 overlapping: _,
335 monotonic: _,
336 } => groups
337 .into_par_iter()
338 .map_with(CloneWrapper(state), |state, &[start, len]| {
339 let len_us = len as usize;
340 if len_us > N_UNIQUE_SORT_FALLBACK_THRESHOLD {
341 col.slice(start as i64, len_us).n_unique().unwrap() as IdxSize
342 } else {
343 state.0.n_unique_slice(values, start, len)
344 }
345 })
346 .collect::<NoNull<IdxCa>>(),
347 })
348 .into_inner()
349 .into_series()
350 }
351
352 #[doc(hidden)]
353 pub unsafe fn agg_mean(&self, groups: &GroupsType) -> Series {
354 let s = if groups.len() > 1 {
356 self.rechunk()
357 } else {
358 self.clone()
359 };
360
361 use DataType::*;
362 match s.dtype() {
363 Boolean => s.cast(&Float64).unwrap().agg_mean(groups),
364 Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_mean(groups),
365 Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_mean(groups),
366 dt if dt.is_primitive_numeric() => apply_method_physical_integer!(s, agg_mean, groups),
367 #[cfg(feature = "dtype-decimal")]
368 Decimal(_, _) => self.cast(&Float64).unwrap().agg_mean(groups),
369 #[cfg(feature = "dtype-datetime")]
370 dt @ Datetime(_, _) => self
371 .to_physical_repr()
372 .agg_mean(groups)
373 .cast(&Int64)
374 .unwrap()
375 .cast(dt)
376 .unwrap(),
377 #[cfg(feature = "dtype-duration")]
378 dt @ Duration(_) => self
379 .to_physical_repr()
380 .agg_mean(groups)
381 .cast(&Int64)
382 .unwrap()
383 .cast(dt)
384 .unwrap(),
385 #[cfg(feature = "dtype-time")]
386 Time => self
387 .to_physical_repr()
388 .agg_mean(groups)
389 .cast(&Int64)
390 .unwrap()
391 .cast(&Time)
392 .unwrap(),
393 #[cfg(feature = "dtype-date")]
394 Date => (self
395 .to_physical_repr()
396 .agg_mean(groups)
397 .cast(&Float64)
398 .unwrap()
399 * (US_IN_DAY as f64))
400 .cast(&Datetime(TimeUnit::Microseconds, None))
401 .unwrap(),
402 _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
403 }
404 }
405
406 #[doc(hidden)]
407 pub unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
408 let s = if groups.len() > 1 {
410 self.rechunk()
411 } else {
412 self.clone()
413 };
414
415 use DataType::*;
416 match s.dtype() {
417 Boolean => s.cast(&Float64).unwrap().agg_median(groups),
418 Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_median(groups),
419 Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_median(groups),
420 dt if dt.is_primitive_numeric() => {
421 apply_method_physical_integer!(s, agg_median, groups)
422 },
423 #[cfg(feature = "dtype-decimal")]
424 Decimal(_, _) => self.cast(&Float64).unwrap().agg_median(groups),
425 #[cfg(feature = "dtype-datetime")]
426 dt @ Datetime(_, _) => self
427 .to_physical_repr()
428 .agg_median(groups)
429 .cast(&Int64)
430 .unwrap()
431 .cast(dt)
432 .unwrap(),
433 #[cfg(feature = "dtype-duration")]
434 dt @ Duration(_) => self
435 .to_physical_repr()
436 .agg_median(groups)
437 .cast(&Int64)
438 .unwrap()
439 .cast(dt)
440 .unwrap(),
441 #[cfg(feature = "dtype-time")]
442 Time => self
443 .to_physical_repr()
444 .agg_median(groups)
445 .cast(&Int64)
446 .unwrap()
447 .cast(&Time)
448 .unwrap(),
449 #[cfg(feature = "dtype-date")]
450 Date => (self
451 .to_physical_repr()
452 .agg_median(groups)
453 .cast(&Float64)
454 .unwrap()
455 * (US_IN_DAY as f64))
456 .cast(&Datetime(TimeUnit::Microseconds, None))
457 .unwrap(),
458 _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
459 }
460 }
461
462 #[doc(hidden)]
463 pub unsafe fn agg_quantile(
464 &self,
465 groups: &GroupsType,
466 quantile: f64,
467 method: QuantileMethod,
468 ) -> Series {
469 let s = if groups.len() > 1 {
471 self.rechunk()
472 } else {
473 self.clone()
474 };
475
476 use DataType::*;
477 match s.dtype() {
478 Float32 => s.f32().unwrap().agg_quantile(groups, quantile, method),
479 Float64 => s.f64().unwrap().agg_quantile(groups, quantile, method),
480 #[cfg(feature = "dtype-decimal")]
481 Decimal(_, _) => s
482 .cast(&DataType::Float64)
483 .unwrap()
484 .agg_quantile(groups, quantile, method),
485 #[cfg(feature = "dtype-datetime")]
486 Datetime(tu, tz) => self
487 .to_physical_repr()
488 .agg_quantile(groups, quantile, method)
489 .cast(&Int64)
490 .unwrap()
491 .into_datetime(*tu, tz.clone()),
492 #[cfg(feature = "dtype-duration")]
493 Duration(tu) => self
494 .to_physical_repr()
495 .agg_quantile(groups, quantile, method)
496 .cast(&Int64)
497 .unwrap()
498 .into_duration(*tu),
499 #[cfg(feature = "dtype-time")]
500 Time => self
501 .to_physical_repr()
502 .agg_quantile(groups, quantile, method)
503 .cast(&Int64)
504 .unwrap()
505 .into_time(),
506 #[cfg(feature = "dtype-date")]
507 Date => (self
508 .to_physical_repr()
509 .agg_quantile(groups, quantile, method)
510 .cast(&Float64)
511 .unwrap()
512 * (US_IN_DAY as f64))
513 .cast(&DataType::Int64)
514 .unwrap()
515 .into_datetime(TimeUnit::Microseconds, None),
516 dt if dt.is_primitive_numeric() => {
517 apply_method_physical_integer!(s, agg_quantile, groups, quantile, method)
518 },
519 _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
520 }
521 }
522
523 #[doc(hidden)]
524 pub unsafe fn agg_last(&self, groups: &GroupsType) -> Series {
525 let s = if groups.len() > 1 {
527 self.rechunk()
528 } else {
529 self.clone()
530 };
531
532 let out = match groups {
533 GroupsType::Idx(groups) => {
534 let indices = groups
535 .all()
536 .iter()
537 .map(|idx| {
538 if idx.is_empty() {
539 None
540 } else {
541 Some(idx[idx.len() - 1])
542 }
543 })
544 .collect_ca(PlSmallStr::EMPTY);
545 s.take_unchecked(&indices)
546 },
547 GroupsType::Slice { groups, .. } => {
548 let indices = groups
549 .iter()
550 .map(|&[first, len]| {
551 if len == 0 {
552 None
553 } else {
554 Some(first + len - 1)
555 }
556 })
557 .collect_ca(PlSmallStr::EMPTY);
558 s.take_unchecked(&indices)
559 },
560 };
561 s.restore_logical(out)
562 }
563
564 #[doc(hidden)]
565 pub unsafe fn agg_last_non_null(&self, groups: &GroupsType) -> Series {
566 if !self.has_nulls() {
567 return self.agg_last(groups);
568 }
569
570 let s = if groups.len() > 1 {
572 self.rechunk()
573 } else {
574 self.clone()
575 };
576
577 let validity = s.rechunk_validity().unwrap();
578 let indices = match groups {
579 GroupsType::Idx(groups) => {
580 groups
581 .iter()
582 .map(|(_, idx)| {
583 let mut opt_idx = None;
585 for &ii in idx.iter().rev() {
586 if validity.get_bit_unchecked(ii as usize) {
588 opt_idx = Some(ii);
589 break;
590 }
591 }
592 opt_idx
593 })
594 .collect_ca(PlSmallStr::EMPTY)
595 },
596 GroupsType::Slice { groups, .. } => {
597 let mask = BitMask::from_bitmap(&validity);
598 groups
599 .iter()
600 .map(|&[first, len]| {
601 let validity = mask.sliced_unchecked(first as usize, len as usize);
603 let trailing_zeros = validity.trailing_zeros() as IdxSize;
604 if trailing_zeros == len {
605 None
607 } else {
608 Some(first + len - trailing_zeros - 1)
609 }
610 })
611 .collect_ca(PlSmallStr::EMPTY)
612 },
613 };
614 let mut out = s.take_unchecked(&indices);
616 if groups.is_monotonic() {
617 out.set_sorted_flag(s.is_sorted_flag())
618 }
619 s.restore_logical(out)
620 }
621}