1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3
4use num_traits::NumCast;
5use polars_compute::rolling::QuantileMethod;
6use polars_utils::format_pl_smallstr;
7use polars_utils::hashing::DirtyHash;
8use rayon::prelude::*;
9
10use self::hashing::*;
11use crate::POOL;
12use crate::prelude::*;
13use crate::utils::{_set_partition_size, accumulate_dataframes_vertical};
14
15pub mod aggregations;
16pub mod expr;
17pub(crate) mod hashing;
18mod into_groups;
19mod position;
20
21pub use into_groups::*;
22pub use position::*;
23
24use crate::chunked_array::ops::row_encode::{
25 encode_rows_unordered, encode_rows_vertical_par_unordered,
26};
27
28impl DataFrame {
29 pub fn group_by_with_series(
30 &self,
31 mut by: Vec<Column>,
32 multithreaded: bool,
33 sorted: bool,
34 ) -> PolarsResult<GroupBy<'_>> {
35 polars_ensure!(
36 !by.is_empty(),
37 ComputeError: "at least one key is required in a group_by operation"
38 );
39
40 let common_height = if self.width() > 0 {
44 self.height()
45 } else {
46 by.iter().map(|s| s.len()).max().expect("at least 1 key")
47 };
48 for by_key in by.iter_mut() {
49 if by_key.len() != common_height {
50 polars_ensure!(
51 by_key.len() == 1,
52 ShapeMismatch: "series used as keys should have the same length as the DataFrame"
53 );
54 *by_key = by_key.new_from_index(0, common_height)
55 }
56 }
57
58 let groups = if by.len() == 1 {
59 let column = &by[0];
60 column
61 .as_materialized_series()
62 .group_tuples(multithreaded, sorted)
63 } else if by.iter().any(|s| s.dtype().is_object()) {
64 #[cfg(feature = "object")]
65 {
66 let mut df = DataFrame::new(by.clone()).unwrap();
67 let n = df.height();
68 let rows = df.to_av_rows();
69 let iter = (0..n).map(|i| rows.get(i));
70 Ok(group_by(iter, sorted))
71 }
72 #[cfg(not(feature = "object"))]
73 {
74 unreachable!()
75 }
76 } else {
77 let by = by
79 .iter()
80 .filter(|s| !s.dtype().is_null())
81 .cloned()
82 .collect::<Vec<_>>();
83 if by.is_empty() {
84 let groups = if self.is_empty() {
85 vec![]
86 } else {
87 vec![[0, self.height() as IdxSize]]
88 };
89 Ok(GroupsType::Slice {
90 groups,
91 overlapping: false,
92 })
93 } else {
94 let rows = if multithreaded {
95 encode_rows_vertical_par_unordered(&by)
96 } else {
97 encode_rows_unordered(&by)
98 }?
99 .into_series();
100 rows.group_tuples(multithreaded, sorted)
101 }
102 };
103 Ok(GroupBy::new(self, by, groups?.into_sliceable(), None))
104 }
105
106 pub fn group_by<I, S>(&self, by: I) -> PolarsResult<GroupBy<'_>>
119 where
120 I: IntoIterator<Item = S>,
121 S: Into<PlSmallStr>,
122 {
123 let selected_keys = self.select_columns(by)?;
124 self.group_by_with_series(selected_keys, true, false)
125 }
126
127 pub fn group_by_stable<I, S>(&self, by: I) -> PolarsResult<GroupBy<'_>>
130 where
131 I: IntoIterator<Item = S>,
132 S: Into<PlSmallStr>,
133 {
134 let selected_keys = self.select_columns(by)?;
135 self.group_by_with_series(selected_keys, true, true)
136 }
137}
138
139#[derive(Debug, Clone)]
189pub struct GroupBy<'a> {
190 pub df: &'a DataFrame,
191 pub(crate) selected_keys: Vec<Column>,
192 groups: GroupPositions,
194 pub(crate) selected_agg: Option<Vec<PlSmallStr>>,
196}
197
198impl<'a> GroupBy<'a> {
199 pub fn new(
200 df: &'a DataFrame,
201 by: Vec<Column>,
202 groups: GroupPositions,
203 selected_agg: Option<Vec<PlSmallStr>>,
204 ) -> Self {
205 GroupBy {
206 df,
207 selected_keys: by,
208 groups,
209 selected_agg,
210 }
211 }
212
213 #[must_use]
219 pub fn select<I: IntoIterator<Item = S>, S: Into<PlSmallStr>>(mut self, selection: I) -> Self {
220 self.selected_agg = Some(selection.into_iter().map(|s| s.into()).collect());
221 self
222 }
223
224 pub fn get_groups(&self) -> &GroupPositions {
229 &self.groups
230 }
231
232 pub unsafe fn get_groups_mut(&mut self) -> &mut GroupPositions {
241 &mut self.groups
242 }
243
244 pub fn take_groups(self) -> GroupPositions {
245 self.groups
246 }
247
248 pub fn take_groups_mut(&mut self) -> GroupPositions {
249 std::mem::take(&mut self.groups)
250 }
251
252 pub fn keys_sliced(&self, slice: Option<(i64, usize)>) -> Vec<Column> {
253 #[allow(unused_assignments)]
254 let mut groups_owned = None;
256
257 let groups = if let Some((offset, len)) = slice {
258 groups_owned = Some(self.groups.slice(offset, len));
259 groups_owned.as_deref().unwrap()
260 } else {
261 &self.groups
262 };
263 POOL.install(|| {
264 self.selected_keys
265 .par_iter()
266 .map(Column::as_materialized_series)
267 .map(|s| {
268 match groups {
269 GroupsType::Idx(groups) => {
270 let mut out = unsafe { s.take_slice_unchecked(groups.first()) };
272 if groups.sorted {
273 out.set_sorted_flag(s.is_sorted_flag());
274 };
275 out
276 },
277 GroupsType::Slice {
278 groups,
279 overlapping,
280 } => {
281 if *overlapping && !groups.is_empty() {
282 let offset = groups[0][0];
284 let [upper_offset, upper_len] = groups[groups.len() - 1];
285 return s.slice(
286 offset as i64,
287 ((upper_offset + upper_len) - offset) as usize,
288 );
289 }
290
291 let indices = groups
292 .iter()
293 .map(|&[first, _len]| first)
294 .collect_ca(PlSmallStr::EMPTY);
295 let mut out = unsafe { s.take_unchecked(&indices) };
297 out.set_sorted_flag(s.is_sorted_flag());
299 out
300 },
301 }
302 })
303 .map(Column::from)
304 .collect()
305 })
306 }
307
308 pub fn keys(&self) -> Vec<Column> {
309 self.keys_sliced(None)
310 }
311
312 fn prepare_agg(&self) -> PolarsResult<(Vec<Column>, Vec<Column>)> {
313 let keys = self.keys();
314
315 let agg_col = match &self.selected_agg {
316 Some(selection) => self.df.select_columns_impl(selection.as_slice()),
317 None => {
318 let by: Vec<_> = self.selected_keys.iter().map(|s| s.name()).collect();
319 let selection = self
320 .df
321 .iter()
322 .map(|s| s.name())
323 .filter(|a| !by.contains(a))
324 .cloned()
325 .collect::<Vec<_>>();
326
327 self.df.select_columns_impl(selection.as_slice())
328 },
329 }?;
330
331 Ok((keys, agg_col))
332 }
333
334 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
360 pub fn mean(&self) -> PolarsResult<DataFrame> {
361 let (mut cols, agg_cols) = self.prepare_agg()?;
362
363 for agg_col in agg_cols {
364 let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Mean);
365 let mut agg = unsafe { agg_col.agg_mean(&self.groups) };
366 agg.rename(new_name);
367 cols.push(agg);
368 }
369 DataFrame::new(cols)
370 }
371
372 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
398 pub fn sum(&self) -> PolarsResult<DataFrame> {
399 let (mut cols, agg_cols) = self.prepare_agg()?;
400
401 for agg_col in agg_cols {
402 let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Sum);
403 let mut agg = unsafe { agg_col.agg_sum(&self.groups) };
404 agg.rename(new_name);
405 cols.push(agg);
406 }
407 DataFrame::new(cols)
408 }
409
410 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
436 pub fn min(&self) -> PolarsResult<DataFrame> {
437 let (mut cols, agg_cols) = self.prepare_agg()?;
438 for agg_col in agg_cols {
439 let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Min);
440 let mut agg = unsafe { agg_col.agg_min(&self.groups) };
441 agg.rename(new_name);
442 cols.push(agg);
443 }
444 DataFrame::new(cols)
445 }
446
447 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
473 pub fn max(&self) -> PolarsResult<DataFrame> {
474 let (mut cols, agg_cols) = self.prepare_agg()?;
475 for agg_col in agg_cols {
476 let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Max);
477 let mut agg = unsafe { agg_col.agg_max(&self.groups) };
478 agg.rename(new_name);
479 cols.push(agg);
480 }
481 DataFrame::new(cols)
482 }
483
484 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
510 pub fn first(&self) -> PolarsResult<DataFrame> {
511 let (mut cols, agg_cols) = self.prepare_agg()?;
512 for agg_col in agg_cols {
513 let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::First);
514 let mut agg = unsafe { agg_col.agg_first(&self.groups) };
515 agg.rename(new_name);
516 cols.push(agg);
517 }
518 DataFrame::new(cols)
519 }
520
521 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
547 pub fn last(&self) -> PolarsResult<DataFrame> {
548 let (mut cols, agg_cols) = self.prepare_agg()?;
549 for agg_col in agg_cols {
550 let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Last);
551 let mut agg = unsafe { agg_col.agg_last(&self.groups) };
552 agg.rename(new_name);
553 cols.push(agg);
554 }
555 DataFrame::new(cols)
556 }
557
558 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
584 pub fn n_unique(&self) -> PolarsResult<DataFrame> {
585 let (mut cols, agg_cols) = self.prepare_agg()?;
586 for agg_col in agg_cols {
587 let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::NUnique);
588 let mut agg = unsafe { agg_col.agg_n_unique(&self.groups) };
589 agg.rename(new_name);
590 cols.push(agg);
591 }
592 DataFrame::new(cols)
593 }
594
595 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
607 pub fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<DataFrame> {
608 polars_ensure!(
609 (0.0..=1.0).contains(&quantile),
610 ComputeError: "`quantile` should be within 0.0 and 1.0"
611 );
612 let (mut cols, agg_cols) = self.prepare_agg()?;
613 for agg_col in agg_cols {
614 let new_name = fmt_group_by_column(
615 agg_col.name().as_str(),
616 GroupByMethod::Quantile(quantile, method),
617 );
618 let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, method) };
619 agg.rename(new_name);
620 cols.push(agg);
621 }
622 DataFrame::new(cols)
623 }
624
625 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
636 pub fn median(&self) -> PolarsResult<DataFrame> {
637 let (mut cols, agg_cols) = self.prepare_agg()?;
638 for agg_col in agg_cols {
639 let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Median);
640 let mut agg = unsafe { agg_col.agg_median(&self.groups) };
641 agg.rename(new_name);
642 cols.push(agg);
643 }
644 DataFrame::new(cols)
645 }
646
647 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
649 pub fn var(&self, ddof: u8) -> PolarsResult<DataFrame> {
650 let (mut cols, agg_cols) = self.prepare_agg()?;
651 for agg_col in agg_cols {
652 let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Var(ddof));
653 let mut agg = unsafe { agg_col.agg_var(&self.groups, ddof) };
654 agg.rename(new_name);
655 cols.push(agg);
656 }
657 DataFrame::new(cols)
658 }
659
660 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
662 pub fn std(&self, ddof: u8) -> PolarsResult<DataFrame> {
663 let (mut cols, agg_cols) = self.prepare_agg()?;
664 for agg_col in agg_cols {
665 let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Std(ddof));
666 let mut agg = unsafe { agg_col.agg_std(&self.groups, ddof) };
667 agg.rename(new_name);
668 cols.push(agg);
669 }
670 DataFrame::new(cols)
671 }
672
673 pub fn count(&self) -> PolarsResult<DataFrame> {
699 let (mut cols, agg_cols) = self.prepare_agg()?;
700
701 for agg_col in agg_cols {
702 let new_name = fmt_group_by_column(
703 agg_col.name().as_str(),
704 GroupByMethod::Count {
705 include_nulls: true,
706 },
707 );
708 let mut ca = self.groups.group_count();
709 ca.rename(new_name);
710 cols.push(ca.into_column());
711 }
712 DataFrame::new(cols)
713 }
714
715 pub fn groups(&self) -> PolarsResult<DataFrame> {
741 let mut cols = self.keys();
742 let mut column = self.groups.as_list_chunked();
743 let new_name = fmt_group_by_column("", GroupByMethod::Groups);
744 column.rename(new_name);
745 cols.push(column.into_column());
746 DataFrame::new(cols)
747 }
748
749 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
776 pub fn agg_list(&self) -> PolarsResult<DataFrame> {
777 let (mut cols, agg_cols) = self.prepare_agg()?;
778 for agg_col in agg_cols {
779 let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Implode);
780 let mut agg = unsafe { agg_col.agg_list(&self.groups) };
781 agg.rename(new_name);
782 cols.push(agg);
783 }
784 DataFrame::new(cols)
785 }
786
787 fn prepare_apply(&self) -> PolarsResult<DataFrame> {
788 polars_ensure!(self.df.height() > 0, ComputeError: "cannot group_by + apply on empty 'DataFrame'");
789 if let Some(agg) = &self.selected_agg {
790 if agg.is_empty() {
791 Ok(self.df.clone())
792 } else {
793 let mut new_cols = Vec::with_capacity(self.selected_keys.len() + agg.len());
794 new_cols.extend_from_slice(&self.selected_keys);
795 let cols = self.df.select_columns_impl(agg.as_slice())?;
796 new_cols.extend(cols);
797 Ok(unsafe { DataFrame::new_no_checks(self.df.height(), new_cols) })
798 }
799 } else {
800 Ok(self.df.clone())
801 }
802 }
803
804 #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
806 pub fn par_apply<F>(&self, f: F) -> PolarsResult<DataFrame>
807 where
808 F: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
809 {
810 let df = self.prepare_apply()?;
811 let dfs = self
812 .get_groups()
813 .par_iter()
814 .map(|g| {
815 let sub_df = unsafe { take_df(&df, g) };
818 f(sub_df)
819 })
820 .collect::<PolarsResult<Vec<_>>>()?;
821
822 let mut df = accumulate_dataframes_vertical(dfs)?;
823 df.as_single_chunk_par();
824 Ok(df)
825 }
826
827 pub fn apply<F>(&self, mut f: F) -> PolarsResult<DataFrame>
829 where
830 F: FnMut(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
831 {
832 let df = self.prepare_apply()?;
833 let dfs = self
834 .get_groups()
835 .iter()
836 .map(|g| {
837 let sub_df = unsafe { take_df(&df, g) };
840 f(sub_df)
841 })
842 .collect::<PolarsResult<Vec<_>>>()?;
843
844 let mut df = accumulate_dataframes_vertical(dfs)?;
845 df.as_single_chunk_par();
846 Ok(df)
847 }
848
849 pub fn sliced(mut self, slice: Option<(i64, usize)>) -> Self {
850 match slice {
851 None => self,
852 Some((offset, length)) => {
853 self.groups = self.groups.slice(offset, length);
854 self.selected_keys = self.keys_sliced(slice);
855 self
856 },
857 }
858 }
859}
860
861unsafe fn take_df(df: &DataFrame, g: GroupsIndicator) -> DataFrame {
862 match g {
863 GroupsIndicator::Idx(idx) => df.take_slice_unchecked(idx.1),
864 GroupsIndicator::Slice([first, len]) => df.slice(first as i64, len as usize),
865 }
866}
867
868#[derive(Copy, Clone, Debug)]
869pub enum GroupByMethod {
870 Min,
871 NanMin,
872 Max,
873 NanMax,
874 Median,
875 Mean,
876 First,
877 Last,
878 Sum,
879 Groups,
880 NUnique,
881 Quantile(f64, QuantileMethod),
882 Count { include_nulls: bool },
883 Implode,
884 Std(u8),
885 Var(u8),
886}
887
888impl Display for GroupByMethod {
889 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
890 use GroupByMethod::*;
891 let s = match self {
892 Min => "min",
893 NanMin => "nan_min",
894 Max => "max",
895 NanMax => "nan_max",
896 Median => "median",
897 Mean => "mean",
898 First => "first",
899 Last => "last",
900 Sum => "sum",
901 Groups => "groups",
902 NUnique => "n_unique",
903 Quantile(_, _) => "quantile",
904 Count { .. } => "count",
905 Implode => "list",
906 Std(_) => "std",
907 Var(_) => "var",
908 };
909 write!(f, "{s}")
910 }
911}
912
913pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> PlSmallStr {
915 use GroupByMethod::*;
916 match method {
917 Min => format_pl_smallstr!("{name}_min"),
918 Max => format_pl_smallstr!("{name}_max"),
919 NanMin => format_pl_smallstr!("{name}_nan_min"),
920 NanMax => format_pl_smallstr!("{name}_nan_max"),
921 Median => format_pl_smallstr!("{name}_median"),
922 Mean => format_pl_smallstr!("{name}_mean"),
923 First => format_pl_smallstr!("{name}_first"),
924 Last => format_pl_smallstr!("{name}_last"),
925 Sum => format_pl_smallstr!("{name}_sum"),
926 Groups => PlSmallStr::from_static("groups"),
927 NUnique => format_pl_smallstr!("{name}_n_unique"),
928 Count { .. } => format_pl_smallstr!("{name}_count"),
929 Implode => format_pl_smallstr!("{name}_agg_list"),
930 Quantile(quantile, _interpol) => format_pl_smallstr!("{name}_quantile_{quantile:.2}"),
931 Std(_) => format_pl_smallstr!("{name}_agg_std"),
932 Var(_) => format_pl_smallstr!("{name}_agg_var"),
933 }
934}
935
936#[cfg(test)]
937mod test {
938 use num_traits::FloatConst;
939
940 use crate::prelude::*;
941
942 #[test]
943 #[cfg(feature = "dtype-date")]
944 #[cfg_attr(miri, ignore)]
945 fn test_group_by() -> PolarsResult<()> {
946 let s0 = Column::new(
947 PlSmallStr::from_static("date"),
948 &[
949 "2020-08-21",
950 "2020-08-21",
951 "2020-08-22",
952 "2020-08-23",
953 "2020-08-22",
954 ],
955 );
956 let s1 = Column::new(PlSmallStr::from_static("temp"), [20, 10, 7, 9, 1]);
957 let s2 = Column::new(PlSmallStr::from_static("rain"), [0.2, 0.1, 0.3, 0.1, 0.01]);
958 let df = DataFrame::new(vec![s0, s1, s2]).unwrap();
959
960 let out = df.group_by_stable(["date"])?.select(["temp"]).count()?;
961 assert_eq!(
962 out.column("temp_count")?,
963 &Column::new(PlSmallStr::from_static("temp_count"), [2 as IdxSize, 2, 1])
964 );
965
966 #[allow(deprecated)]
968 let out = df
970 .group_by_stable(["date"])?
971 .select(["temp", "rain"])
972 .mean()?;
973 assert_eq!(
974 out.column("temp_mean")?,
975 &Column::new(PlSmallStr::from_static("temp_mean"), [15.0f64, 4.0, 9.0])
976 );
977
978 #[allow(deprecated)]
980 let out = df
982 .group_by_stable(["date", "temp"])?
983 .select(["rain"])
984 .mean()?;
985 assert!(out.column("rain_mean").is_ok());
986
987 #[allow(deprecated)]
989 let out = df.group_by_stable(["date"])?.select(["temp"]).sum()?;
990 assert_eq!(
991 out.column("temp_sum")?,
992 &Column::new(PlSmallStr::from_static("temp_sum"), [30, 8, 9])
993 );
994
995 #[allow(deprecated)]
997 let gb = df.group_by(["date"]).unwrap().n_unique().unwrap();
999 assert_eq!(gb.width(), 3);
1001 Ok(())
1002 }
1003
1004 #[test]
1005 #[cfg_attr(miri, ignore)]
1006 fn test_static_group_by_by_12_columns() {
1007 let s0 = Column::new("G1".into(), ["A", "A", "B", "B", "C"].as_ref());
1009 let s1 = Column::new("N".into(), [1, 2, 2, 4, 2].as_ref());
1010 let s2 = Column::new("G2".into(), ["k", "l", "m", "m", "l"].as_ref());
1011 let s3 = Column::new("G3".into(), ["a", "b", "c", "c", "d"].as_ref());
1012 let s4 = Column::new("G4".into(), ["1", "2", "3", "3", "4"].as_ref());
1013 let s5 = Column::new("G5".into(), ["X", "Y", "Z", "Z", "W"].as_ref());
1014 let s6 = Column::new("G6".into(), [false, true, true, true, false].as_ref());
1015 let s7 = Column::new("G7".into(), ["r", "x", "q", "q", "o"].as_ref());
1016 let s8 = Column::new("G8".into(), ["R", "X", "Q", "Q", "O"].as_ref());
1017 let s9 = Column::new("G9".into(), [1, 2, 3, 3, 4].as_ref());
1018 let s10 = Column::new("G10".into(), [".", "!", "?", "?", "/"].as_ref());
1019 let s11 = Column::new("G11".into(), ["(", ")", "@", "@", "$"].as_ref());
1020 let s12 = Column::new("G12".into(), ["-", "_", ";", ";", ","].as_ref());
1021
1022 let df =
1023 DataFrame::new(vec![s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]).unwrap();
1024
1025 #[allow(deprecated)]
1027 let adf = df
1028 .group_by([
1029 "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12",
1030 ])
1031 .unwrap()
1032 .select(["N"])
1033 .sum()
1034 .unwrap();
1035
1036 assert_eq!(
1037 Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1038 &[Some(1), Some(2), Some(2), Some(6)]
1039 );
1040 }
1041
1042 #[test]
1043 #[cfg_attr(miri, ignore)]
1044 fn test_dynamic_group_by_by_13_columns() {
1045 let series_content = ["A", "A", "B", "B", "C"];
1047
1048 let series_names = [
1050 "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12", "G13",
1051 ];
1052
1053 let mut columns = Vec::with_capacity(14);
1055
1056 for series_name in series_names {
1058 let group_columns = Column::new(series_name.into(), series_content.as_ref());
1059 columns.push(group_columns);
1060 }
1061
1062 let agg_series = Column::new("N".into(), [1, 2, 3, 3, 4].as_ref());
1064 columns.push(agg_series);
1065
1066 let df = DataFrame::new(columns).unwrap();
1068
1069 #[allow(deprecated)]
1071 let adf = df
1073 .group_by(series_names)
1074 .unwrap()
1075 .select(["N"])
1076 .sum()
1077 .unwrap();
1078
1079 for series_name in &series_names {
1082 assert_eq!(
1083 Vec::from(&adf.column(series_name).unwrap().str().unwrap().sort(false)),
1084 &[Some("A"), Some("B"), Some("C")]
1085 );
1086 }
1087
1088 assert_eq!(
1090 Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1091 &[Some(3), Some(4), Some(6)]
1092 );
1093 }
1094
1095 #[test]
1096 #[cfg_attr(miri, ignore)]
1097 fn test_group_by_floats() {
1098 let df = df! {"flt" => [1., 1., 2., 2., 3.],
1099 "val" => [1, 1, 1, 1, 1]
1100 }
1101 .unwrap();
1102 #[allow(deprecated)]
1104 let res = df.group_by(["flt"]).unwrap().sum().unwrap();
1105 let res = res.sort(["flt"], SortMultipleOptions::default()).unwrap();
1106 assert_eq!(
1107 Vec::from(res.column("val_sum").unwrap().i32().unwrap()),
1108 &[Some(2), Some(2), Some(1)]
1109 );
1110 }
1111
1112 #[test]
1113 #[cfg_attr(miri, ignore)]
1114 #[cfg(feature = "dtype-categorical")]
1115 fn test_group_by_categorical() {
1116 let mut df = df! {"foo" => ["a", "a", "b", "b", "c"],
1117 "ham" => ["a", "a", "b", "b", "c"],
1118 "bar" => [1, 1, 1, 1, 1]
1119 }
1120 .unwrap();
1121
1122 df.apply("foo", |s| {
1123 s.cast(&DataType::from_categories(Categories::global()))
1124 .unwrap()
1125 })
1126 .unwrap();
1127
1128 #[allow(deprecated)]
1130 let res = df
1132 .group_by_stable(["foo", "ham"])
1133 .unwrap()
1134 .select(["bar"])
1135 .sum()
1136 .unwrap();
1137
1138 assert_eq!(
1139 Vec::from(
1140 res.column("bar_sum")
1141 .unwrap()
1142 .as_materialized_series()
1143 .i32()
1144 .unwrap()
1145 ),
1146 &[Some(2), Some(2), Some(1)]
1147 );
1148 }
1149
1150 #[test]
1151 #[cfg_attr(miri, ignore)]
1152 fn test_group_by_null_handling() -> PolarsResult<()> {
1153 let df = df!(
1154 "a" => ["a", "a", "a", "b", "b"],
1155 "b" => [Some(1), Some(2), None, None, Some(1)]
1156 )?;
1157 #[allow(deprecated)]
1159 let out = df.group_by_stable(["a"])?.mean()?;
1160
1161 assert_eq!(
1162 Vec::from(out.column("b_mean")?.as_materialized_series().f64()?),
1163 &[Some(1.5), Some(1.0)]
1164 );
1165 Ok(())
1166 }
1167
1168 #[test]
1169 #[cfg_attr(miri, ignore)]
1170 fn test_group_by_var() -> PolarsResult<()> {
1171 let df = df![
1173 "g" => ["foo", "foo", "bar"],
1174 "flt" => [1.0, 2.0, 3.0],
1175 "int" => [1, 2, 3]
1176 ]?;
1177
1178 #[allow(deprecated)]
1180 let out = df.group_by_stable(["g"])?.select(["int"]).var(1)?;
1181
1182 assert_eq!(out.column("int_agg_var")?.f64()?.get(0), Some(0.5));
1183 #[allow(deprecated)]
1185 let out = df.group_by_stable(["g"])?.select(["int"]).std(1)?;
1186 let val = out.column("int_agg_std")?.f64()?.get(0).unwrap();
1187 let expected = f64::FRAC_1_SQRT_2();
1188 assert!((val - expected).abs() < 0.000001);
1189 Ok(())
1190 }
1191
1192 #[test]
1193 #[cfg_attr(miri, ignore)]
1194 #[cfg(feature = "dtype-categorical")]
1195 fn test_group_by_null_group() -> PolarsResult<()> {
1196 let mut df = df![
1198 "g" => [Some("foo"), Some("foo"), Some("bar"), None, None],
1199 "flt" => [1.0, 2.0, 3.0, 1.0, 1.0],
1200 "int" => [1, 2, 3, 1, 1]
1201 ]?;
1202
1203 df.try_apply("g", |s| {
1204 s.cast(&DataType::from_categories(Categories::global()))
1205 })?;
1206
1207 #[allow(deprecated)]
1209 let _ = df.group_by(["g"])?.sum()?;
1210 Ok(())
1211 }
1212}