Skip to main content

polars_core/frame/group_by/
mod.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3
4use num_traits::NumCast;
5use polars_compute::rolling::QuantileMethod;
6use polars_utils::format_pl_smallstr;
7use polars_utils::hashing::DirtyHash;
8use rayon::prelude::*;
9
10use self::hashing::*;
11use crate::POOL;
12use crate::prelude::*;
13use crate::utils::{_set_partition_size, accumulate_dataframes_vertical};
14
15pub mod aggregations;
16pub(crate) mod hashing;
17mod into_groups;
18mod position;
19
20pub use into_groups::*;
21pub use position::*;
22
23use crate::chunked_array::ops::row_encode::{
24    encode_rows_unordered, encode_rows_vertical_par_unordered,
25};
26
27impl DataFrame {
28    pub fn group_by_with_series(
29        &self,
30        mut by: Vec<Column>,
31        multithreaded: bool,
32        sorted: bool,
33    ) -> PolarsResult<GroupBy<'_>> {
34        polars_ensure!(
35            !by.is_empty(),
36            ComputeError: "at least one key is required in a group_by operation"
37        );
38
39        // Ensure all 'by' columns have the same common_height
40        // The condition self.width > 0 ensures we can still call this on a
41        // dummy dataframe where we provide the keys
42        let common_height = if self.width() > 0 {
43            self.height()
44        } else {
45            by.iter().map(|s| s.len()).max().expect("at least 1 key")
46        };
47        for by_key in by.iter_mut() {
48            if by_key.len() != common_height {
49                polars_ensure!(
50                    by_key.len() == 1,
51                    ShapeMismatch: "series used as keys should have the same length as the DataFrame"
52                );
53                *by_key = by_key.new_from_index(0, common_height)
54            }
55        }
56
57        let groups = if by.len() == 1 {
58            let column = &by[0];
59            column
60                .as_materialized_series()
61                .group_tuples(multithreaded, sorted)
62        } else if by.iter().any(|s| s.dtype().is_object()) {
63            #[cfg(feature = "object")]
64            {
65                let mut df = DataFrame::new(self.height(), by.clone()).unwrap();
66                let n = df.height();
67                let rows = df.to_av_rows();
68                let iter = (0..n).map(|i| rows.get(i));
69                Ok(group_by(iter, sorted))
70            }
71            #[cfg(not(feature = "object"))]
72            {
73                unreachable!()
74            }
75        } else {
76            // Skip null dtype.
77            let by = by
78                .iter()
79                .filter(|s| !s.dtype().is_null())
80                .cloned()
81                .collect::<Vec<_>>();
82            if by.is_empty() {
83                let groups = if self.height() == 0 {
84                    vec![]
85                } else {
86                    vec![[0, self.height() as IdxSize]]
87                };
88
89                Ok(GroupsType::new_slice(groups, false, true))
90            } else {
91                let rows = if multithreaded {
92                    encode_rows_vertical_par_unordered(&by)
93                } else {
94                    encode_rows_unordered(&by)
95                }?
96                .into_series();
97                rows.group_tuples(multithreaded, sorted)
98            }
99        };
100        Ok(GroupBy::new(self, by, groups?.into_sliceable(), None))
101    }
102
103    /// Group DataFrame using a Series column.
104    ///
105    /// # Example
106    ///
107    /// ```
108    /// use polars_core::prelude::*;
109    /// fn group_by_sum(df: &DataFrame) -> PolarsResult<DataFrame> {
110    ///     df.group_by(["column_name"])?
111    ///     .select(["agg_column_name"])
112    ///     .sum()
113    /// }
114    /// ```
115    pub fn group_by<I, S>(&self, by: I) -> PolarsResult<GroupBy<'_>>
116    where
117        I: IntoIterator<Item = S>,
118        S: AsRef<str>,
119    {
120        let selected_keys = self.select_to_vec(by)?;
121        self.group_by_with_series(selected_keys, true, false)
122    }
123
124    /// Group DataFrame using a Series column.
125    /// The groups are ordered by their smallest row index.
126    pub fn group_by_stable<I, S>(&self, by: I) -> PolarsResult<GroupBy<'_>>
127    where
128        I: IntoIterator<Item = S>,
129        S: AsRef<str>,
130    {
131        let selected_keys = self.select_to_vec(by)?;
132        self.group_by_with_series(selected_keys, true, true)
133    }
134}
135
136/// Returned by a group_by operation on a DataFrame. This struct supports
137/// several aggregations.
138///
139/// Until described otherwise, the examples in this struct are performed on the following DataFrame:
140///
141/// ```ignore
142/// use polars_core::prelude::*;
143///
144/// let dates = &[
145/// "2020-08-21",
146/// "2020-08-21",
147/// "2020-08-22",
148/// "2020-08-23",
149/// "2020-08-22",
150/// ];
151/// // date format
152/// let fmt = "%Y-%m-%d";
153/// // create date series
154/// let s0 = DateChunked::parse_from_str_slice("date", dates, fmt)
155///         .into_series();
156/// // create temperature series
157/// let s1 = Series::new("temp".into(), [20, 10, 7, 9, 1]);
158/// // create rain series
159/// let s2 = Series::new("rain".into(), [0.2, 0.1, 0.3, 0.1, 0.01]);
160/// // create a new DataFrame
161/// let df = DataFrame::new_infer_height(vec![s0, s1, s2]).unwrap();
162/// println!("{:?}", df);
163/// ```
164///
165/// Outputs:
166///
167/// ```text
168/// +------------+------+------+
169/// | date       | temp | rain |
170/// | ---        | ---  | ---  |
171/// | Date       | i32  | f64  |
172/// +============+======+======+
173/// | 2020-08-21 | 20   | 0.2  |
174/// +------------+------+------+
175/// | 2020-08-21 | 10   | 0.1  |
176/// +------------+------+------+
177/// | 2020-08-22 | 7    | 0.3  |
178/// +------------+------+------+
179/// | 2020-08-23 | 9    | 0.1  |
180/// +------------+------+------+
181/// | 2020-08-22 | 1    | 0.01 |
182/// +------------+------+------+
183/// ```
184///
185#[derive(Debug, Clone)]
186pub struct GroupBy<'a> {
187    pub df: &'a DataFrame,
188    pub(crate) selected_keys: Vec<Column>,
189    // [first idx, [other idx]]
190    groups: GroupPositions,
191    // columns selected for aggregation
192    pub(crate) selected_agg: Option<Vec<PlSmallStr>>,
193}
194
195impl<'a> GroupBy<'a> {
196    pub fn new(
197        df: &'a DataFrame,
198        by: Vec<Column>,
199        groups: GroupPositions,
200        selected_agg: Option<Vec<PlSmallStr>>,
201    ) -> Self {
202        GroupBy {
203            df,
204            selected_keys: by,
205            groups,
206            selected_agg,
207        }
208    }
209
210    /// Select the column(s) that should be aggregated.
211    /// You can select a single column or a slice of columns.
212    ///
213    /// Note that making a selection with this method is not required. If you
214    /// skip it all columns (except for the keys) will be selected for aggregation.
215    #[must_use]
216    pub fn select<I: IntoIterator<Item = S>, S: Into<PlSmallStr>>(mut self, selection: I) -> Self {
217        self.selected_agg = Some(selection.into_iter().map(|s| s.into()).collect());
218        self
219    }
220
221    /// Get the internal representation of the GroupBy operation.
222    /// The Vec returned contains:
223    ///     (first_idx, [`Vec<indexes>`])
224    ///     Where second value in the tuple is a vector with all matching indexes.
225    pub fn get_groups(&self) -> &GroupPositions {
226        &self.groups
227    }
228
229    /// Get the internal representation of the GroupBy operation.
230    /// The Vec returned contains:
231    ///     (first_idx, [`Vec<indexes>`])
232    ///     Where second value in the tuple is a vector with all matching indexes.
233    ///
234    /// # Safety
235    /// Groups should always be in bounds of the `DataFrame` hold by this [`GroupBy`].
236    /// If you mutate it, you must hold that invariant.
237    pub unsafe fn get_groups_mut(&mut self) -> &mut GroupPositions {
238        &mut self.groups
239    }
240
241    pub fn into_groups(self) -> GroupPositions {
242        self.groups
243    }
244
245    pub fn keys_sliced(&self, slice: Option<(i64, usize)>) -> Vec<Column> {
246        #[allow(unused_assignments)]
247        // needed to keep the lifetimes valid for this scope
248        let mut groups_owned = None;
249
250        let groups = if let Some((offset, len)) = slice {
251            groups_owned = Some(self.groups.slice(offset, len));
252            groups_owned.as_deref().unwrap()
253        } else {
254            &self.groups
255        };
256        POOL.install(|| {
257            self.selected_keys
258                .par_iter()
259                .map(Column::as_materialized_series)
260                .map(|s| {
261                    match groups {
262                        GroupsType::Idx(groups) => {
263                            // SAFETY: groups are always in bounds.
264                            let mut out = unsafe { s.take_slice_unchecked(groups.first()) };
265                            if groups.sorted {
266                                out.set_sorted_flag(s.is_sorted_flag());
267                            };
268                            out
269                        },
270                        GroupsType::Slice {
271                            groups,
272                            overlapping,
273                            monotonic: _,
274                        } => {
275                            if *overlapping && !groups.is_empty() {
276                                // Groups can be sliced.
277                                let offset = groups[0][0];
278                                let [upper_offset, upper_len] = groups[groups.len() - 1];
279                                return s.slice(
280                                    offset as i64,
281                                    ((upper_offset + upper_len) - offset) as usize,
282                                );
283                            }
284
285                            let indices = groups
286                                .iter()
287                                .map(|&[first, _len]| first)
288                                .collect_ca(PlSmallStr::EMPTY);
289                            // SAFETY: groups are always in bounds.
290                            let mut out = unsafe { s.take_unchecked(&indices) };
291                            // Sliced groups are always in order of discovery.
292                            out.set_sorted_flag(s.is_sorted_flag());
293                            out
294                        },
295                    }
296                })
297                .map(Column::from)
298                .collect()
299        })
300    }
301
302    pub fn keys(&self) -> Vec<Column> {
303        self.keys_sliced(None)
304    }
305
306    fn prepare_agg(&self) -> PolarsResult<(Vec<Column>, Vec<Column>)> {
307        let keys = self.keys();
308
309        let agg_col = match &self.selected_agg {
310            Some(selection) => self.df.select_to_vec(selection),
311            None => {
312                let by: Vec<_> = self.selected_keys.iter().map(|s| s.name()).collect();
313                let selection = self
314                    .df
315                    .columns()
316                    .iter()
317                    .map(|s| s.name())
318                    .filter(|a| !by.contains(a))
319                    .cloned()
320                    .collect::<Vec<_>>();
321
322                self.df.select_to_vec(selection.as_slice())
323            },
324        }?;
325
326        Ok((keys, agg_col))
327    }
328
329    /// Aggregate grouped series and compute the mean per group.
330    ///
331    /// # Example
332    ///
333    /// ```rust
334    /// # use polars_core::prelude::*;
335    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
336    ///     df.group_by(["date"])?.select(["temp", "rain"]).mean()
337    /// }
338    /// ```
339    /// Returns:
340    ///
341    /// ```text
342    /// +------------+-----------+-----------+
343    /// | date       | temp_mean | rain_mean |
344    /// | ---        | ---       | ---       |
345    /// | Date       | f64       | f64       |
346    /// +============+===========+===========+
347    /// | 2020-08-23 | 9         | 0.1       |
348    /// +------------+-----------+-----------+
349    /// | 2020-08-22 | 4         | 0.155     |
350    /// +------------+-----------+-----------+
351    /// | 2020-08-21 | 15        | 0.15      |
352    /// +------------+-----------+-----------+
353    /// ```
354    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
355    pub fn mean(&self) -> PolarsResult<DataFrame> {
356        let (mut cols, agg_cols) = self.prepare_agg()?;
357
358        for agg_col in agg_cols {
359            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Mean);
360            let mut agg = unsafe { agg_col.agg_mean(&self.groups) };
361            agg.rename(new_name);
362            cols.push(agg);
363        }
364
365        DataFrame::new_infer_height(cols)
366    }
367
368    /// Aggregate grouped series and compute the sum per group.
369    ///
370    /// # Example
371    ///
372    /// ```rust
373    /// # use polars_core::prelude::*;
374    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
375    ///     df.group_by(["date"])?.select(["temp"]).sum()
376    /// }
377    /// ```
378    /// Returns:
379    ///
380    /// ```text
381    /// +------------+----------+
382    /// | date       | temp_sum |
383    /// | ---        | ---      |
384    /// | Date       | i32      |
385    /// +============+==========+
386    /// | 2020-08-23 | 9        |
387    /// +------------+----------+
388    /// | 2020-08-22 | 8        |
389    /// +------------+----------+
390    /// | 2020-08-21 | 30       |
391    /// +------------+----------+
392    /// ```
393    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
394    pub fn sum(&self) -> PolarsResult<DataFrame> {
395        let (mut cols, agg_cols) = self.prepare_agg()?;
396
397        for agg_col in agg_cols {
398            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Sum);
399            let mut agg = unsafe { agg_col.agg_sum(&self.groups) };
400            agg.rename(new_name);
401            cols.push(agg);
402        }
403        DataFrame::new_infer_height(cols)
404    }
405
406    /// Aggregate grouped series and compute the minimal value per group.
407    ///
408    /// # Example
409    ///
410    /// ```rust
411    /// # use polars_core::prelude::*;
412    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
413    ///     df.group_by(["date"])?.select(["temp"]).min()
414    /// }
415    /// ```
416    /// Returns:
417    ///
418    /// ```text
419    /// +------------+----------+
420    /// | date       | temp_min |
421    /// | ---        | ---      |
422    /// | Date       | i32      |
423    /// +============+==========+
424    /// | 2020-08-23 | 9        |
425    /// +------------+----------+
426    /// | 2020-08-22 | 1        |
427    /// +------------+----------+
428    /// | 2020-08-21 | 10       |
429    /// +------------+----------+
430    /// ```
431    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
432    pub fn min(&self) -> PolarsResult<DataFrame> {
433        let (mut cols, agg_cols) = self.prepare_agg()?;
434        for agg_col in agg_cols {
435            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Min);
436            let mut agg = unsafe { agg_col.agg_min(&self.groups) };
437            agg.rename(new_name);
438            cols.push(agg);
439        }
440        DataFrame::new_infer_height(cols)
441    }
442
443    /// Aggregate grouped series and compute the maximum value per group.
444    ///
445    /// # Example
446    ///
447    /// ```rust
448    /// # use polars_core::prelude::*;
449    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
450    ///     df.group_by(["date"])?.select(["temp"]).max()
451    /// }
452    /// ```
453    /// Returns:
454    ///
455    /// ```text
456    /// +------------+----------+
457    /// | date       | temp_max |
458    /// | ---        | ---      |
459    /// | Date       | i32      |
460    /// +============+==========+
461    /// | 2020-08-23 | 9        |
462    /// +------------+----------+
463    /// | 2020-08-22 | 7        |
464    /// +------------+----------+
465    /// | 2020-08-21 | 20       |
466    /// +------------+----------+
467    /// ```
468    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
469    pub fn max(&self) -> PolarsResult<DataFrame> {
470        let (mut cols, agg_cols) = self.prepare_agg()?;
471        for agg_col in agg_cols {
472            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Max);
473            let mut agg = unsafe { agg_col.agg_max(&self.groups) };
474            agg.rename(new_name);
475            cols.push(agg);
476        }
477        DataFrame::new_infer_height(cols)
478    }
479
480    /// Aggregate grouped `Series` and find the first value per group.
481    ///
482    /// # Example
483    ///
484    /// ```rust
485    /// # use polars_core::prelude::*;
486    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
487    ///     df.group_by(["date"])?.select(["temp"]).first()
488    /// }
489    /// ```
490    /// Returns:
491    ///
492    /// ```text
493    /// +------------+------------+
494    /// | date       | temp_first |
495    /// | ---        | ---        |
496    /// | Date       | i32        |
497    /// +============+============+
498    /// | 2020-08-23 | 9          |
499    /// +------------+------------+
500    /// | 2020-08-22 | 7          |
501    /// +------------+------------+
502    /// | 2020-08-21 | 20         |
503    /// +------------+------------+
504    /// ```
505    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
506    pub fn first(&self) -> PolarsResult<DataFrame> {
507        let (mut cols, agg_cols) = self.prepare_agg()?;
508        for agg_col in agg_cols {
509            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::First);
510            let mut agg = unsafe { agg_col.agg_first(&self.groups) };
511            agg.rename(new_name);
512            cols.push(agg);
513        }
514        DataFrame::new_infer_height(cols)
515    }
516
517    /// Aggregate grouped `Series` and return the last value per group.
518    ///
519    /// # Example
520    ///
521    /// ```rust
522    /// # use polars_core::prelude::*;
523    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
524    ///     df.group_by(["date"])?.select(["temp"]).last()
525    /// }
526    /// ```
527    /// Returns:
528    ///
529    /// ```text
530    /// +------------+------------+
531    /// | date       | temp_last |
532    /// | ---        | ---        |
533    /// | Date       | i32        |
534    /// +============+============+
535    /// | 2020-08-23 | 9          |
536    /// +------------+------------+
537    /// | 2020-08-22 | 1          |
538    /// +------------+------------+
539    /// | 2020-08-21 | 10         |
540    /// +------------+------------+
541    /// ```
542    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
543    pub fn last(&self) -> PolarsResult<DataFrame> {
544        let (mut cols, agg_cols) = self.prepare_agg()?;
545        for agg_col in agg_cols {
546            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Last);
547            let mut agg = unsafe { agg_col.agg_last(&self.groups) };
548            agg.rename(new_name);
549            cols.push(agg);
550        }
551        DataFrame::new_infer_height(cols)
552    }
553
554    /// Aggregate grouped `Series` by counting the number of unique values.
555    ///
556    /// # Example
557    ///
558    /// ```rust
559    /// # use polars_core::prelude::*;
560    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
561    ///     df.group_by(["date"])?.select(["temp"]).n_unique()
562    /// }
563    /// ```
564    /// Returns:
565    ///
566    /// ```text
567    /// +------------+---------------+
568    /// | date       | temp_n_unique |
569    /// | ---        | ---           |
570    /// | Date       | u32           |
571    /// +============+===============+
572    /// | 2020-08-23 | 1             |
573    /// +------------+---------------+
574    /// | 2020-08-22 | 2             |
575    /// +------------+---------------+
576    /// | 2020-08-21 | 2             |
577    /// +------------+---------------+
578    /// ```
579    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
580    pub fn n_unique(&self) -> PolarsResult<DataFrame> {
581        let (mut cols, agg_cols) = self.prepare_agg()?;
582        for agg_col in agg_cols {
583            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::NUnique);
584            let mut agg = unsafe { agg_col.agg_n_unique(&self.groups) };
585            agg.rename(new_name);
586            cols.push(agg);
587        }
588        DataFrame::new_infer_height(cols)
589    }
590
591    /// Aggregate grouped [`Series`] and determine the quantile per group.
592    ///
593    /// # Example
594    ///
595    /// ```rust
596    /// # use polars_core::prelude::*;
597    ///
598    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
599    ///     df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileMethod::default())
600    /// }
601    /// ```
602    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
603    pub fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<DataFrame> {
604        polars_ensure!(
605            (0.0..=1.0).contains(&quantile),
606            ComputeError: "`quantile` should be within 0.0 and 1.0"
607        );
608        let (mut cols, agg_cols) = self.prepare_agg()?;
609        for agg_col in agg_cols {
610            let new_name = fmt_group_by_column(
611                agg_col.name().as_str(),
612                GroupByMethod::Quantile(quantile, method),
613            );
614            let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, method) };
615            agg.rename(new_name);
616            cols.push(agg);
617        }
618        DataFrame::new_infer_height(cols)
619    }
620
621    /// Aggregate grouped [`Series`] and determine the median per group.
622    ///
623    /// # Example
624    ///
625    /// ```rust
626    /// # use polars_core::prelude::*;
627    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
628    ///     df.group_by(["date"])?.select(["temp"]).median()
629    /// }
630    /// ```
631    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
632    pub fn median(&self) -> PolarsResult<DataFrame> {
633        let (mut cols, agg_cols) = self.prepare_agg()?;
634        for agg_col in agg_cols {
635            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Median);
636            let mut agg = unsafe { agg_col.agg_median(&self.groups) };
637            agg.rename(new_name);
638            cols.push(agg);
639        }
640        DataFrame::new_infer_height(cols)
641    }
642
643    /// Aggregate grouped [`Series`] and determine the variance per group.
644    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
645    pub fn var(&self, ddof: u8) -> PolarsResult<DataFrame> {
646        let (mut cols, agg_cols) = self.prepare_agg()?;
647        for agg_col in agg_cols {
648            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Var(ddof));
649            let mut agg = unsafe { agg_col.agg_var(&self.groups, ddof) };
650            agg.rename(new_name);
651            cols.push(agg);
652        }
653        DataFrame::new_infer_height(cols)
654    }
655
656    /// Aggregate grouped [`Series`] and determine the standard deviation per group.
657    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
658    pub fn std(&self, ddof: u8) -> PolarsResult<DataFrame> {
659        let (mut cols, agg_cols) = self.prepare_agg()?;
660        for agg_col in agg_cols {
661            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Std(ddof));
662            let mut agg = unsafe { agg_col.agg_std(&self.groups, ddof) };
663            agg.rename(new_name);
664            cols.push(agg);
665        }
666        DataFrame::new_infer_height(cols)
667    }
668
669    /// Aggregate grouped series and compute the number of values per group.
670    ///
671    /// # Example
672    ///
673    /// ```rust
674    /// # use polars_core::prelude::*;
675    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
676    ///     df.group_by(["date"])?.select(["temp"]).count()
677    /// }
678    /// ```
679    /// Returns:
680    ///
681    /// ```text
682    /// +------------+------------+
683    /// | date       | temp_count |
684    /// | ---        | ---        |
685    /// | Date       | u32        |
686    /// +============+============+
687    /// | 2020-08-23 | 1          |
688    /// +------------+------------+
689    /// | 2020-08-22 | 2          |
690    /// +------------+------------+
691    /// | 2020-08-21 | 2          |
692    /// +------------+------------+
693    /// ```
694    pub fn count(&self) -> PolarsResult<DataFrame> {
695        let (mut cols, agg_cols) = self.prepare_agg()?;
696
697        for agg_col in agg_cols {
698            let new_name = fmt_group_by_column(
699                agg_col.name().as_str(),
700                GroupByMethod::Count {
701                    include_nulls: true,
702                },
703            );
704            let mut ca = self.groups.group_count();
705            ca.rename(new_name);
706            cols.push(ca.into_column());
707        }
708        DataFrame::new_infer_height(cols)
709    }
710
711    /// Get the group_by group indexes.
712    ///
713    /// # Example
714    ///
715    /// ```rust
716    /// # use polars_core::prelude::*;
717    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
718    ///     df.group_by(["date"])?.groups()
719    /// }
720    /// ```
721    /// Returns:
722    ///
723    /// ```text
724    /// +--------------+------------+
725    /// | date         | groups     |
726    /// | ---          | ---        |
727    /// | Date(days)   | list [u32] |
728    /// +==============+============+
729    /// | 2020-08-23   | "[3]"      |
730    /// +--------------+------------+
731    /// | 2020-08-22   | "[2, 4]"   |
732    /// +--------------+------------+
733    /// | 2020-08-21   | "[0, 1]"   |
734    /// +--------------+------------+
735    /// ```
736    pub fn groups(&self) -> PolarsResult<DataFrame> {
737        let mut cols = self.keys();
738        let mut column = self.groups.as_list_chunked();
739        let new_name = fmt_group_by_column("", GroupByMethod::Groups);
740        column.rename(new_name);
741        cols.push(column.into_column());
742        DataFrame::new_infer_height(cols)
743    }
744
745    /// Aggregate the groups of the group_by operation into lists.
746    ///
747    /// # Example
748    ///
749    /// ```rust
750    /// # use polars_core::prelude::*;
751    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
752    ///     // GroupBy and aggregate to Lists
753    ///     df.group_by(["date"])?.select(["temp"]).agg_list()
754    /// }
755    /// ```
756    /// Returns:
757    ///
758    /// ```text
759    /// +------------+------------------------+
760    /// | date       | temp_agg_list          |
761    /// | ---        | ---                    |
762    /// | Date       | list [i32]             |
763    /// +============+========================+
764    /// | 2020-08-23 | "[Some(9)]"            |
765    /// +------------+------------------------+
766    /// | 2020-08-22 | "[Some(7), Some(1)]"   |
767    /// +------------+------------------------+
768    /// | 2020-08-21 | "[Some(20), Some(10)]" |
769    /// +------------+------------------------+
770    /// ```
771    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
772    pub fn agg_list(&self) -> PolarsResult<DataFrame> {
773        let (mut cols, agg_cols) = self.prepare_agg()?;
774        for agg_col in agg_cols {
775            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Implode);
776            let mut agg = unsafe { agg_col.agg_list(&self.groups) };
777            agg.rename(new_name);
778            cols.push(agg);
779        }
780        DataFrame::new_infer_height(cols)
781    }
782
783    fn prepare_apply(&self) -> PolarsResult<DataFrame> {
784        if let Some(agg) = &self.selected_agg {
785            if agg.is_empty() {
786                Ok(self.df.clone())
787            } else {
788                let mut new_cols = Vec::with_capacity(self.selected_keys.len() + agg.len());
789                new_cols.extend_from_slice(&self.selected_keys);
790                let cols = self.df.select_to_vec(agg.as_slice())?;
791                new_cols.extend(cols);
792                Ok(unsafe { DataFrame::new_unchecked(self.df.height(), new_cols) })
793            }
794        } else {
795            Ok(self.df.clone())
796        }
797    }
798
799    /// Apply a closure over the groups as a new [`DataFrame`] in parallel.
800    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
801    pub fn par_apply<F>(&self, f: F) -> PolarsResult<DataFrame>
802    where
803        F: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
804    {
805        polars_ensure!(self.df.height() > 0, ComputeError: "cannot group_by + apply on empty 'DataFrame'");
806        let df = self.prepare_apply()?;
807        let dfs = self
808            .get_groups()
809            .par_iter()
810            .map(|g| {
811                // SAFETY:
812                // groups are in bounds
813                let sub_df = unsafe { take_df(&df, g) };
814                f(sub_df)
815            })
816            .collect::<PolarsResult<Vec<_>>>()?;
817
818        let mut df = accumulate_dataframes_vertical(dfs)?;
819        df.rechunk_mut_par();
820        Ok(df)
821    }
822
823    /// Apply a closure over the groups as a new [`DataFrame`].
824    pub fn apply<F>(&self, f: F) -> PolarsResult<DataFrame>
825    where
826        F: FnMut(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
827    {
828        self.apply_sliced(None, f, None)
829    }
830
831    pub fn apply_sliced<F>(
832        &self,
833        slice: Option<(i64, usize)>,
834        mut f: F,
835        schema: Option<&SchemaRef>,
836    ) -> PolarsResult<DataFrame>
837    where
838        F: FnMut(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
839    {
840        if self.df.height() == 0 {
841            // return empty dataframe with correct schema
842            if let Some(schema) = schema {
843                return Ok(DataFrame::empty_with_arc_schema(schema.clone()));
844            }
845
846            polars_bail!(ComputeError: "cannot group_by + apply on empty 'DataFrame'");
847        }
848
849        let df = self.prepare_apply()?;
850        let max_height = if let Some((offset, len)) = slice {
851            offset.try_into().unwrap_or(usize::MAX).saturating_add(len)
852        } else {
853            usize::MAX
854        };
855        let mut height = 0;
856        let mut dfs = Vec::with_capacity(self.get_groups().len());
857        for g in self.get_groups().iter() {
858            // SAFETY: groups are in bounds.
859            let sub_df = unsafe { take_df(&df, g) };
860            let df = f(sub_df)?;
861            height += df.height();
862            dfs.push(df);
863
864            // Even if max_height is zero we need at least one df, so check
865            // after first push.
866            if height >= max_height {
867                break;
868            }
869        }
870
871        let mut df = accumulate_dataframes_vertical(dfs)?;
872        if let Some((offset, len)) = slice {
873            df = df.slice(offset, len);
874        }
875        Ok(df)
876    }
877
878    pub fn sliced(mut self, slice: Option<(i64, usize)>) -> Self {
879        match slice {
880            None => self,
881            Some((offset, length)) => {
882                self.groups = self.groups.slice(offset, length);
883                self.selected_keys = self.keys_sliced(slice);
884                self
885            },
886        }
887    }
888}
889
890unsafe fn take_df(df: &DataFrame, g: GroupsIndicator) -> DataFrame {
891    match g {
892        GroupsIndicator::Idx(idx) => df.take_slice_unchecked(idx.1),
893        GroupsIndicator::Slice([first, len]) => df.slice(first as i64, len as usize),
894    }
895}
896
897#[derive(Copy, Clone, Debug)]
898pub enum GroupByMethod {
899    Min,
900    NanMin,
901    Max,
902    NanMax,
903    Median,
904    Mean,
905    First,
906    FirstNonNull,
907    Last,
908    LastNonNull,
909    Item { allow_empty: bool },
910    Sum,
911    Groups,
912    NUnique,
913    Quantile(f64, QuantileMethod),
914    Count { include_nulls: bool },
915    Implode,
916    Std(u8),
917    Var(u8),
918    ArgMin,
919    ArgMax,
920}
921
922impl Display for GroupByMethod {
923    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
924        use GroupByMethod::*;
925        let s = match self {
926            Min => "min",
927            NanMin => "nan_min",
928            Max => "max",
929            NanMax => "nan_max",
930            Median => "median",
931            Mean => "mean",
932            First => "first",
933            FirstNonNull => "first_non_null",
934            Last => "last",
935            LastNonNull => "last_non_null",
936            Item { .. } => "item",
937            Sum => "sum",
938            Groups => "groups",
939            NUnique => "n_unique",
940            Quantile(_, _) => "quantile",
941            Count { .. } => "count",
942            Implode => "list",
943            Std(_) => "std",
944            Var(_) => "var",
945            ArgMin => "arg_min",
946            ArgMax => "arg_max",
947        };
948        write!(f, "{s}")
949    }
950}
951
952// Formatting functions used in eager and lazy code for renaming grouped columns
953pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> PlSmallStr {
954    use GroupByMethod::*;
955    match method {
956        Min => format_pl_smallstr!("{name}_min"),
957        Max => format_pl_smallstr!("{name}_max"),
958        NanMin => format_pl_smallstr!("{name}_nan_min"),
959        NanMax => format_pl_smallstr!("{name}_nan_max"),
960        Median => format_pl_smallstr!("{name}_median"),
961        Mean => format_pl_smallstr!("{name}_mean"),
962        First => format_pl_smallstr!("{name}_first"),
963        FirstNonNull => format_pl_smallstr!("{name}_first_non_null"),
964        Last => format_pl_smallstr!("{name}_last"),
965        LastNonNull => format_pl_smallstr!("{name}_last_non_null"),
966        Item { .. } => format_pl_smallstr!("{name}_item"),
967        Sum => format_pl_smallstr!("{name}_sum"),
968        Groups => PlSmallStr::from_static("groups"),
969        NUnique => format_pl_smallstr!("{name}_n_unique"),
970        Count { .. } => format_pl_smallstr!("{name}_count"),
971        Implode => format_pl_smallstr!("{name}_agg_list"),
972        Quantile(quantile, _interpol) => format_pl_smallstr!("{name}_quantile_{quantile:.2}"),
973        Std(_) => format_pl_smallstr!("{name}_agg_std"),
974        Var(_) => format_pl_smallstr!("{name}_agg_var"),
975        ArgMin => format_pl_smallstr!("{name}_arg_min"),
976        ArgMax => format_pl_smallstr!("{name}_arg_max"),
977    }
978}
979
980#[cfg(test)]
981mod test {
982    use num_traits::FloatConst;
983
984    use crate::prelude::*;
985
986    #[test]
987    #[cfg(feature = "dtype-date")]
988    #[cfg_attr(miri, ignore)]
989    fn test_group_by() -> PolarsResult<()> {
990        let s0 = Column::new(
991            PlSmallStr::from_static("date"),
992            &[
993                "2020-08-21",
994                "2020-08-21",
995                "2020-08-22",
996                "2020-08-23",
997                "2020-08-22",
998            ],
999        );
1000        let s1 = Column::new(PlSmallStr::from_static("temp"), [20, 10, 7, 9, 1]);
1001        let s2 = Column::new(PlSmallStr::from_static("rain"), [0.2, 0.1, 0.3, 0.1, 0.01]);
1002        let df = DataFrame::new_infer_height(vec![s0, s1, s2]).unwrap();
1003
1004        let out = df.group_by_stable(["date"])?.select(["temp"]).count()?;
1005        assert_eq!(
1006            out.column("temp_count")?,
1007            &Column::new(PlSmallStr::from_static("temp_count"), [2 as IdxSize, 2, 1])
1008        );
1009
1010        // Use of deprecated mean() for testing purposes
1011        #[allow(deprecated)]
1012        // Select multiple
1013        let out = df
1014            .group_by_stable(["date"])?
1015            .select(["temp", "rain"])
1016            .mean()?;
1017        assert_eq!(
1018            out.column("temp_mean")?,
1019            &Column::new(PlSmallStr::from_static("temp_mean"), [15.0f64, 4.0, 9.0])
1020        );
1021
1022        // Use of deprecated `mean()` for testing purposes
1023        #[allow(deprecated)]
1024        // Group by multiple
1025        let out = df
1026            .group_by_stable(["date", "temp"])?
1027            .select(["rain"])
1028            .mean()?;
1029        assert!(out.column("rain_mean").is_ok());
1030
1031        // Use of deprecated `sum()` for testing purposes
1032        #[allow(deprecated)]
1033        let out = df.group_by_stable(["date"])?.select(["temp"]).sum()?;
1034        assert_eq!(
1035            out.column("temp_sum")?,
1036            &Column::new(PlSmallStr::from_static("temp_sum"), [30, 8, 9])
1037        );
1038
1039        // Use of deprecated `n_unique()` for testing purposes
1040        #[allow(deprecated)]
1041        // implicit select all and only aggregate on methods that support that aggregation
1042        let gb = df.group_by(["date"]).unwrap().n_unique().unwrap();
1043        // check the group by column is filtered out.
1044        assert_eq!(gb.width(), 3);
1045        Ok(())
1046    }
1047
1048    #[test]
1049    #[cfg_attr(miri, ignore)]
1050    fn test_static_group_by_by_12_columns() {
1051        // Build GroupBy DataFrame.
1052        let s0 = Column::new("G1".into(), ["A", "A", "B", "B", "C"].as_ref());
1053        let s1 = Column::new("N".into(), [1, 2, 2, 4, 2].as_ref());
1054        let s2 = Column::new("G2".into(), ["k", "l", "m", "m", "l"].as_ref());
1055        let s3 = Column::new("G3".into(), ["a", "b", "c", "c", "d"].as_ref());
1056        let s4 = Column::new("G4".into(), ["1", "2", "3", "3", "4"].as_ref());
1057        let s5 = Column::new("G5".into(), ["X", "Y", "Z", "Z", "W"].as_ref());
1058        let s6 = Column::new("G6".into(), [false, true, true, true, false].as_ref());
1059        let s7 = Column::new("G7".into(), ["r", "x", "q", "q", "o"].as_ref());
1060        let s8 = Column::new("G8".into(), ["R", "X", "Q", "Q", "O"].as_ref());
1061        let s9 = Column::new("G9".into(), [1, 2, 3, 3, 4].as_ref());
1062        let s10 = Column::new("G10".into(), [".", "!", "?", "?", "/"].as_ref());
1063        let s11 = Column::new("G11".into(), ["(", ")", "@", "@", "$"].as_ref());
1064        let s12 = Column::new("G12".into(), ["-", "_", ";", ";", ","].as_ref());
1065
1066        let df = DataFrame::new_infer_height(vec![
1067            s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
1068        ])
1069        .unwrap();
1070
1071        // Use of deprecated `sum()` for testing purposes
1072        #[allow(deprecated)]
1073        let adf = df
1074            .group_by([
1075                "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12",
1076            ])
1077            .unwrap()
1078            .select(["N"])
1079            .sum()
1080            .unwrap();
1081
1082        assert_eq!(
1083            Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1084            &[Some(1), Some(2), Some(2), Some(6)]
1085        );
1086    }
1087
1088    #[test]
1089    #[cfg_attr(miri, ignore)]
1090    fn test_dynamic_group_by_by_13_columns() {
1091        // The content for every group_by series.
1092        let series_content = ["A", "A", "B", "B", "C"];
1093
1094        // The name of every group_by series.
1095        let series_names = [
1096            "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12", "G13",
1097        ];
1098
1099        // Vector to contain every series.
1100        let mut columns = Vec::with_capacity(14);
1101
1102        // Create a series for every group name.
1103        for series_name in series_names {
1104            let group_columns = Column::new(series_name.into(), series_content.as_ref());
1105            columns.push(group_columns);
1106        }
1107
1108        // Create a series for the aggregation column.
1109        let agg_series = Column::new("N".into(), [1, 2, 3, 3, 4].as_ref());
1110        columns.push(agg_series);
1111
1112        // Create the dataframe with the computed series.
1113        let df = DataFrame::new_infer_height(columns).unwrap();
1114
1115        // Use of deprecated `sum()` for testing purposes
1116        #[allow(deprecated)]
1117        // Compute the aggregated DataFrame by the 13 columns defined in `series_names`.
1118        let adf = df
1119            .group_by(series_names)
1120            .unwrap()
1121            .select(["N"])
1122            .sum()
1123            .unwrap();
1124
1125        // Check that the results of the group-by are correct. The content of every column
1126        // is equal, then, the grouped columns shall be equal and in the same order.
1127        for series_name in &series_names {
1128            assert_eq!(
1129                Vec::from(&adf.column(series_name).unwrap().str().unwrap().sort(false)),
1130                &[Some("A"), Some("B"), Some("C")]
1131            );
1132        }
1133
1134        // Check the aggregated column is the expected one.
1135        assert_eq!(
1136            Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1137            &[Some(3), Some(4), Some(6)]
1138        );
1139    }
1140
1141    #[test]
1142    #[cfg_attr(miri, ignore)]
1143    fn test_group_by_floats() {
1144        let df = df! {"flt" => [1., 1., 2., 2., 3.],
1145                    "val" => [1, 1, 1, 1, 1]
1146        }
1147        .unwrap();
1148        // Use of deprecated `sum()` for testing purposes
1149        #[allow(deprecated)]
1150        let res = df.group_by(["flt"]).unwrap().sum().unwrap();
1151        let res = res.sort(["flt"], SortMultipleOptions::default()).unwrap();
1152        assert_eq!(
1153            Vec::from(res.column("val_sum").unwrap().i32().unwrap()),
1154            &[Some(2), Some(2), Some(1)]
1155        );
1156    }
1157
1158    #[test]
1159    #[cfg_attr(miri, ignore)]
1160    #[cfg(feature = "dtype-categorical")]
1161    fn test_group_by_categorical() {
1162        let mut df = df! {"foo" => ["a", "a", "b", "b", "c"],
1163                    "ham" => ["a", "a", "b", "b", "c"],
1164                    "bar" => [1, 1, 1, 1, 1]
1165        }
1166        .unwrap();
1167
1168        df.apply("foo", |s| {
1169            s.cast(&DataType::from_categories(Categories::global()))
1170                .unwrap()
1171        })
1172        .unwrap();
1173
1174        // Use of deprecated `sum()` for testing purposes
1175        #[allow(deprecated)]
1176        // check multiple keys and categorical
1177        let res = df
1178            .group_by_stable(["foo", "ham"])
1179            .unwrap()
1180            .select(["bar"])
1181            .sum()
1182            .unwrap();
1183
1184        assert_eq!(
1185            Vec::from(
1186                res.column("bar_sum")
1187                    .unwrap()
1188                    .as_materialized_series()
1189                    .i32()
1190                    .unwrap()
1191            ),
1192            &[Some(2), Some(2), Some(1)]
1193        );
1194    }
1195
1196    #[test]
1197    #[cfg_attr(miri, ignore)]
1198    fn test_group_by_null_handling() -> PolarsResult<()> {
1199        let df = df!(
1200            "a" => ["a", "a", "a", "b", "b"],
1201            "b" => [Some(1), Some(2), None, None, Some(1)]
1202        )?;
1203        // Use of deprecated `mean()` for testing purposes
1204        #[allow(deprecated)]
1205        let out = df.group_by_stable(["a"])?.mean()?;
1206
1207        assert_eq!(
1208            Vec::from(out.column("b_mean")?.as_materialized_series().f64()?),
1209            &[Some(1.5), Some(1.0)]
1210        );
1211        Ok(())
1212    }
1213
1214    #[test]
1215    #[cfg_attr(miri, ignore)]
1216    fn test_group_by_var() -> PolarsResult<()> {
1217        // check variance and proper coercion to f64
1218        let df = df![
1219            "g" => ["foo", "foo", "bar"],
1220            "flt" => [1.0, 2.0, 3.0],
1221            "int" => [1, 2, 3]
1222        ]?;
1223
1224        // Use of deprecated `sum()` for testing purposes
1225        #[allow(deprecated)]
1226        let out = df.group_by_stable(["g"])?.select(["int"]).var(1)?;
1227
1228        assert_eq!(out.column("int_agg_var")?.f64()?.get(0), Some(0.5));
1229        // Use of deprecated `std()` for testing purposes
1230        #[allow(deprecated)]
1231        let out = df.group_by_stable(["g"])?.select(["int"]).std(1)?;
1232        let val = out.column("int_agg_std")?.f64()?.get(0).unwrap();
1233        let expected = f64::FRAC_1_SQRT_2();
1234        assert!((val - expected).abs() < 0.000001);
1235        Ok(())
1236    }
1237
1238    #[test]
1239    #[cfg_attr(miri, ignore)]
1240    #[cfg(feature = "dtype-categorical")]
1241    fn test_group_by_null_group() -> PolarsResult<()> {
1242        // check if null is own group
1243        let mut df = df![
1244            "g" => [Some("foo"), Some("foo"), Some("bar"), None, None],
1245            "flt" => [1.0, 2.0, 3.0, 1.0, 1.0],
1246            "int" => [1, 2, 3, 1, 1]
1247        ]?;
1248
1249        df.try_apply("g", |s| {
1250            s.cast(&DataType::from_categories(Categories::global()))
1251        })?;
1252
1253        // Use of deprecated `sum()` for testing purposes
1254        #[allow(deprecated)]
1255        let _ = df.group_by(["g"])?.sum()?;
1256        Ok(())
1257    }
1258}