polars_core/frame/group_by/
mod.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3
4use num_traits::NumCast;
5use polars_compute::rolling::QuantileMethod;
6use polars_utils::format_pl_smallstr;
7use polars_utils::hashing::DirtyHash;
8use rayon::prelude::*;
9
10use self::hashing::*;
11use crate::POOL;
12use crate::prelude::*;
13use crate::utils::{_set_partition_size, accumulate_dataframes_vertical};
14
15pub mod aggregations;
16pub mod expr;
17pub(crate) mod hashing;
18mod into_groups;
19mod position;
20
21pub use into_groups::*;
22pub use position::*;
23
24use crate::chunked_array::ops::row_encode::{
25    encode_rows_unordered, encode_rows_vertical_par_unordered,
26};
27
28impl DataFrame {
29    pub fn group_by_with_series(
30        &self,
31        mut by: Vec<Column>,
32        multithreaded: bool,
33        sorted: bool,
34    ) -> PolarsResult<GroupBy<'_>> {
35        polars_ensure!(
36            !by.is_empty(),
37            ComputeError: "at least one key is required in a group_by operation"
38        );
39
40        // Ensure all 'by' columns have the same common_height
41        // The condition self.width > 0 ensures we can still call this on a
42        // dummy dataframe where we provide the keys
43        let common_height = if self.width() > 0 {
44            self.height()
45        } else {
46            by.iter().map(|s| s.len()).max().expect("at least 1 key")
47        };
48        for by_key in by.iter_mut() {
49            if by_key.len() != common_height {
50                polars_ensure!(
51                    by_key.len() == 1,
52                    ShapeMismatch: "series used as keys should have the same length as the DataFrame"
53                );
54                *by_key = by_key.new_from_index(0, common_height)
55            }
56        }
57
58        let groups = if by.len() == 1 {
59            let column = &by[0];
60            column
61                .as_materialized_series()
62                .group_tuples(multithreaded, sorted)
63        } else if by.iter().any(|s| s.dtype().is_object()) {
64            #[cfg(feature = "object")]
65            {
66                let mut df = DataFrame::new(by.clone()).unwrap();
67                let n = df.height();
68                let rows = df.to_av_rows();
69                let iter = (0..n).map(|i| rows.get(i));
70                Ok(group_by(iter, sorted))
71            }
72            #[cfg(not(feature = "object"))]
73            {
74                unreachable!()
75            }
76        } else {
77            // Skip null dtype.
78            let by = by
79                .iter()
80                .filter(|s| !s.dtype().is_null())
81                .cloned()
82                .collect::<Vec<_>>();
83            if by.is_empty() {
84                let groups = if self.is_empty() {
85                    vec![]
86                } else {
87                    vec![[0, self.height() as IdxSize]]
88                };
89                Ok(GroupsType::Slice {
90                    groups,
91                    overlapping: false,
92                })
93            } else {
94                let rows = if multithreaded {
95                    encode_rows_vertical_par_unordered(&by)
96                } else {
97                    encode_rows_unordered(&by)
98                }?
99                .into_series();
100                rows.group_tuples(multithreaded, sorted)
101            }
102        };
103        Ok(GroupBy::new(self, by, groups?.into_sliceable(), None))
104    }
105
106    /// Group DataFrame using a Series column.
107    ///
108    /// # Example
109    ///
110    /// ```
111    /// use polars_core::prelude::*;
112    /// fn group_by_sum(df: &DataFrame) -> PolarsResult<DataFrame> {
113    ///     df.group_by(["column_name"])?
114    ///     .select(["agg_column_name"])
115    ///     .sum()
116    /// }
117    /// ```
118    pub fn group_by<I, S>(&self, by: I) -> PolarsResult<GroupBy<'_>>
119    where
120        I: IntoIterator<Item = S>,
121        S: Into<PlSmallStr>,
122    {
123        let selected_keys = self.select_columns(by)?;
124        self.group_by_with_series(selected_keys, true, false)
125    }
126
127    /// Group DataFrame using a Series column.
128    /// The groups are ordered by their smallest row index.
129    pub fn group_by_stable<I, S>(&self, by: I) -> PolarsResult<GroupBy<'_>>
130    where
131        I: IntoIterator<Item = S>,
132        S: Into<PlSmallStr>,
133    {
134        let selected_keys = self.select_columns(by)?;
135        self.group_by_with_series(selected_keys, true, true)
136    }
137}
138
139/// Returned by a group_by operation on a DataFrame. This struct supports
140/// several aggregations.
141///
142/// Until described otherwise, the examples in this struct are performed on the following DataFrame:
143///
144/// ```ignore
145/// use polars_core::prelude::*;
146///
147/// let dates = &[
148/// "2020-08-21",
149/// "2020-08-21",
150/// "2020-08-22",
151/// "2020-08-23",
152/// "2020-08-22",
153/// ];
154/// // date format
155/// let fmt = "%Y-%m-%d";
156/// // create date series
157/// let s0 = DateChunked::parse_from_str_slice("date", dates, fmt)
158///         .into_series();
159/// // create temperature series
160/// let s1 = Series::new("temp".into(), [20, 10, 7, 9, 1]);
161/// // create rain series
162/// let s2 = Series::new("rain".into(), [0.2, 0.1, 0.3, 0.1, 0.01]);
163/// // create a new DataFrame
164/// let df = DataFrame::new(vec![s0, s1, s2]).unwrap();
165/// println!("{:?}", df);
166/// ```
167///
168/// Outputs:
169///
170/// ```text
171/// +------------+------+------+
172/// | date       | temp | rain |
173/// | ---        | ---  | ---  |
174/// | Date       | i32  | f64  |
175/// +============+======+======+
176/// | 2020-08-21 | 20   | 0.2  |
177/// +------------+------+------+
178/// | 2020-08-21 | 10   | 0.1  |
179/// +------------+------+------+
180/// | 2020-08-22 | 7    | 0.3  |
181/// +------------+------+------+
182/// | 2020-08-23 | 9    | 0.1  |
183/// +------------+------+------+
184/// | 2020-08-22 | 1    | 0.01 |
185/// +------------+------+------+
186/// ```
187///
188#[derive(Debug, Clone)]
189pub struct GroupBy<'a> {
190    pub df: &'a DataFrame,
191    pub(crate) selected_keys: Vec<Column>,
192    // [first idx, [other idx]]
193    groups: GroupPositions,
194    // columns selected for aggregation
195    pub(crate) selected_agg: Option<Vec<PlSmallStr>>,
196}
197
198impl<'a> GroupBy<'a> {
199    pub fn new(
200        df: &'a DataFrame,
201        by: Vec<Column>,
202        groups: GroupPositions,
203        selected_agg: Option<Vec<PlSmallStr>>,
204    ) -> Self {
205        GroupBy {
206            df,
207            selected_keys: by,
208            groups,
209            selected_agg,
210        }
211    }
212
213    /// Select the column(s) that should be aggregated.
214    /// You can select a single column or a slice of columns.
215    ///
216    /// Note that making a selection with this method is not required. If you
217    /// skip it all columns (except for the keys) will be selected for aggregation.
218    #[must_use]
219    pub fn select<I: IntoIterator<Item = S>, S: Into<PlSmallStr>>(mut self, selection: I) -> Self {
220        self.selected_agg = Some(selection.into_iter().map(|s| s.into()).collect());
221        self
222    }
223
224    /// Get the internal representation of the GroupBy operation.
225    /// The Vec returned contains:
226    ///     (first_idx, [`Vec<indexes>`])
227    ///     Where second value in the tuple is a vector with all matching indexes.
228    pub fn get_groups(&self) -> &GroupPositions {
229        &self.groups
230    }
231
232    /// Get the internal representation of the GroupBy operation.
233    /// The Vec returned contains:
234    ///     (first_idx, [`Vec<indexes>`])
235    ///     Where second value in the tuple is a vector with all matching indexes.
236    ///
237    /// # Safety
238    /// Groups should always be in bounds of the `DataFrame` hold by this [`GroupBy`].
239    /// If you mutate it, you must hold that invariant.
240    pub unsafe fn get_groups_mut(&mut self) -> &mut GroupPositions {
241        &mut self.groups
242    }
243
244    pub fn take_groups(self) -> GroupPositions {
245        self.groups
246    }
247
248    pub fn take_groups_mut(&mut self) -> GroupPositions {
249        std::mem::take(&mut self.groups)
250    }
251
252    pub fn keys_sliced(&self, slice: Option<(i64, usize)>) -> Vec<Column> {
253        #[allow(unused_assignments)]
254        // needed to keep the lifetimes valid for this scope
255        let mut groups_owned = None;
256
257        let groups = if let Some((offset, len)) = slice {
258            groups_owned = Some(self.groups.slice(offset, len));
259            groups_owned.as_deref().unwrap()
260        } else {
261            &self.groups
262        };
263        POOL.install(|| {
264            self.selected_keys
265                .par_iter()
266                .map(Column::as_materialized_series)
267                .map(|s| {
268                    match groups {
269                        GroupsType::Idx(groups) => {
270                            // SAFETY: groups are always in bounds.
271                            let mut out = unsafe { s.take_slice_unchecked(groups.first()) };
272                            if groups.sorted {
273                                out.set_sorted_flag(s.is_sorted_flag());
274                            };
275                            out
276                        },
277                        GroupsType::Slice {
278                            groups,
279                            overlapping,
280                        } => {
281                            if *overlapping && !groups.is_empty() {
282                                // Groups can be sliced.
283                                let offset = groups[0][0];
284                                let [upper_offset, upper_len] = groups[groups.len() - 1];
285                                return s.slice(
286                                    offset as i64,
287                                    ((upper_offset + upper_len) - offset) as usize,
288                                );
289                            }
290
291                            let indices = groups
292                                .iter()
293                                .map(|&[first, _len]| first)
294                                .collect_ca(PlSmallStr::EMPTY);
295                            // SAFETY: groups are always in bounds.
296                            let mut out = unsafe { s.take_unchecked(&indices) };
297                            // Sliced groups are always in order of discovery.
298                            out.set_sorted_flag(s.is_sorted_flag());
299                            out
300                        },
301                    }
302                })
303                .map(Column::from)
304                .collect()
305        })
306    }
307
308    pub fn keys(&self) -> Vec<Column> {
309        self.keys_sliced(None)
310    }
311
312    fn prepare_agg(&self) -> PolarsResult<(Vec<Column>, Vec<Column>)> {
313        let keys = self.keys();
314
315        let agg_col = match &self.selected_agg {
316            Some(selection) => self.df.select_columns_impl(selection.as_slice()),
317            None => {
318                let by: Vec<_> = self.selected_keys.iter().map(|s| s.name()).collect();
319                let selection = self
320                    .df
321                    .iter()
322                    .map(|s| s.name())
323                    .filter(|a| !by.contains(a))
324                    .cloned()
325                    .collect::<Vec<_>>();
326
327                self.df.select_columns_impl(selection.as_slice())
328            },
329        }?;
330
331        Ok((keys, agg_col))
332    }
333
334    /// Aggregate grouped series and compute the mean per group.
335    ///
336    /// # Example
337    ///
338    /// ```rust
339    /// # use polars_core::prelude::*;
340    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
341    ///     df.group_by(["date"])?.select(["temp", "rain"]).mean()
342    /// }
343    /// ```
344    /// Returns:
345    ///
346    /// ```text
347    /// +------------+-----------+-----------+
348    /// | date       | temp_mean | rain_mean |
349    /// | ---        | ---       | ---       |
350    /// | Date       | f64       | f64       |
351    /// +============+===========+===========+
352    /// | 2020-08-23 | 9         | 0.1       |
353    /// +------------+-----------+-----------+
354    /// | 2020-08-22 | 4         | 0.155     |
355    /// +------------+-----------+-----------+
356    /// | 2020-08-21 | 15        | 0.15      |
357    /// +------------+-----------+-----------+
358    /// ```
359    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
360    pub fn mean(&self) -> PolarsResult<DataFrame> {
361        let (mut cols, agg_cols) = self.prepare_agg()?;
362
363        for agg_col in agg_cols {
364            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Mean);
365            let mut agg = unsafe { agg_col.agg_mean(&self.groups) };
366            agg.rename(new_name);
367            cols.push(agg);
368        }
369        DataFrame::new(cols)
370    }
371
372    /// Aggregate grouped series and compute the sum per group.
373    ///
374    /// # Example
375    ///
376    /// ```rust
377    /// # use polars_core::prelude::*;
378    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
379    ///     df.group_by(["date"])?.select(["temp"]).sum()
380    /// }
381    /// ```
382    /// Returns:
383    ///
384    /// ```text
385    /// +------------+----------+
386    /// | date       | temp_sum |
387    /// | ---        | ---      |
388    /// | Date       | i32      |
389    /// +============+==========+
390    /// | 2020-08-23 | 9        |
391    /// +------------+----------+
392    /// | 2020-08-22 | 8        |
393    /// +------------+----------+
394    /// | 2020-08-21 | 30       |
395    /// +------------+----------+
396    /// ```
397    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
398    pub fn sum(&self) -> PolarsResult<DataFrame> {
399        let (mut cols, agg_cols) = self.prepare_agg()?;
400
401        for agg_col in agg_cols {
402            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Sum);
403            let mut agg = unsafe { agg_col.agg_sum(&self.groups) };
404            agg.rename(new_name);
405            cols.push(agg);
406        }
407        DataFrame::new(cols)
408    }
409
410    /// Aggregate grouped series and compute the minimal value per group.
411    ///
412    /// # Example
413    ///
414    /// ```rust
415    /// # use polars_core::prelude::*;
416    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
417    ///     df.group_by(["date"])?.select(["temp"]).min()
418    /// }
419    /// ```
420    /// Returns:
421    ///
422    /// ```text
423    /// +------------+----------+
424    /// | date       | temp_min |
425    /// | ---        | ---      |
426    /// | Date       | i32      |
427    /// +============+==========+
428    /// | 2020-08-23 | 9        |
429    /// +------------+----------+
430    /// | 2020-08-22 | 1        |
431    /// +------------+----------+
432    /// | 2020-08-21 | 10       |
433    /// +------------+----------+
434    /// ```
435    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
436    pub fn min(&self) -> PolarsResult<DataFrame> {
437        let (mut cols, agg_cols) = self.prepare_agg()?;
438        for agg_col in agg_cols {
439            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Min);
440            let mut agg = unsafe { agg_col.agg_min(&self.groups) };
441            agg.rename(new_name);
442            cols.push(agg);
443        }
444        DataFrame::new(cols)
445    }
446
447    /// Aggregate grouped series and compute the maximum value per group.
448    ///
449    /// # Example
450    ///
451    /// ```rust
452    /// # use polars_core::prelude::*;
453    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
454    ///     df.group_by(["date"])?.select(["temp"]).max()
455    /// }
456    /// ```
457    /// Returns:
458    ///
459    /// ```text
460    /// +------------+----------+
461    /// | date       | temp_max |
462    /// | ---        | ---      |
463    /// | Date       | i32      |
464    /// +============+==========+
465    /// | 2020-08-23 | 9        |
466    /// +------------+----------+
467    /// | 2020-08-22 | 7        |
468    /// +------------+----------+
469    /// | 2020-08-21 | 20       |
470    /// +------------+----------+
471    /// ```
472    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
473    pub fn max(&self) -> PolarsResult<DataFrame> {
474        let (mut cols, agg_cols) = self.prepare_agg()?;
475        for agg_col in agg_cols {
476            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Max);
477            let mut agg = unsafe { agg_col.agg_max(&self.groups) };
478            agg.rename(new_name);
479            cols.push(agg);
480        }
481        DataFrame::new(cols)
482    }
483
484    /// Aggregate grouped `Series` and find the first value per group.
485    ///
486    /// # Example
487    ///
488    /// ```rust
489    /// # use polars_core::prelude::*;
490    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
491    ///     df.group_by(["date"])?.select(["temp"]).first()
492    /// }
493    /// ```
494    /// Returns:
495    ///
496    /// ```text
497    /// +------------+------------+
498    /// | date       | temp_first |
499    /// | ---        | ---        |
500    /// | Date       | i32        |
501    /// +============+============+
502    /// | 2020-08-23 | 9          |
503    /// +------------+------------+
504    /// | 2020-08-22 | 7          |
505    /// +------------+------------+
506    /// | 2020-08-21 | 20         |
507    /// +------------+------------+
508    /// ```
509    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
510    pub fn first(&self) -> PolarsResult<DataFrame> {
511        let (mut cols, agg_cols) = self.prepare_agg()?;
512        for agg_col in agg_cols {
513            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::First);
514            let mut agg = unsafe { agg_col.agg_first(&self.groups) };
515            agg.rename(new_name);
516            cols.push(agg);
517        }
518        DataFrame::new(cols)
519    }
520
521    /// Aggregate grouped `Series` and return the last value per group.
522    ///
523    /// # Example
524    ///
525    /// ```rust
526    /// # use polars_core::prelude::*;
527    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
528    ///     df.group_by(["date"])?.select(["temp"]).last()
529    /// }
530    /// ```
531    /// Returns:
532    ///
533    /// ```text
534    /// +------------+------------+
535    /// | date       | temp_last |
536    /// | ---        | ---        |
537    /// | Date       | i32        |
538    /// +============+============+
539    /// | 2020-08-23 | 9          |
540    /// +------------+------------+
541    /// | 2020-08-22 | 1          |
542    /// +------------+------------+
543    /// | 2020-08-21 | 10         |
544    /// +------------+------------+
545    /// ```
546    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
547    pub fn last(&self) -> PolarsResult<DataFrame> {
548        let (mut cols, agg_cols) = self.prepare_agg()?;
549        for agg_col in agg_cols {
550            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Last);
551            let mut agg = unsafe { agg_col.agg_last(&self.groups) };
552            agg.rename(new_name);
553            cols.push(agg);
554        }
555        DataFrame::new(cols)
556    }
557
558    /// Aggregate grouped `Series` by counting the number of unique values.
559    ///
560    /// # Example
561    ///
562    /// ```rust
563    /// # use polars_core::prelude::*;
564    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
565    ///     df.group_by(["date"])?.select(["temp"]).n_unique()
566    /// }
567    /// ```
568    /// Returns:
569    ///
570    /// ```text
571    /// +------------+---------------+
572    /// | date       | temp_n_unique |
573    /// | ---        | ---           |
574    /// | Date       | u32           |
575    /// +============+===============+
576    /// | 2020-08-23 | 1             |
577    /// +------------+---------------+
578    /// | 2020-08-22 | 2             |
579    /// +------------+---------------+
580    /// | 2020-08-21 | 2             |
581    /// +------------+---------------+
582    /// ```
583    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
584    pub fn n_unique(&self) -> PolarsResult<DataFrame> {
585        let (mut cols, agg_cols) = self.prepare_agg()?;
586        for agg_col in agg_cols {
587            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::NUnique);
588            let mut agg = unsafe { agg_col.agg_n_unique(&self.groups) };
589            agg.rename(new_name);
590            cols.push(agg);
591        }
592        DataFrame::new(cols)
593    }
594
595    /// Aggregate grouped [`Series`] and determine the quantile per group.
596    ///
597    /// # Example
598    ///
599    /// ```rust
600    /// # use polars_core::prelude::*;
601    ///
602    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
603    ///     df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileMethod::default())
604    /// }
605    /// ```
606    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
607    pub fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<DataFrame> {
608        polars_ensure!(
609            (0.0..=1.0).contains(&quantile),
610            ComputeError: "`quantile` should be within 0.0 and 1.0"
611        );
612        let (mut cols, agg_cols) = self.prepare_agg()?;
613        for agg_col in agg_cols {
614            let new_name = fmt_group_by_column(
615                agg_col.name().as_str(),
616                GroupByMethod::Quantile(quantile, method),
617            );
618            let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, method) };
619            agg.rename(new_name);
620            cols.push(agg);
621        }
622        DataFrame::new(cols)
623    }
624
625    /// Aggregate grouped [`Series`] and determine the median per group.
626    ///
627    /// # Example
628    ///
629    /// ```rust
630    /// # use polars_core::prelude::*;
631    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
632    ///     df.group_by(["date"])?.select(["temp"]).median()
633    /// }
634    /// ```
635    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
636    pub fn median(&self) -> PolarsResult<DataFrame> {
637        let (mut cols, agg_cols) = self.prepare_agg()?;
638        for agg_col in agg_cols {
639            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Median);
640            let mut agg = unsafe { agg_col.agg_median(&self.groups) };
641            agg.rename(new_name);
642            cols.push(agg);
643        }
644        DataFrame::new(cols)
645    }
646
647    /// Aggregate grouped [`Series`] and determine the variance per group.
648    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
649    pub fn var(&self, ddof: u8) -> PolarsResult<DataFrame> {
650        let (mut cols, agg_cols) = self.prepare_agg()?;
651        for agg_col in agg_cols {
652            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Var(ddof));
653            let mut agg = unsafe { agg_col.agg_var(&self.groups, ddof) };
654            agg.rename(new_name);
655            cols.push(agg);
656        }
657        DataFrame::new(cols)
658    }
659
660    /// Aggregate grouped [`Series`] and determine the standard deviation per group.
661    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
662    pub fn std(&self, ddof: u8) -> PolarsResult<DataFrame> {
663        let (mut cols, agg_cols) = self.prepare_agg()?;
664        for agg_col in agg_cols {
665            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Std(ddof));
666            let mut agg = unsafe { agg_col.agg_std(&self.groups, ddof) };
667            agg.rename(new_name);
668            cols.push(agg);
669        }
670        DataFrame::new(cols)
671    }
672
673    /// Aggregate grouped series and compute the number of values per group.
674    ///
675    /// # Example
676    ///
677    /// ```rust
678    /// # use polars_core::prelude::*;
679    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
680    ///     df.group_by(["date"])?.select(["temp"]).count()
681    /// }
682    /// ```
683    /// Returns:
684    ///
685    /// ```text
686    /// +------------+------------+
687    /// | date       | temp_count |
688    /// | ---        | ---        |
689    /// | Date       | u32        |
690    /// +============+============+
691    /// | 2020-08-23 | 1          |
692    /// +------------+------------+
693    /// | 2020-08-22 | 2          |
694    /// +------------+------------+
695    /// | 2020-08-21 | 2          |
696    /// +------------+------------+
697    /// ```
698    pub fn count(&self) -> PolarsResult<DataFrame> {
699        let (mut cols, agg_cols) = self.prepare_agg()?;
700
701        for agg_col in agg_cols {
702            let new_name = fmt_group_by_column(
703                agg_col.name().as_str(),
704                GroupByMethod::Count {
705                    include_nulls: true,
706                },
707            );
708            let mut ca = self.groups.group_count();
709            ca.rename(new_name);
710            cols.push(ca.into_column());
711        }
712        DataFrame::new(cols)
713    }
714
715    /// Get the group_by group indexes.
716    ///
717    /// # Example
718    ///
719    /// ```rust
720    /// # use polars_core::prelude::*;
721    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
722    ///     df.group_by(["date"])?.groups()
723    /// }
724    /// ```
725    /// Returns:
726    ///
727    /// ```text
728    /// +--------------+------------+
729    /// | date         | groups     |
730    /// | ---          | ---        |
731    /// | Date(days)   | list [u32] |
732    /// +==============+============+
733    /// | 2020-08-23   | "[3]"      |
734    /// +--------------+------------+
735    /// | 2020-08-22   | "[2, 4]"   |
736    /// +--------------+------------+
737    /// | 2020-08-21   | "[0, 1]"   |
738    /// +--------------+------------+
739    /// ```
740    pub fn groups(&self) -> PolarsResult<DataFrame> {
741        let mut cols = self.keys();
742        let mut column = self.groups.as_list_chunked();
743        let new_name = fmt_group_by_column("", GroupByMethod::Groups);
744        column.rename(new_name);
745        cols.push(column.into_column());
746        DataFrame::new(cols)
747    }
748
749    /// Aggregate the groups of the group_by operation into lists.
750    ///
751    /// # Example
752    ///
753    /// ```rust
754    /// # use polars_core::prelude::*;
755    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
756    ///     // GroupBy and aggregate to Lists
757    ///     df.group_by(["date"])?.select(["temp"]).agg_list()
758    /// }
759    /// ```
760    /// Returns:
761    ///
762    /// ```text
763    /// +------------+------------------------+
764    /// | date       | temp_agg_list          |
765    /// | ---        | ---                    |
766    /// | Date       | list [i32]             |
767    /// +============+========================+
768    /// | 2020-08-23 | "[Some(9)]"            |
769    /// +------------+------------------------+
770    /// | 2020-08-22 | "[Some(7), Some(1)]"   |
771    /// +------------+------------------------+
772    /// | 2020-08-21 | "[Some(20), Some(10)]" |
773    /// +------------+------------------------+
774    /// ```
775    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
776    pub fn agg_list(&self) -> PolarsResult<DataFrame> {
777        let (mut cols, agg_cols) = self.prepare_agg()?;
778        for agg_col in agg_cols {
779            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Implode);
780            let mut agg = unsafe { agg_col.agg_list(&self.groups) };
781            agg.rename(new_name);
782            cols.push(agg);
783        }
784        DataFrame::new(cols)
785    }
786
787    fn prepare_apply(&self) -> PolarsResult<DataFrame> {
788        polars_ensure!(self.df.height() > 0, ComputeError: "cannot group_by + apply on empty 'DataFrame'");
789        if let Some(agg) = &self.selected_agg {
790            if agg.is_empty() {
791                Ok(self.df.clone())
792            } else {
793                let mut new_cols = Vec::with_capacity(self.selected_keys.len() + agg.len());
794                new_cols.extend_from_slice(&self.selected_keys);
795                let cols = self.df.select_columns_impl(agg.as_slice())?;
796                new_cols.extend(cols);
797                Ok(unsafe { DataFrame::new_no_checks(self.df.height(), new_cols) })
798            }
799        } else {
800            Ok(self.df.clone())
801        }
802    }
803
804    /// Apply a closure over the groups as a new [`DataFrame`] in parallel.
805    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
806    pub fn par_apply<F>(&self, f: F) -> PolarsResult<DataFrame>
807    where
808        F: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
809    {
810        let df = self.prepare_apply()?;
811        let dfs = self
812            .get_groups()
813            .par_iter()
814            .map(|g| {
815                // SAFETY:
816                // groups are in bounds
817                let sub_df = unsafe { take_df(&df, g) };
818                f(sub_df)
819            })
820            .collect::<PolarsResult<Vec<_>>>()?;
821
822        let mut df = accumulate_dataframes_vertical(dfs)?;
823        df.as_single_chunk_par();
824        Ok(df)
825    }
826
827    /// Apply a closure over the groups as a new [`DataFrame`].
828    pub fn apply<F>(&self, mut f: F) -> PolarsResult<DataFrame>
829    where
830        F: FnMut(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
831    {
832        let df = self.prepare_apply()?;
833        let dfs = self
834            .get_groups()
835            .iter()
836            .map(|g| {
837                // SAFETY:
838                // groups are in bounds
839                let sub_df = unsafe { take_df(&df, g) };
840                f(sub_df)
841            })
842            .collect::<PolarsResult<Vec<_>>>()?;
843
844        let mut df = accumulate_dataframes_vertical(dfs)?;
845        df.as_single_chunk_par();
846        Ok(df)
847    }
848
849    pub fn sliced(mut self, slice: Option<(i64, usize)>) -> Self {
850        match slice {
851            None => self,
852            Some((offset, length)) => {
853                self.groups = self.groups.slice(offset, length);
854                self.selected_keys = self.keys_sliced(slice);
855                self
856            },
857        }
858    }
859}
860
861unsafe fn take_df(df: &DataFrame, g: GroupsIndicator) -> DataFrame {
862    match g {
863        GroupsIndicator::Idx(idx) => df.take_slice_unchecked(idx.1),
864        GroupsIndicator::Slice([first, len]) => df.slice(first as i64, len as usize),
865    }
866}
867
868#[derive(Copy, Clone, Debug)]
869pub enum GroupByMethod {
870    Min,
871    NanMin,
872    Max,
873    NanMax,
874    Median,
875    Mean,
876    First,
877    Last,
878    Item { allow_empty: bool },
879    Sum,
880    Groups,
881    NUnique,
882    Quantile(f64, QuantileMethod),
883    Count { include_nulls: bool },
884    Implode,
885    Std(u8),
886    Var(u8),
887}
888
889impl Display for GroupByMethod {
890    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
891        use GroupByMethod::*;
892        let s = match self {
893            Min => "min",
894            NanMin => "nan_min",
895            Max => "max",
896            NanMax => "nan_max",
897            Median => "median",
898            Mean => "mean",
899            First => "first",
900            Last => "last",
901            Item { .. } => "item",
902            Sum => "sum",
903            Groups => "groups",
904            NUnique => "n_unique",
905            Quantile(_, _) => "quantile",
906            Count { .. } => "count",
907            Implode => "list",
908            Std(_) => "std",
909            Var(_) => "var",
910        };
911        write!(f, "{s}")
912    }
913}
914
915// Formatting functions used in eager and lazy code for renaming grouped columns
916pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> PlSmallStr {
917    use GroupByMethod::*;
918    match method {
919        Min => format_pl_smallstr!("{name}_min"),
920        Max => format_pl_smallstr!("{name}_max"),
921        NanMin => format_pl_smallstr!("{name}_nan_min"),
922        NanMax => format_pl_smallstr!("{name}_nan_max"),
923        Median => format_pl_smallstr!("{name}_median"),
924        Mean => format_pl_smallstr!("{name}_mean"),
925        First => format_pl_smallstr!("{name}_first"),
926        Last => format_pl_smallstr!("{name}_last"),
927        Item { .. } => format_pl_smallstr!("{name}_item"),
928        Sum => format_pl_smallstr!("{name}_sum"),
929        Groups => PlSmallStr::from_static("groups"),
930        NUnique => format_pl_smallstr!("{name}_n_unique"),
931        Count { .. } => format_pl_smallstr!("{name}_count"),
932        Implode => format_pl_smallstr!("{name}_agg_list"),
933        Quantile(quantile, _interpol) => format_pl_smallstr!("{name}_quantile_{quantile:.2}"),
934        Std(_) => format_pl_smallstr!("{name}_agg_std"),
935        Var(_) => format_pl_smallstr!("{name}_agg_var"),
936    }
937}
938
939#[cfg(test)]
940mod test {
941    use num_traits::FloatConst;
942
943    use crate::prelude::*;
944
945    #[test]
946    #[cfg(feature = "dtype-date")]
947    #[cfg_attr(miri, ignore)]
948    fn test_group_by() -> PolarsResult<()> {
949        let s0 = Column::new(
950            PlSmallStr::from_static("date"),
951            &[
952                "2020-08-21",
953                "2020-08-21",
954                "2020-08-22",
955                "2020-08-23",
956                "2020-08-22",
957            ],
958        );
959        let s1 = Column::new(PlSmallStr::from_static("temp"), [20, 10, 7, 9, 1]);
960        let s2 = Column::new(PlSmallStr::from_static("rain"), [0.2, 0.1, 0.3, 0.1, 0.01]);
961        let df = DataFrame::new(vec![s0, s1, s2]).unwrap();
962
963        let out = df.group_by_stable(["date"])?.select(["temp"]).count()?;
964        assert_eq!(
965            out.column("temp_count")?,
966            &Column::new(PlSmallStr::from_static("temp_count"), [2 as IdxSize, 2, 1])
967        );
968
969        // Use of deprecated mean() for testing purposes
970        #[allow(deprecated)]
971        // Select multiple
972        let out = df
973            .group_by_stable(["date"])?
974            .select(["temp", "rain"])
975            .mean()?;
976        assert_eq!(
977            out.column("temp_mean")?,
978            &Column::new(PlSmallStr::from_static("temp_mean"), [15.0f64, 4.0, 9.0])
979        );
980
981        // Use of deprecated `mean()` for testing purposes
982        #[allow(deprecated)]
983        // Group by multiple
984        let out = df
985            .group_by_stable(["date", "temp"])?
986            .select(["rain"])
987            .mean()?;
988        assert!(out.column("rain_mean").is_ok());
989
990        // Use of deprecated `sum()` for testing purposes
991        #[allow(deprecated)]
992        let out = df.group_by_stable(["date"])?.select(["temp"]).sum()?;
993        assert_eq!(
994            out.column("temp_sum")?,
995            &Column::new(PlSmallStr::from_static("temp_sum"), [30, 8, 9])
996        );
997
998        // Use of deprecated `n_unique()` for testing purposes
999        #[allow(deprecated)]
1000        // implicit select all and only aggregate on methods that support that aggregation
1001        let gb = df.group_by(["date"]).unwrap().n_unique().unwrap();
1002        // check the group by column is filtered out.
1003        assert_eq!(gb.width(), 3);
1004        Ok(())
1005    }
1006
1007    #[test]
1008    #[cfg_attr(miri, ignore)]
1009    fn test_static_group_by_by_12_columns() {
1010        // Build GroupBy DataFrame.
1011        let s0 = Column::new("G1".into(), ["A", "A", "B", "B", "C"].as_ref());
1012        let s1 = Column::new("N".into(), [1, 2, 2, 4, 2].as_ref());
1013        let s2 = Column::new("G2".into(), ["k", "l", "m", "m", "l"].as_ref());
1014        let s3 = Column::new("G3".into(), ["a", "b", "c", "c", "d"].as_ref());
1015        let s4 = Column::new("G4".into(), ["1", "2", "3", "3", "4"].as_ref());
1016        let s5 = Column::new("G5".into(), ["X", "Y", "Z", "Z", "W"].as_ref());
1017        let s6 = Column::new("G6".into(), [false, true, true, true, false].as_ref());
1018        let s7 = Column::new("G7".into(), ["r", "x", "q", "q", "o"].as_ref());
1019        let s8 = Column::new("G8".into(), ["R", "X", "Q", "Q", "O"].as_ref());
1020        let s9 = Column::new("G9".into(), [1, 2, 3, 3, 4].as_ref());
1021        let s10 = Column::new("G10".into(), [".", "!", "?", "?", "/"].as_ref());
1022        let s11 = Column::new("G11".into(), ["(", ")", "@", "@", "$"].as_ref());
1023        let s12 = Column::new("G12".into(), ["-", "_", ";", ";", ","].as_ref());
1024
1025        let df =
1026            DataFrame::new(vec![s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]).unwrap();
1027
1028        // Use of deprecated `sum()` for testing purposes
1029        #[allow(deprecated)]
1030        let adf = df
1031            .group_by([
1032                "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12",
1033            ])
1034            .unwrap()
1035            .select(["N"])
1036            .sum()
1037            .unwrap();
1038
1039        assert_eq!(
1040            Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1041            &[Some(1), Some(2), Some(2), Some(6)]
1042        );
1043    }
1044
1045    #[test]
1046    #[cfg_attr(miri, ignore)]
1047    fn test_dynamic_group_by_by_13_columns() {
1048        // The content for every group_by series.
1049        let series_content = ["A", "A", "B", "B", "C"];
1050
1051        // The name of every group_by series.
1052        let series_names = [
1053            "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12", "G13",
1054        ];
1055
1056        // Vector to contain every series.
1057        let mut columns = Vec::with_capacity(14);
1058
1059        // Create a series for every group name.
1060        for series_name in series_names {
1061            let group_columns = Column::new(series_name.into(), series_content.as_ref());
1062            columns.push(group_columns);
1063        }
1064
1065        // Create a series for the aggregation column.
1066        let agg_series = Column::new("N".into(), [1, 2, 3, 3, 4].as_ref());
1067        columns.push(agg_series);
1068
1069        // Create the dataframe with the computed series.
1070        let df = DataFrame::new(columns).unwrap();
1071
1072        // Use of deprecated `sum()` for testing purposes
1073        #[allow(deprecated)]
1074        // Compute the aggregated DataFrame by the 13 columns defined in `series_names`.
1075        let adf = df
1076            .group_by(series_names)
1077            .unwrap()
1078            .select(["N"])
1079            .sum()
1080            .unwrap();
1081
1082        // Check that the results of the group-by are correct. The content of every column
1083        // is equal, then, the grouped columns shall be equal and in the same order.
1084        for series_name in &series_names {
1085            assert_eq!(
1086                Vec::from(&adf.column(series_name).unwrap().str().unwrap().sort(false)),
1087                &[Some("A"), Some("B"), Some("C")]
1088            );
1089        }
1090
1091        // Check the aggregated column is the expected one.
1092        assert_eq!(
1093            Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1094            &[Some(3), Some(4), Some(6)]
1095        );
1096    }
1097
1098    #[test]
1099    #[cfg_attr(miri, ignore)]
1100    fn test_group_by_floats() {
1101        let df = df! {"flt" => [1., 1., 2., 2., 3.],
1102                    "val" => [1, 1, 1, 1, 1]
1103        }
1104        .unwrap();
1105        // Use of deprecated `sum()` for testing purposes
1106        #[allow(deprecated)]
1107        let res = df.group_by(["flt"]).unwrap().sum().unwrap();
1108        let res = res.sort(["flt"], SortMultipleOptions::default()).unwrap();
1109        assert_eq!(
1110            Vec::from(res.column("val_sum").unwrap().i32().unwrap()),
1111            &[Some(2), Some(2), Some(1)]
1112        );
1113    }
1114
1115    #[test]
1116    #[cfg_attr(miri, ignore)]
1117    #[cfg(feature = "dtype-categorical")]
1118    fn test_group_by_categorical() {
1119        let mut df = df! {"foo" => ["a", "a", "b", "b", "c"],
1120                    "ham" => ["a", "a", "b", "b", "c"],
1121                    "bar" => [1, 1, 1, 1, 1]
1122        }
1123        .unwrap();
1124
1125        df.apply("foo", |s| {
1126            s.cast(&DataType::from_categories(Categories::global()))
1127                .unwrap()
1128        })
1129        .unwrap();
1130
1131        // Use of deprecated `sum()` for testing purposes
1132        #[allow(deprecated)]
1133        // check multiple keys and categorical
1134        let res = df
1135            .group_by_stable(["foo", "ham"])
1136            .unwrap()
1137            .select(["bar"])
1138            .sum()
1139            .unwrap();
1140
1141        assert_eq!(
1142            Vec::from(
1143                res.column("bar_sum")
1144                    .unwrap()
1145                    .as_materialized_series()
1146                    .i32()
1147                    .unwrap()
1148            ),
1149            &[Some(2), Some(2), Some(1)]
1150        );
1151    }
1152
1153    #[test]
1154    #[cfg_attr(miri, ignore)]
1155    fn test_group_by_null_handling() -> PolarsResult<()> {
1156        let df = df!(
1157            "a" => ["a", "a", "a", "b", "b"],
1158            "b" => [Some(1), Some(2), None, None, Some(1)]
1159        )?;
1160        // Use of deprecated `mean()` for testing purposes
1161        #[allow(deprecated)]
1162        let out = df.group_by_stable(["a"])?.mean()?;
1163
1164        assert_eq!(
1165            Vec::from(out.column("b_mean")?.as_materialized_series().f64()?),
1166            &[Some(1.5), Some(1.0)]
1167        );
1168        Ok(())
1169    }
1170
1171    #[test]
1172    #[cfg_attr(miri, ignore)]
1173    fn test_group_by_var() -> PolarsResult<()> {
1174        // check variance and proper coercion to f64
1175        let df = df![
1176            "g" => ["foo", "foo", "bar"],
1177            "flt" => [1.0, 2.0, 3.0],
1178            "int" => [1, 2, 3]
1179        ]?;
1180
1181        // Use of deprecated `sum()` for testing purposes
1182        #[allow(deprecated)]
1183        let out = df.group_by_stable(["g"])?.select(["int"]).var(1)?;
1184
1185        assert_eq!(out.column("int_agg_var")?.f64()?.get(0), Some(0.5));
1186        // Use of deprecated `std()` for testing purposes
1187        #[allow(deprecated)]
1188        let out = df.group_by_stable(["g"])?.select(["int"]).std(1)?;
1189        let val = out.column("int_agg_std")?.f64()?.get(0).unwrap();
1190        let expected = f64::FRAC_1_SQRT_2();
1191        assert!((val - expected).abs() < 0.000001);
1192        Ok(())
1193    }
1194
1195    #[test]
1196    #[cfg_attr(miri, ignore)]
1197    #[cfg(feature = "dtype-categorical")]
1198    fn test_group_by_null_group() -> PolarsResult<()> {
1199        // check if null is own group
1200        let mut df = df![
1201            "g" => [Some("foo"), Some("foo"), Some("bar"), None, None],
1202            "flt" => [1.0, 2.0, 3.0, 1.0, 1.0],
1203            "int" => [1, 2, 3, 1, 1]
1204        ]?;
1205
1206        df.try_apply("g", |s| {
1207            s.cast(&DataType::from_categories(Categories::global()))
1208        })?;
1209
1210        // Use of deprecated `sum()` for testing purposes
1211        #[allow(deprecated)]
1212        let _ = df.group_by(["g"])?.sum()?;
1213        Ok(())
1214    }
1215}