polars_core/frame/group_by/
mod.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3
4use num_traits::NumCast;
5use polars_compute::rolling::QuantileMethod;
6use polars_utils::format_pl_smallstr;
7use polars_utils::hashing::DirtyHash;
8use rayon::prelude::*;
9
10use self::hashing::*;
11use crate::POOL;
12use crate::prelude::*;
13use crate::utils::{_set_partition_size, accumulate_dataframes_vertical};
14
15pub mod aggregations;
16pub mod expr;
17pub(crate) mod hashing;
18mod into_groups;
19mod position;
20
21pub use into_groups::*;
22pub use position::*;
23
24use crate::chunked_array::ops::row_encode::{
25    encode_rows_unordered, encode_rows_vertical_par_unordered,
26};
27
28impl DataFrame {
29    pub fn group_by_with_series(
30        &self,
31        mut by: Vec<Column>,
32        multithreaded: bool,
33        sorted: bool,
34    ) -> PolarsResult<GroupBy<'_>> {
35        polars_ensure!(
36            !by.is_empty(),
37            ComputeError: "at least one key is required in a group_by operation"
38        );
39
40        // Ensure all 'by' columns have the same common_height
41        // The condition self.width > 0 ensures we can still call this on a
42        // dummy dataframe where we provide the keys
43        let common_height = if self.width() > 0 {
44            self.height()
45        } else {
46            by.iter().map(|s| s.len()).max().expect("at least 1 key")
47        };
48        for by_key in by.iter_mut() {
49            if by_key.len() != common_height {
50                polars_ensure!(
51                    by_key.len() == 1,
52                    ShapeMismatch: "series used as keys should have the same length as the DataFrame"
53                );
54                *by_key = by_key.new_from_index(0, common_height)
55            }
56        }
57
58        let groups = if by.len() == 1 {
59            let column = &by[0];
60            column
61                .as_materialized_series()
62                .group_tuples(multithreaded, sorted)
63        } else if by.iter().any(|s| s.dtype().is_object()) {
64            #[cfg(feature = "object")]
65            {
66                let mut df = DataFrame::new(by.clone()).unwrap();
67                let n = df.height();
68                let rows = df.to_av_rows();
69                let iter = (0..n).map(|i| rows.get(i));
70                Ok(group_by(iter, sorted))
71            }
72            #[cfg(not(feature = "object"))]
73            {
74                unreachable!()
75            }
76        } else {
77            // Skip null dtype.
78            let by = by
79                .iter()
80                .filter(|s| !s.dtype().is_null())
81                .cloned()
82                .collect::<Vec<_>>();
83            if by.is_empty() {
84                let groups = if self.is_empty() {
85                    vec![]
86                } else {
87                    vec![[0, self.height() as IdxSize]]
88                };
89                Ok(GroupsType::new_slice(groups, false, true))
90            } else {
91                let rows = if multithreaded {
92                    encode_rows_vertical_par_unordered(&by)
93                } else {
94                    encode_rows_unordered(&by)
95                }?
96                .into_series();
97                rows.group_tuples(multithreaded, sorted)
98            }
99        };
100        Ok(GroupBy::new(self, by, groups?.into_sliceable(), None))
101    }
102
103    /// Group DataFrame using a Series column.
104    ///
105    /// # Example
106    ///
107    /// ```
108    /// use polars_core::prelude::*;
109    /// fn group_by_sum(df: &DataFrame) -> PolarsResult<DataFrame> {
110    ///     df.group_by(["column_name"])?
111    ///     .select(["agg_column_name"])
112    ///     .sum()
113    /// }
114    /// ```
115    pub fn group_by<I, S>(&self, by: I) -> PolarsResult<GroupBy<'_>>
116    where
117        I: IntoIterator<Item = S>,
118        S: Into<PlSmallStr>,
119    {
120        let selected_keys = self.select_columns(by)?;
121        self.group_by_with_series(selected_keys, true, false)
122    }
123
124    /// Group DataFrame using a Series column.
125    /// The groups are ordered by their smallest row index.
126    pub fn group_by_stable<I, S>(&self, by: I) -> PolarsResult<GroupBy<'_>>
127    where
128        I: IntoIterator<Item = S>,
129        S: Into<PlSmallStr>,
130    {
131        let selected_keys = self.select_columns(by)?;
132        self.group_by_with_series(selected_keys, true, true)
133    }
134}
135
136/// Returned by a group_by operation on a DataFrame. This struct supports
137/// several aggregations.
138///
139/// Until described otherwise, the examples in this struct are performed on the following DataFrame:
140///
141/// ```ignore
142/// use polars_core::prelude::*;
143///
144/// let dates = &[
145/// "2020-08-21",
146/// "2020-08-21",
147/// "2020-08-22",
148/// "2020-08-23",
149/// "2020-08-22",
150/// ];
151/// // date format
152/// let fmt = "%Y-%m-%d";
153/// // create date series
154/// let s0 = DateChunked::parse_from_str_slice("date", dates, fmt)
155///         .into_series();
156/// // create temperature series
157/// let s1 = Series::new("temp".into(), [20, 10, 7, 9, 1]);
158/// // create rain series
159/// let s2 = Series::new("rain".into(), [0.2, 0.1, 0.3, 0.1, 0.01]);
160/// // create a new DataFrame
161/// let df = DataFrame::new(vec![s0, s1, s2]).unwrap();
162/// println!("{:?}", df);
163/// ```
164///
165/// Outputs:
166///
167/// ```text
168/// +------------+------+------+
169/// | date       | temp | rain |
170/// | ---        | ---  | ---  |
171/// | Date       | i32  | f64  |
172/// +============+======+======+
173/// | 2020-08-21 | 20   | 0.2  |
174/// +------------+------+------+
175/// | 2020-08-21 | 10   | 0.1  |
176/// +------------+------+------+
177/// | 2020-08-22 | 7    | 0.3  |
178/// +------------+------+------+
179/// | 2020-08-23 | 9    | 0.1  |
180/// +------------+------+------+
181/// | 2020-08-22 | 1    | 0.01 |
182/// +------------+------+------+
183/// ```
184///
185#[derive(Debug, Clone)]
186pub struct GroupBy<'a> {
187    pub df: &'a DataFrame,
188    pub(crate) selected_keys: Vec<Column>,
189    // [first idx, [other idx]]
190    groups: GroupPositions,
191    // columns selected for aggregation
192    pub(crate) selected_agg: Option<Vec<PlSmallStr>>,
193}
194
195impl<'a> GroupBy<'a> {
196    pub fn new(
197        df: &'a DataFrame,
198        by: Vec<Column>,
199        groups: GroupPositions,
200        selected_agg: Option<Vec<PlSmallStr>>,
201    ) -> Self {
202        GroupBy {
203            df,
204            selected_keys: by,
205            groups,
206            selected_agg,
207        }
208    }
209
210    /// Select the column(s) that should be aggregated.
211    /// You can select a single column or a slice of columns.
212    ///
213    /// Note that making a selection with this method is not required. If you
214    /// skip it all columns (except for the keys) will be selected for aggregation.
215    #[must_use]
216    pub fn select<I: IntoIterator<Item = S>, S: Into<PlSmallStr>>(mut self, selection: I) -> Self {
217        self.selected_agg = Some(selection.into_iter().map(|s| s.into()).collect());
218        self
219    }
220
221    /// Get the internal representation of the GroupBy operation.
222    /// The Vec returned contains:
223    ///     (first_idx, [`Vec<indexes>`])
224    ///     Where second value in the tuple is a vector with all matching indexes.
225    pub fn get_groups(&self) -> &GroupPositions {
226        &self.groups
227    }
228
229    /// Get the internal representation of the GroupBy operation.
230    /// The Vec returned contains:
231    ///     (first_idx, [`Vec<indexes>`])
232    ///     Where second value in the tuple is a vector with all matching indexes.
233    ///
234    /// # Safety
235    /// Groups should always be in bounds of the `DataFrame` hold by this [`GroupBy`].
236    /// If you mutate it, you must hold that invariant.
237    pub unsafe fn get_groups_mut(&mut self) -> &mut GroupPositions {
238        &mut self.groups
239    }
240
241    pub fn into_groups(self) -> GroupPositions {
242        self.groups
243    }
244
245    pub fn keys_sliced(&self, slice: Option<(i64, usize)>) -> Vec<Column> {
246        #[allow(unused_assignments)]
247        // needed to keep the lifetimes valid for this scope
248        let mut groups_owned = None;
249
250        let groups = if let Some((offset, len)) = slice {
251            groups_owned = Some(self.groups.slice(offset, len));
252            groups_owned.as_deref().unwrap()
253        } else {
254            &self.groups
255        };
256        POOL.install(|| {
257            self.selected_keys
258                .par_iter()
259                .map(Column::as_materialized_series)
260                .map(|s| {
261                    match groups {
262                        GroupsType::Idx(groups) => {
263                            // SAFETY: groups are always in bounds.
264                            let mut out = unsafe { s.take_slice_unchecked(groups.first()) };
265                            if groups.sorted {
266                                out.set_sorted_flag(s.is_sorted_flag());
267                            };
268                            out
269                        },
270                        GroupsType::Slice {
271                            groups,
272                            overlapping,
273                            monotonic: _,
274                        } => {
275                            if *overlapping && !groups.is_empty() {
276                                // Groups can be sliced.
277                                let offset = groups[0][0];
278                                let [upper_offset, upper_len] = groups[groups.len() - 1];
279                                return s.slice(
280                                    offset as i64,
281                                    ((upper_offset + upper_len) - offset) as usize,
282                                );
283                            }
284
285                            let indices = groups
286                                .iter()
287                                .map(|&[first, _len]| first)
288                                .collect_ca(PlSmallStr::EMPTY);
289                            // SAFETY: groups are always in bounds.
290                            let mut out = unsafe { s.take_unchecked(&indices) };
291                            // Sliced groups are always in order of discovery.
292                            out.set_sorted_flag(s.is_sorted_flag());
293                            out
294                        },
295                    }
296                })
297                .map(Column::from)
298                .collect()
299        })
300    }
301
302    pub fn keys(&self) -> Vec<Column> {
303        self.keys_sliced(None)
304    }
305
306    fn prepare_agg(&self) -> PolarsResult<(Vec<Column>, Vec<Column>)> {
307        let keys = self.keys();
308
309        let agg_col = match &self.selected_agg {
310            Some(selection) => self.df.select_columns_impl(selection.as_slice()),
311            None => {
312                let by: Vec<_> = self.selected_keys.iter().map(|s| s.name()).collect();
313                let selection = self
314                    .df
315                    .iter()
316                    .map(|s| s.name())
317                    .filter(|a| !by.contains(a))
318                    .cloned()
319                    .collect::<Vec<_>>();
320
321                self.df.select_columns_impl(selection.as_slice())
322            },
323        }?;
324
325        Ok((keys, agg_col))
326    }
327
328    /// Aggregate grouped series and compute the mean per group.
329    ///
330    /// # Example
331    ///
332    /// ```rust
333    /// # use polars_core::prelude::*;
334    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
335    ///     df.group_by(["date"])?.select(["temp", "rain"]).mean()
336    /// }
337    /// ```
338    /// Returns:
339    ///
340    /// ```text
341    /// +------------+-----------+-----------+
342    /// | date       | temp_mean | rain_mean |
343    /// | ---        | ---       | ---       |
344    /// | Date       | f64       | f64       |
345    /// +============+===========+===========+
346    /// | 2020-08-23 | 9         | 0.1       |
347    /// +------------+-----------+-----------+
348    /// | 2020-08-22 | 4         | 0.155     |
349    /// +------------+-----------+-----------+
350    /// | 2020-08-21 | 15        | 0.15      |
351    /// +------------+-----------+-----------+
352    /// ```
353    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
354    pub fn mean(&self) -> PolarsResult<DataFrame> {
355        let (mut cols, agg_cols) = self.prepare_agg()?;
356
357        for agg_col in agg_cols {
358            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Mean);
359            let mut agg = unsafe { agg_col.agg_mean(&self.groups) };
360            agg.rename(new_name);
361            cols.push(agg);
362        }
363        DataFrame::new(cols)
364    }
365
366    /// Aggregate grouped series and compute the sum per group.
367    ///
368    /// # Example
369    ///
370    /// ```rust
371    /// # use polars_core::prelude::*;
372    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
373    ///     df.group_by(["date"])?.select(["temp"]).sum()
374    /// }
375    /// ```
376    /// Returns:
377    ///
378    /// ```text
379    /// +------------+----------+
380    /// | date       | temp_sum |
381    /// | ---        | ---      |
382    /// | Date       | i32      |
383    /// +============+==========+
384    /// | 2020-08-23 | 9        |
385    /// +------------+----------+
386    /// | 2020-08-22 | 8        |
387    /// +------------+----------+
388    /// | 2020-08-21 | 30       |
389    /// +------------+----------+
390    /// ```
391    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
392    pub fn sum(&self) -> PolarsResult<DataFrame> {
393        let (mut cols, agg_cols) = self.prepare_agg()?;
394
395        for agg_col in agg_cols {
396            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Sum);
397            let mut agg = unsafe { agg_col.agg_sum(&self.groups) };
398            agg.rename(new_name);
399            cols.push(agg);
400        }
401        DataFrame::new(cols)
402    }
403
404    /// Aggregate grouped series and compute the minimal value per group.
405    ///
406    /// # Example
407    ///
408    /// ```rust
409    /// # use polars_core::prelude::*;
410    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
411    ///     df.group_by(["date"])?.select(["temp"]).min()
412    /// }
413    /// ```
414    /// Returns:
415    ///
416    /// ```text
417    /// +------------+----------+
418    /// | date       | temp_min |
419    /// | ---        | ---      |
420    /// | Date       | i32      |
421    /// +============+==========+
422    /// | 2020-08-23 | 9        |
423    /// +------------+----------+
424    /// | 2020-08-22 | 1        |
425    /// +------------+----------+
426    /// | 2020-08-21 | 10       |
427    /// +------------+----------+
428    /// ```
429    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
430    pub fn min(&self) -> PolarsResult<DataFrame> {
431        let (mut cols, agg_cols) = self.prepare_agg()?;
432        for agg_col in agg_cols {
433            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Min);
434            let mut agg = unsafe { agg_col.agg_min(&self.groups) };
435            agg.rename(new_name);
436            cols.push(agg);
437        }
438        DataFrame::new(cols)
439    }
440
441    /// Aggregate grouped series and compute the maximum value per group.
442    ///
443    /// # Example
444    ///
445    /// ```rust
446    /// # use polars_core::prelude::*;
447    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
448    ///     df.group_by(["date"])?.select(["temp"]).max()
449    /// }
450    /// ```
451    /// Returns:
452    ///
453    /// ```text
454    /// +------------+----------+
455    /// | date       | temp_max |
456    /// | ---        | ---      |
457    /// | Date       | i32      |
458    /// +============+==========+
459    /// | 2020-08-23 | 9        |
460    /// +------------+----------+
461    /// | 2020-08-22 | 7        |
462    /// +------------+----------+
463    /// | 2020-08-21 | 20       |
464    /// +------------+----------+
465    /// ```
466    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
467    pub fn max(&self) -> PolarsResult<DataFrame> {
468        let (mut cols, agg_cols) = self.prepare_agg()?;
469        for agg_col in agg_cols {
470            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Max);
471            let mut agg = unsafe { agg_col.agg_max(&self.groups) };
472            agg.rename(new_name);
473            cols.push(agg);
474        }
475        DataFrame::new(cols)
476    }
477
478    /// Aggregate grouped `Series` and find the first value per group.
479    ///
480    /// # Example
481    ///
482    /// ```rust
483    /// # use polars_core::prelude::*;
484    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
485    ///     df.group_by(["date"])?.select(["temp"]).first()
486    /// }
487    /// ```
488    /// Returns:
489    ///
490    /// ```text
491    /// +------------+------------+
492    /// | date       | temp_first |
493    /// | ---        | ---        |
494    /// | Date       | i32        |
495    /// +============+============+
496    /// | 2020-08-23 | 9          |
497    /// +------------+------------+
498    /// | 2020-08-22 | 7          |
499    /// +------------+------------+
500    /// | 2020-08-21 | 20         |
501    /// +------------+------------+
502    /// ```
503    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
504    pub fn first(&self) -> PolarsResult<DataFrame> {
505        let (mut cols, agg_cols) = self.prepare_agg()?;
506        for agg_col in agg_cols {
507            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::First);
508            let mut agg = unsafe { agg_col.agg_first(&self.groups) };
509            agg.rename(new_name);
510            cols.push(agg);
511        }
512        DataFrame::new(cols)
513    }
514
515    /// Aggregate grouped `Series` and return the last value per group.
516    ///
517    /// # Example
518    ///
519    /// ```rust
520    /// # use polars_core::prelude::*;
521    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
522    ///     df.group_by(["date"])?.select(["temp"]).last()
523    /// }
524    /// ```
525    /// Returns:
526    ///
527    /// ```text
528    /// +------------+------------+
529    /// | date       | temp_last |
530    /// | ---        | ---        |
531    /// | Date       | i32        |
532    /// +============+============+
533    /// | 2020-08-23 | 9          |
534    /// +------------+------------+
535    /// | 2020-08-22 | 1          |
536    /// +------------+------------+
537    /// | 2020-08-21 | 10         |
538    /// +------------+------------+
539    /// ```
540    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
541    pub fn last(&self) -> PolarsResult<DataFrame> {
542        let (mut cols, agg_cols) = self.prepare_agg()?;
543        for agg_col in agg_cols {
544            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Last);
545            let mut agg = unsafe { agg_col.agg_last(&self.groups) };
546            agg.rename(new_name);
547            cols.push(agg);
548        }
549        DataFrame::new(cols)
550    }
551
552    /// Aggregate grouped `Series` by counting the number of unique values.
553    ///
554    /// # Example
555    ///
556    /// ```rust
557    /// # use polars_core::prelude::*;
558    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
559    ///     df.group_by(["date"])?.select(["temp"]).n_unique()
560    /// }
561    /// ```
562    /// Returns:
563    ///
564    /// ```text
565    /// +------------+---------------+
566    /// | date       | temp_n_unique |
567    /// | ---        | ---           |
568    /// | Date       | u32           |
569    /// +============+===============+
570    /// | 2020-08-23 | 1             |
571    /// +------------+---------------+
572    /// | 2020-08-22 | 2             |
573    /// +------------+---------------+
574    /// | 2020-08-21 | 2             |
575    /// +------------+---------------+
576    /// ```
577    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
578    pub fn n_unique(&self) -> PolarsResult<DataFrame> {
579        let (mut cols, agg_cols) = self.prepare_agg()?;
580        for agg_col in agg_cols {
581            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::NUnique);
582            let mut agg = unsafe { agg_col.agg_n_unique(&self.groups) };
583            agg.rename(new_name);
584            cols.push(agg);
585        }
586        DataFrame::new(cols)
587    }
588
589    /// Aggregate grouped [`Series`] and determine the quantile per group.
590    ///
591    /// # Example
592    ///
593    /// ```rust
594    /// # use polars_core::prelude::*;
595    ///
596    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
597    ///     df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileMethod::default())
598    /// }
599    /// ```
600    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
601    pub fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<DataFrame> {
602        polars_ensure!(
603            (0.0..=1.0).contains(&quantile),
604            ComputeError: "`quantile` should be within 0.0 and 1.0"
605        );
606        let (mut cols, agg_cols) = self.prepare_agg()?;
607        for agg_col in agg_cols {
608            let new_name = fmt_group_by_column(
609                agg_col.name().as_str(),
610                GroupByMethod::Quantile(quantile, method),
611            );
612            let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, method) };
613            agg.rename(new_name);
614            cols.push(agg);
615        }
616        DataFrame::new(cols)
617    }
618
619    /// Aggregate grouped [`Series`] and determine the median per group.
620    ///
621    /// # Example
622    ///
623    /// ```rust
624    /// # use polars_core::prelude::*;
625    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
626    ///     df.group_by(["date"])?.select(["temp"]).median()
627    /// }
628    /// ```
629    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
630    pub fn median(&self) -> PolarsResult<DataFrame> {
631        let (mut cols, agg_cols) = self.prepare_agg()?;
632        for agg_col in agg_cols {
633            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Median);
634            let mut agg = unsafe { agg_col.agg_median(&self.groups) };
635            agg.rename(new_name);
636            cols.push(agg);
637        }
638        DataFrame::new(cols)
639    }
640
641    /// Aggregate grouped [`Series`] and determine the variance per group.
642    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
643    pub fn var(&self, ddof: u8) -> PolarsResult<DataFrame> {
644        let (mut cols, agg_cols) = self.prepare_agg()?;
645        for agg_col in agg_cols {
646            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Var(ddof));
647            let mut agg = unsafe { agg_col.agg_var(&self.groups, ddof) };
648            agg.rename(new_name);
649            cols.push(agg);
650        }
651        DataFrame::new(cols)
652    }
653
654    /// Aggregate grouped [`Series`] and determine the standard deviation per group.
655    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
656    pub fn std(&self, ddof: u8) -> PolarsResult<DataFrame> {
657        let (mut cols, agg_cols) = self.prepare_agg()?;
658        for agg_col in agg_cols {
659            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Std(ddof));
660            let mut agg = unsafe { agg_col.agg_std(&self.groups, ddof) };
661            agg.rename(new_name);
662            cols.push(agg);
663        }
664        DataFrame::new(cols)
665    }
666
667    /// Aggregate grouped series and compute the number of values per group.
668    ///
669    /// # Example
670    ///
671    /// ```rust
672    /// # use polars_core::prelude::*;
673    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
674    ///     df.group_by(["date"])?.select(["temp"]).count()
675    /// }
676    /// ```
677    /// Returns:
678    ///
679    /// ```text
680    /// +------------+------------+
681    /// | date       | temp_count |
682    /// | ---        | ---        |
683    /// | Date       | u32        |
684    /// +============+============+
685    /// | 2020-08-23 | 1          |
686    /// +------------+------------+
687    /// | 2020-08-22 | 2          |
688    /// +------------+------------+
689    /// | 2020-08-21 | 2          |
690    /// +------------+------------+
691    /// ```
692    pub fn count(&self) -> PolarsResult<DataFrame> {
693        let (mut cols, agg_cols) = self.prepare_agg()?;
694
695        for agg_col in agg_cols {
696            let new_name = fmt_group_by_column(
697                agg_col.name().as_str(),
698                GroupByMethod::Count {
699                    include_nulls: true,
700                },
701            );
702            let mut ca = self.groups.group_count();
703            ca.rename(new_name);
704            cols.push(ca.into_column());
705        }
706        DataFrame::new(cols)
707    }
708
709    /// Get the group_by group indexes.
710    ///
711    /// # Example
712    ///
713    /// ```rust
714    /// # use polars_core::prelude::*;
715    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
716    ///     df.group_by(["date"])?.groups()
717    /// }
718    /// ```
719    /// Returns:
720    ///
721    /// ```text
722    /// +--------------+------------+
723    /// | date         | groups     |
724    /// | ---          | ---        |
725    /// | Date(days)   | list [u32] |
726    /// +==============+============+
727    /// | 2020-08-23   | "[3]"      |
728    /// +--------------+------------+
729    /// | 2020-08-22   | "[2, 4]"   |
730    /// +--------------+------------+
731    /// | 2020-08-21   | "[0, 1]"   |
732    /// +--------------+------------+
733    /// ```
734    pub fn groups(&self) -> PolarsResult<DataFrame> {
735        let mut cols = self.keys();
736        let mut column = self.groups.as_list_chunked();
737        let new_name = fmt_group_by_column("", GroupByMethod::Groups);
738        column.rename(new_name);
739        cols.push(column.into_column());
740        DataFrame::new(cols)
741    }
742
743    /// Aggregate the groups of the group_by operation into lists.
744    ///
745    /// # Example
746    ///
747    /// ```rust
748    /// # use polars_core::prelude::*;
749    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
750    ///     // GroupBy and aggregate to Lists
751    ///     df.group_by(["date"])?.select(["temp"]).agg_list()
752    /// }
753    /// ```
754    /// Returns:
755    ///
756    /// ```text
757    /// +------------+------------------------+
758    /// | date       | temp_agg_list          |
759    /// | ---        | ---                    |
760    /// | Date       | list [i32]             |
761    /// +============+========================+
762    /// | 2020-08-23 | "[Some(9)]"            |
763    /// +------------+------------------------+
764    /// | 2020-08-22 | "[Some(7), Some(1)]"   |
765    /// +------------+------------------------+
766    /// | 2020-08-21 | "[Some(20), Some(10)]" |
767    /// +------------+------------------------+
768    /// ```
769    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
770    pub fn agg_list(&self) -> PolarsResult<DataFrame> {
771        let (mut cols, agg_cols) = self.prepare_agg()?;
772        for agg_col in agg_cols {
773            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Implode);
774            let mut agg = unsafe { agg_col.agg_list(&self.groups) };
775            agg.rename(new_name);
776            cols.push(agg);
777        }
778        DataFrame::new(cols)
779    }
780
781    fn prepare_apply(&self) -> PolarsResult<DataFrame> {
782        polars_ensure!(self.df.height() > 0, ComputeError: "cannot group_by + apply on empty 'DataFrame'");
783        if let Some(agg) = &self.selected_agg {
784            if agg.is_empty() {
785                Ok(self.df.clone())
786            } else {
787                let mut new_cols = Vec::with_capacity(self.selected_keys.len() + agg.len());
788                new_cols.extend_from_slice(&self.selected_keys);
789                let cols = self.df.select_columns_impl(agg.as_slice())?;
790                new_cols.extend(cols);
791                Ok(unsafe { DataFrame::new_no_checks(self.df.height(), new_cols) })
792            }
793        } else {
794            Ok(self.df.clone())
795        }
796    }
797
798    /// Apply a closure over the groups as a new [`DataFrame`] in parallel.
799    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
800    pub fn par_apply<F>(&self, f: F) -> PolarsResult<DataFrame>
801    where
802        F: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
803    {
804        let df = self.prepare_apply()?;
805        let dfs = self
806            .get_groups()
807            .par_iter()
808            .map(|g| {
809                // SAFETY:
810                // groups are in bounds
811                let sub_df = unsafe { take_df(&df, g) };
812                f(sub_df)
813            })
814            .collect::<PolarsResult<Vec<_>>>()?;
815
816        let mut df = accumulate_dataframes_vertical(dfs)?;
817        df.as_single_chunk_par();
818        Ok(df)
819    }
820
821    /// Apply a closure over the groups as a new [`DataFrame`].
822    pub fn apply<F>(&self, f: F) -> PolarsResult<DataFrame>
823    where
824        F: FnMut(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
825    {
826        self.apply_sliced(None, f)
827    }
828
829    pub fn apply_sliced<F>(&self, slice: Option<(i64, usize)>, mut f: F) -> PolarsResult<DataFrame>
830    where
831        F: FnMut(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
832    {
833        let df = self.prepare_apply()?;
834        let max_height = if let Some((offset, len)) = slice {
835            offset.try_into().unwrap_or(usize::MAX).saturating_add(len)
836        } else {
837            usize::MAX
838        };
839        let mut height = 0;
840        let mut dfs = Vec::with_capacity(self.get_groups().len());
841        for g in self.get_groups().iter() {
842            // SAFETY: groups are in bounds.
843            let sub_df = unsafe { take_df(&df, g) };
844            let df = f(sub_df)?;
845            height += df.height();
846            dfs.push(df);
847
848            // Even if max_height is zero we need at least one df, so check
849            // after first push.
850            if height >= max_height {
851                break;
852            }
853        }
854
855        let mut df = accumulate_dataframes_vertical(dfs)?;
856        if let Some((offset, len)) = slice {
857            df = df.slice(offset, len);
858        }
859        Ok(df)
860    }
861
862    pub fn sliced(mut self, slice: Option<(i64, usize)>) -> Self {
863        match slice {
864            None => self,
865            Some((offset, length)) => {
866                self.groups = self.groups.slice(offset, length);
867                self.selected_keys = self.keys_sliced(slice);
868                self
869            },
870        }
871    }
872}
873
874unsafe fn take_df(df: &DataFrame, g: GroupsIndicator) -> DataFrame {
875    match g {
876        GroupsIndicator::Idx(idx) => df.take_slice_unchecked(idx.1),
877        GroupsIndicator::Slice([first, len]) => df.slice(first as i64, len as usize),
878    }
879}
880
881#[derive(Copy, Clone, Debug)]
882pub enum GroupByMethod {
883    Min,
884    NanMin,
885    Max,
886    NanMax,
887    Median,
888    Mean,
889    First,
890    FirstNonNull,
891    Last,
892    LastNonNull,
893    Item { allow_empty: bool },
894    Sum,
895    Groups,
896    NUnique,
897    Quantile(f64, QuantileMethod),
898    Count { include_nulls: bool },
899    Implode,
900    Std(u8),
901    Var(u8),
902}
903
904impl Display for GroupByMethod {
905    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
906        use GroupByMethod::*;
907        let s = match self {
908            Min => "min",
909            NanMin => "nan_min",
910            Max => "max",
911            NanMax => "nan_max",
912            Median => "median",
913            Mean => "mean",
914            First => "first",
915            FirstNonNull => "first_non_null",
916            Last => "last",
917            LastNonNull => "last_non_null",
918            Item { .. } => "item",
919            Sum => "sum",
920            Groups => "groups",
921            NUnique => "n_unique",
922            Quantile(_, _) => "quantile",
923            Count { .. } => "count",
924            Implode => "list",
925            Std(_) => "std",
926            Var(_) => "var",
927        };
928        write!(f, "{s}")
929    }
930}
931
932// Formatting functions used in eager and lazy code for renaming grouped columns
933pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> PlSmallStr {
934    use GroupByMethod::*;
935    match method {
936        Min => format_pl_smallstr!("{name}_min"),
937        Max => format_pl_smallstr!("{name}_max"),
938        NanMin => format_pl_smallstr!("{name}_nan_min"),
939        NanMax => format_pl_smallstr!("{name}_nan_max"),
940        Median => format_pl_smallstr!("{name}_median"),
941        Mean => format_pl_smallstr!("{name}_mean"),
942        First => format_pl_smallstr!("{name}_first"),
943        FirstNonNull => format_pl_smallstr!("{name}_first_non_null"),
944        Last => format_pl_smallstr!("{name}_last"),
945        LastNonNull => format_pl_smallstr!("{name}_last_non_null"),
946        Item { .. } => format_pl_smallstr!("{name}_item"),
947        Sum => format_pl_smallstr!("{name}_sum"),
948        Groups => PlSmallStr::from_static("groups"),
949        NUnique => format_pl_smallstr!("{name}_n_unique"),
950        Count { .. } => format_pl_smallstr!("{name}_count"),
951        Implode => format_pl_smallstr!("{name}_agg_list"),
952        Quantile(quantile, _interpol) => format_pl_smallstr!("{name}_quantile_{quantile:.2}"),
953        Std(_) => format_pl_smallstr!("{name}_agg_std"),
954        Var(_) => format_pl_smallstr!("{name}_agg_var"),
955    }
956}
957
958#[cfg(test)]
959mod test {
960    use num_traits::FloatConst;
961
962    use crate::prelude::*;
963
964    #[test]
965    #[cfg(feature = "dtype-date")]
966    #[cfg_attr(miri, ignore)]
967    fn test_group_by() -> PolarsResult<()> {
968        let s0 = Column::new(
969            PlSmallStr::from_static("date"),
970            &[
971                "2020-08-21",
972                "2020-08-21",
973                "2020-08-22",
974                "2020-08-23",
975                "2020-08-22",
976            ],
977        );
978        let s1 = Column::new(PlSmallStr::from_static("temp"), [20, 10, 7, 9, 1]);
979        let s2 = Column::new(PlSmallStr::from_static("rain"), [0.2, 0.1, 0.3, 0.1, 0.01]);
980        let df = DataFrame::new(vec![s0, s1, s2]).unwrap();
981
982        let out = df.group_by_stable(["date"])?.select(["temp"]).count()?;
983        assert_eq!(
984            out.column("temp_count")?,
985            &Column::new(PlSmallStr::from_static("temp_count"), [2 as IdxSize, 2, 1])
986        );
987
988        // Use of deprecated mean() for testing purposes
989        #[allow(deprecated)]
990        // Select multiple
991        let out = df
992            .group_by_stable(["date"])?
993            .select(["temp", "rain"])
994            .mean()?;
995        assert_eq!(
996            out.column("temp_mean")?,
997            &Column::new(PlSmallStr::from_static("temp_mean"), [15.0f64, 4.0, 9.0])
998        );
999
1000        // Use of deprecated `mean()` for testing purposes
1001        #[allow(deprecated)]
1002        // Group by multiple
1003        let out = df
1004            .group_by_stable(["date", "temp"])?
1005            .select(["rain"])
1006            .mean()?;
1007        assert!(out.column("rain_mean").is_ok());
1008
1009        // Use of deprecated `sum()` for testing purposes
1010        #[allow(deprecated)]
1011        let out = df.group_by_stable(["date"])?.select(["temp"]).sum()?;
1012        assert_eq!(
1013            out.column("temp_sum")?,
1014            &Column::new(PlSmallStr::from_static("temp_sum"), [30, 8, 9])
1015        );
1016
1017        // Use of deprecated `n_unique()` for testing purposes
1018        #[allow(deprecated)]
1019        // implicit select all and only aggregate on methods that support that aggregation
1020        let gb = df.group_by(["date"]).unwrap().n_unique().unwrap();
1021        // check the group by column is filtered out.
1022        assert_eq!(gb.width(), 3);
1023        Ok(())
1024    }
1025
1026    #[test]
1027    #[cfg_attr(miri, ignore)]
1028    fn test_static_group_by_by_12_columns() {
1029        // Build GroupBy DataFrame.
1030        let s0 = Column::new("G1".into(), ["A", "A", "B", "B", "C"].as_ref());
1031        let s1 = Column::new("N".into(), [1, 2, 2, 4, 2].as_ref());
1032        let s2 = Column::new("G2".into(), ["k", "l", "m", "m", "l"].as_ref());
1033        let s3 = Column::new("G3".into(), ["a", "b", "c", "c", "d"].as_ref());
1034        let s4 = Column::new("G4".into(), ["1", "2", "3", "3", "4"].as_ref());
1035        let s5 = Column::new("G5".into(), ["X", "Y", "Z", "Z", "W"].as_ref());
1036        let s6 = Column::new("G6".into(), [false, true, true, true, false].as_ref());
1037        let s7 = Column::new("G7".into(), ["r", "x", "q", "q", "o"].as_ref());
1038        let s8 = Column::new("G8".into(), ["R", "X", "Q", "Q", "O"].as_ref());
1039        let s9 = Column::new("G9".into(), [1, 2, 3, 3, 4].as_ref());
1040        let s10 = Column::new("G10".into(), [".", "!", "?", "?", "/"].as_ref());
1041        let s11 = Column::new("G11".into(), ["(", ")", "@", "@", "$"].as_ref());
1042        let s12 = Column::new("G12".into(), ["-", "_", ";", ";", ","].as_ref());
1043
1044        let df =
1045            DataFrame::new(vec![s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]).unwrap();
1046
1047        // Use of deprecated `sum()` for testing purposes
1048        #[allow(deprecated)]
1049        let adf = df
1050            .group_by([
1051                "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12",
1052            ])
1053            .unwrap()
1054            .select(["N"])
1055            .sum()
1056            .unwrap();
1057
1058        assert_eq!(
1059            Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1060            &[Some(1), Some(2), Some(2), Some(6)]
1061        );
1062    }
1063
1064    #[test]
1065    #[cfg_attr(miri, ignore)]
1066    fn test_dynamic_group_by_by_13_columns() {
1067        // The content for every group_by series.
1068        let series_content = ["A", "A", "B", "B", "C"];
1069
1070        // The name of every group_by series.
1071        let series_names = [
1072            "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12", "G13",
1073        ];
1074
1075        // Vector to contain every series.
1076        let mut columns = Vec::with_capacity(14);
1077
1078        // Create a series for every group name.
1079        for series_name in series_names {
1080            let group_columns = Column::new(series_name.into(), series_content.as_ref());
1081            columns.push(group_columns);
1082        }
1083
1084        // Create a series for the aggregation column.
1085        let agg_series = Column::new("N".into(), [1, 2, 3, 3, 4].as_ref());
1086        columns.push(agg_series);
1087
1088        // Create the dataframe with the computed series.
1089        let df = DataFrame::new(columns).unwrap();
1090
1091        // Use of deprecated `sum()` for testing purposes
1092        #[allow(deprecated)]
1093        // Compute the aggregated DataFrame by the 13 columns defined in `series_names`.
1094        let adf = df
1095            .group_by(series_names)
1096            .unwrap()
1097            .select(["N"])
1098            .sum()
1099            .unwrap();
1100
1101        // Check that the results of the group-by are correct. The content of every column
1102        // is equal, then, the grouped columns shall be equal and in the same order.
1103        for series_name in &series_names {
1104            assert_eq!(
1105                Vec::from(&adf.column(series_name).unwrap().str().unwrap().sort(false)),
1106                &[Some("A"), Some("B"), Some("C")]
1107            );
1108        }
1109
1110        // Check the aggregated column is the expected one.
1111        assert_eq!(
1112            Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1113            &[Some(3), Some(4), Some(6)]
1114        );
1115    }
1116
1117    #[test]
1118    #[cfg_attr(miri, ignore)]
1119    fn test_group_by_floats() {
1120        let df = df! {"flt" => [1., 1., 2., 2., 3.],
1121                    "val" => [1, 1, 1, 1, 1]
1122        }
1123        .unwrap();
1124        // Use of deprecated `sum()` for testing purposes
1125        #[allow(deprecated)]
1126        let res = df.group_by(["flt"]).unwrap().sum().unwrap();
1127        let res = res.sort(["flt"], SortMultipleOptions::default()).unwrap();
1128        assert_eq!(
1129            Vec::from(res.column("val_sum").unwrap().i32().unwrap()),
1130            &[Some(2), Some(2), Some(1)]
1131        );
1132    }
1133
1134    #[test]
1135    #[cfg_attr(miri, ignore)]
1136    #[cfg(feature = "dtype-categorical")]
1137    fn test_group_by_categorical() {
1138        let mut df = df! {"foo" => ["a", "a", "b", "b", "c"],
1139                    "ham" => ["a", "a", "b", "b", "c"],
1140                    "bar" => [1, 1, 1, 1, 1]
1141        }
1142        .unwrap();
1143
1144        df.apply("foo", |s| {
1145            s.cast(&DataType::from_categories(Categories::global()))
1146                .unwrap()
1147        })
1148        .unwrap();
1149
1150        // Use of deprecated `sum()` for testing purposes
1151        #[allow(deprecated)]
1152        // check multiple keys and categorical
1153        let res = df
1154            .group_by_stable(["foo", "ham"])
1155            .unwrap()
1156            .select(["bar"])
1157            .sum()
1158            .unwrap();
1159
1160        assert_eq!(
1161            Vec::from(
1162                res.column("bar_sum")
1163                    .unwrap()
1164                    .as_materialized_series()
1165                    .i32()
1166                    .unwrap()
1167            ),
1168            &[Some(2), Some(2), Some(1)]
1169        );
1170    }
1171
1172    #[test]
1173    #[cfg_attr(miri, ignore)]
1174    fn test_group_by_null_handling() -> PolarsResult<()> {
1175        let df = df!(
1176            "a" => ["a", "a", "a", "b", "b"],
1177            "b" => [Some(1), Some(2), None, None, Some(1)]
1178        )?;
1179        // Use of deprecated `mean()` for testing purposes
1180        #[allow(deprecated)]
1181        let out = df.group_by_stable(["a"])?.mean()?;
1182
1183        assert_eq!(
1184            Vec::from(out.column("b_mean")?.as_materialized_series().f64()?),
1185            &[Some(1.5), Some(1.0)]
1186        );
1187        Ok(())
1188    }
1189
1190    #[test]
1191    #[cfg_attr(miri, ignore)]
1192    fn test_group_by_var() -> PolarsResult<()> {
1193        // check variance and proper coercion to f64
1194        let df = df![
1195            "g" => ["foo", "foo", "bar"],
1196            "flt" => [1.0, 2.0, 3.0],
1197            "int" => [1, 2, 3]
1198        ]?;
1199
1200        // Use of deprecated `sum()` for testing purposes
1201        #[allow(deprecated)]
1202        let out = df.group_by_stable(["g"])?.select(["int"]).var(1)?;
1203
1204        assert_eq!(out.column("int_agg_var")?.f64()?.get(0), Some(0.5));
1205        // Use of deprecated `std()` for testing purposes
1206        #[allow(deprecated)]
1207        let out = df.group_by_stable(["g"])?.select(["int"]).std(1)?;
1208        let val = out.column("int_agg_std")?.f64()?.get(0).unwrap();
1209        let expected = f64::FRAC_1_SQRT_2();
1210        assert!((val - expected).abs() < 0.000001);
1211        Ok(())
1212    }
1213
1214    #[test]
1215    #[cfg_attr(miri, ignore)]
1216    #[cfg(feature = "dtype-categorical")]
1217    fn test_group_by_null_group() -> PolarsResult<()> {
1218        // check if null is own group
1219        let mut df = df![
1220            "g" => [Some("foo"), Some("foo"), Some("bar"), None, None],
1221            "flt" => [1.0, 2.0, 3.0, 1.0, 1.0],
1222            "int" => [1, 2, 3, 1, 1]
1223        ]?;
1224
1225        df.try_apply("g", |s| {
1226            s.cast(&DataType::from_categories(Categories::global()))
1227        })?;
1228
1229        // Use of deprecated `sum()` for testing purposes
1230        #[allow(deprecated)]
1231        let _ = df.group_by(["g"])?.sum()?;
1232        Ok(())
1233    }
1234}