polars_core/frame/group_by/
mod.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3
4use num_traits::NumCast;
5use polars_compute::rolling::QuantileMethod;
6use polars_utils::format_pl_smallstr;
7use polars_utils::hashing::DirtyHash;
8use rayon::prelude::*;
9
10use self::hashing::*;
11use crate::POOL;
12use crate::prelude::*;
13use crate::utils::{_set_partition_size, accumulate_dataframes_vertical};
14
15pub mod aggregations;
16pub mod expr;
17pub(crate) mod hashing;
18mod into_groups;
19mod perfect;
20mod position;
21
22pub use into_groups::*;
23pub use position::*;
24
25use crate::chunked_array::ops::row_encode::{
26    encode_rows_unordered, encode_rows_vertical_par_unordered,
27};
28
29impl DataFrame {
30    pub fn group_by_with_series(
31        &self,
32        mut by: Vec<Column>,
33        multithreaded: bool,
34        sorted: bool,
35    ) -> PolarsResult<GroupBy> {
36        polars_ensure!(
37            !by.is_empty(),
38            ComputeError: "at least one key is required in a group_by operation"
39        );
40
41        // Ensure all 'by' columns have the same common_height
42        // The condition self.width > 0 ensures we can still call this on a
43        // dummy dataframe where we provide the keys
44        let common_height = if self.width() > 0 {
45            self.height()
46        } else {
47            by.iter().map(|s| s.len()).max().expect("at least 1 key")
48        };
49        for by_key in by.iter_mut() {
50            if by_key.len() != common_height {
51                polars_ensure!(
52                    by_key.len() == 1,
53                    ShapeMismatch: "series used as keys should have the same length as the DataFrame"
54                );
55                *by_key = by_key.new_from_index(0, common_height)
56            }
57        }
58
59        let groups = if by.len() == 1 {
60            let column = &by[0];
61            column
62                .as_materialized_series()
63                .group_tuples(multithreaded, sorted)
64        } else if by.iter().any(|s| s.dtype().is_object()) {
65            #[cfg(feature = "object")]
66            {
67                let mut df = DataFrame::new(by.clone()).unwrap();
68                let n = df.height();
69                let rows = df.to_av_rows();
70                let iter = (0..n).map(|i| rows.get(i));
71                Ok(group_by(iter, sorted))
72            }
73            #[cfg(not(feature = "object"))]
74            {
75                unreachable!()
76            }
77        } else {
78            // Skip null dtype.
79            let by = by
80                .iter()
81                .filter(|s| !s.dtype().is_null())
82                .cloned()
83                .collect::<Vec<_>>();
84            if by.is_empty() {
85                let groups = if self.is_empty() {
86                    vec![]
87                } else {
88                    vec![[0, self.height() as IdxSize]]
89                };
90                Ok(GroupsType::Slice {
91                    groups,
92                    rolling: false,
93                })
94            } else {
95                let rows = if multithreaded {
96                    encode_rows_vertical_par_unordered(&by)
97                } else {
98                    encode_rows_unordered(&by)
99                }?
100                .into_series();
101                rows.group_tuples(multithreaded, sorted)
102            }
103        };
104        Ok(GroupBy::new(self, by, groups?.into_sliceable(), None))
105    }
106
107    /// Group DataFrame using a Series column.
108    ///
109    /// # Example
110    ///
111    /// ```
112    /// use polars_core::prelude::*;
113    /// fn group_by_sum(df: &DataFrame) -> PolarsResult<DataFrame> {
114    ///     df.group_by(["column_name"])?
115    ///     .select(["agg_column_name"])
116    ///     .sum()
117    /// }
118    /// ```
119    pub fn group_by<I, S>(&self, by: I) -> PolarsResult<GroupBy>
120    where
121        I: IntoIterator<Item = S>,
122        S: Into<PlSmallStr>,
123    {
124        let selected_keys = self.select_columns(by)?;
125        self.group_by_with_series(selected_keys, true, false)
126    }
127
128    /// Group DataFrame using a Series column.
129    /// The groups are ordered by their smallest row index.
130    pub fn group_by_stable<I, S>(&self, by: I) -> PolarsResult<GroupBy>
131    where
132        I: IntoIterator<Item = S>,
133        S: Into<PlSmallStr>,
134    {
135        let selected_keys = self.select_columns(by)?;
136        self.group_by_with_series(selected_keys, true, true)
137    }
138}
139
140/// Returned by a group_by operation on a DataFrame. This struct supports
141/// several aggregations.
142///
143/// Until described otherwise, the examples in this struct are performed on the following DataFrame:
144///
145/// ```ignore
146/// use polars_core::prelude::*;
147///
148/// let dates = &[
149/// "2020-08-21",
150/// "2020-08-21",
151/// "2020-08-22",
152/// "2020-08-23",
153/// "2020-08-22",
154/// ];
155/// // date format
156/// let fmt = "%Y-%m-%d";
157/// // create date series
158/// let s0 = DateChunked::parse_from_str_slice("date", dates, fmt)
159///         .into_series();
160/// // create temperature series
161/// let s1 = Series::new("temp".into(), [20, 10, 7, 9, 1]);
162/// // create rain series
163/// let s2 = Series::new("rain".into(), [0.2, 0.1, 0.3, 0.1, 0.01]);
164/// // create a new DataFrame
165/// let df = DataFrame::new(vec![s0, s1, s2]).unwrap();
166/// println!("{:?}", df);
167/// ```
168///
169/// Outputs:
170///
171/// ```text
172/// +------------+------+------+
173/// | date       | temp | rain |
174/// | ---        | ---  | ---  |
175/// | Date       | i32  | f64  |
176/// +============+======+======+
177/// | 2020-08-21 | 20   | 0.2  |
178/// +------------+------+------+
179/// | 2020-08-21 | 10   | 0.1  |
180/// +------------+------+------+
181/// | 2020-08-22 | 7    | 0.3  |
182/// +------------+------+------+
183/// | 2020-08-23 | 9    | 0.1  |
184/// +------------+------+------+
185/// | 2020-08-22 | 1    | 0.01 |
186/// +------------+------+------+
187/// ```
188///
189#[derive(Debug, Clone)]
190pub struct GroupBy<'a> {
191    pub df: &'a DataFrame,
192    pub(crate) selected_keys: Vec<Column>,
193    // [first idx, [other idx]]
194    groups: GroupPositions,
195    // columns selected for aggregation
196    pub(crate) selected_agg: Option<Vec<PlSmallStr>>,
197}
198
199impl<'a> GroupBy<'a> {
200    pub fn new(
201        df: &'a DataFrame,
202        by: Vec<Column>,
203        groups: GroupPositions,
204        selected_agg: Option<Vec<PlSmallStr>>,
205    ) -> Self {
206        GroupBy {
207            df,
208            selected_keys: by,
209            groups,
210            selected_agg,
211        }
212    }
213
214    /// Select the column(s) that should be aggregated.
215    /// You can select a single column or a slice of columns.
216    ///
217    /// Note that making a selection with this method is not required. If you
218    /// skip it all columns (except for the keys) will be selected for aggregation.
219    #[must_use]
220    pub fn select<I: IntoIterator<Item = S>, S: Into<PlSmallStr>>(mut self, selection: I) -> Self {
221        self.selected_agg = Some(selection.into_iter().map(|s| s.into()).collect());
222        self
223    }
224
225    /// Get the internal representation of the GroupBy operation.
226    /// The Vec returned contains:
227    ///     (first_idx, [`Vec<indexes>`])
228    ///     Where second value in the tuple is a vector with all matching indexes.
229    pub fn get_groups(&self) -> &GroupPositions {
230        &self.groups
231    }
232
233    /// Get the internal representation of the GroupBy operation.
234    /// The Vec returned contains:
235    ///     (first_idx, [`Vec<indexes>`])
236    ///     Where second value in the tuple is a vector with all matching indexes.
237    ///
238    /// # Safety
239    /// Groups should always be in bounds of the `DataFrame` hold by this [`GroupBy`].
240    /// If you mutate it, you must hold that invariant.
241    pub unsafe fn get_groups_mut(&mut self) -> &mut GroupPositions {
242        &mut self.groups
243    }
244
245    pub fn take_groups(self) -> GroupPositions {
246        self.groups
247    }
248
249    pub fn take_groups_mut(&mut self) -> GroupPositions {
250        std::mem::take(&mut self.groups)
251    }
252
253    pub fn keys_sliced(&self, slice: Option<(i64, usize)>) -> Vec<Column> {
254        #[allow(unused_assignments)]
255        // needed to keep the lifetimes valid for this scope
256        let mut groups_owned = None;
257
258        let groups = if let Some((offset, len)) = slice {
259            groups_owned = Some(self.groups.slice(offset, len));
260            groups_owned.as_deref().unwrap()
261        } else {
262            &self.groups
263        };
264        POOL.install(|| {
265            self.selected_keys
266                .par_iter()
267                .map(Column::as_materialized_series)
268                .map(|s| {
269                    match groups {
270                        GroupsType::Idx(groups) => {
271                            // SAFETY: groups are always in bounds.
272                            let mut out = unsafe { s.take_slice_unchecked(groups.first()) };
273                            if groups.sorted {
274                                out.set_sorted_flag(s.is_sorted_flag());
275                            };
276                            out
277                        },
278                        GroupsType::Slice { groups, rolling } => {
279                            if *rolling && !groups.is_empty() {
280                                // Groups can be sliced.
281                                let offset = groups[0][0];
282                                let [upper_offset, upper_len] = groups[groups.len() - 1];
283                                return s.slice(
284                                    offset as i64,
285                                    ((upper_offset + upper_len) - offset) as usize,
286                                );
287                            }
288
289                            let indices = groups
290                                .iter()
291                                .map(|&[first, _len]| first)
292                                .collect_ca(PlSmallStr::EMPTY);
293                            // SAFETY: groups are always in bounds.
294                            let mut out = unsafe { s.take_unchecked(&indices) };
295                            // Sliced groups are always in order of discovery.
296                            out.set_sorted_flag(s.is_sorted_flag());
297                            out
298                        },
299                    }
300                })
301                .map(Column::from)
302                .collect()
303        })
304    }
305
306    pub fn keys(&self) -> Vec<Column> {
307        self.keys_sliced(None)
308    }
309
310    fn prepare_agg(&self) -> PolarsResult<(Vec<Column>, Vec<Column>)> {
311        let keys = self.keys();
312
313        let agg_col = match &self.selected_agg {
314            Some(selection) => self.df.select_columns_impl(selection.as_slice()),
315            None => {
316                let by: Vec<_> = self.selected_keys.iter().map(|s| s.name()).collect();
317                let selection = self
318                    .df
319                    .iter()
320                    .map(|s| s.name())
321                    .filter(|a| !by.contains(a))
322                    .cloned()
323                    .collect::<Vec<_>>();
324
325                self.df.select_columns_impl(selection.as_slice())
326            },
327        }?;
328
329        Ok((keys, agg_col))
330    }
331
332    /// Aggregate grouped series and compute the mean per group.
333    ///
334    /// # Example
335    ///
336    /// ```rust
337    /// # use polars_core::prelude::*;
338    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
339    ///     df.group_by(["date"])?.select(["temp", "rain"]).mean()
340    /// }
341    /// ```
342    /// Returns:
343    ///
344    /// ```text
345    /// +------------+-----------+-----------+
346    /// | date       | temp_mean | rain_mean |
347    /// | ---        | ---       | ---       |
348    /// | Date       | f64       | f64       |
349    /// +============+===========+===========+
350    /// | 2020-08-23 | 9         | 0.1       |
351    /// +------------+-----------+-----------+
352    /// | 2020-08-22 | 4         | 0.155     |
353    /// +------------+-----------+-----------+
354    /// | 2020-08-21 | 15        | 0.15      |
355    /// +------------+-----------+-----------+
356    /// ```
357    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
358    pub fn mean(&self) -> PolarsResult<DataFrame> {
359        let (mut cols, agg_cols) = self.prepare_agg()?;
360
361        for agg_col in agg_cols {
362            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Mean);
363            let mut agg = unsafe { agg_col.agg_mean(&self.groups) };
364            agg.rename(new_name);
365            cols.push(agg);
366        }
367        DataFrame::new(cols)
368    }
369
370    /// Aggregate grouped series and compute the sum per group.
371    ///
372    /// # Example
373    ///
374    /// ```rust
375    /// # use polars_core::prelude::*;
376    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
377    ///     df.group_by(["date"])?.select(["temp"]).sum()
378    /// }
379    /// ```
380    /// Returns:
381    ///
382    /// ```text
383    /// +------------+----------+
384    /// | date       | temp_sum |
385    /// | ---        | ---      |
386    /// | Date       | i32      |
387    /// +============+==========+
388    /// | 2020-08-23 | 9        |
389    /// +------------+----------+
390    /// | 2020-08-22 | 8        |
391    /// +------------+----------+
392    /// | 2020-08-21 | 30       |
393    /// +------------+----------+
394    /// ```
395    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
396    pub fn sum(&self) -> PolarsResult<DataFrame> {
397        let (mut cols, agg_cols) = self.prepare_agg()?;
398
399        for agg_col in agg_cols {
400            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Sum);
401            let mut agg = unsafe { agg_col.agg_sum(&self.groups) };
402            agg.rename(new_name);
403            cols.push(agg);
404        }
405        DataFrame::new(cols)
406    }
407
408    /// Aggregate grouped series and compute the minimal value per group.
409    ///
410    /// # Example
411    ///
412    /// ```rust
413    /// # use polars_core::prelude::*;
414    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
415    ///     df.group_by(["date"])?.select(["temp"]).min()
416    /// }
417    /// ```
418    /// Returns:
419    ///
420    /// ```text
421    /// +------------+----------+
422    /// | date       | temp_min |
423    /// | ---        | ---      |
424    /// | Date       | i32      |
425    /// +============+==========+
426    /// | 2020-08-23 | 9        |
427    /// +------------+----------+
428    /// | 2020-08-22 | 1        |
429    /// +------------+----------+
430    /// | 2020-08-21 | 10       |
431    /// +------------+----------+
432    /// ```
433    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
434    pub fn min(&self) -> PolarsResult<DataFrame> {
435        let (mut cols, agg_cols) = self.prepare_agg()?;
436        for agg_col in agg_cols {
437            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Min);
438            let mut agg = unsafe { agg_col.agg_min(&self.groups) };
439            agg.rename(new_name);
440            cols.push(agg);
441        }
442        DataFrame::new(cols)
443    }
444
445    /// Aggregate grouped series and compute the maximum value per group.
446    ///
447    /// # Example
448    ///
449    /// ```rust
450    /// # use polars_core::prelude::*;
451    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
452    ///     df.group_by(["date"])?.select(["temp"]).max()
453    /// }
454    /// ```
455    /// Returns:
456    ///
457    /// ```text
458    /// +------------+----------+
459    /// | date       | temp_max |
460    /// | ---        | ---      |
461    /// | Date       | i32      |
462    /// +============+==========+
463    /// | 2020-08-23 | 9        |
464    /// +------------+----------+
465    /// | 2020-08-22 | 7        |
466    /// +------------+----------+
467    /// | 2020-08-21 | 20       |
468    /// +------------+----------+
469    /// ```
470    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
471    pub fn max(&self) -> PolarsResult<DataFrame> {
472        let (mut cols, agg_cols) = self.prepare_agg()?;
473        for agg_col in agg_cols {
474            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Max);
475            let mut agg = unsafe { agg_col.agg_max(&self.groups) };
476            agg.rename(new_name);
477            cols.push(agg);
478        }
479        DataFrame::new(cols)
480    }
481
482    /// Aggregate grouped `Series` and find the first value per group.
483    ///
484    /// # Example
485    ///
486    /// ```rust
487    /// # use polars_core::prelude::*;
488    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
489    ///     df.group_by(["date"])?.select(["temp"]).first()
490    /// }
491    /// ```
492    /// Returns:
493    ///
494    /// ```text
495    /// +------------+------------+
496    /// | date       | temp_first |
497    /// | ---        | ---        |
498    /// | Date       | i32        |
499    /// +============+============+
500    /// | 2020-08-23 | 9          |
501    /// +------------+------------+
502    /// | 2020-08-22 | 7          |
503    /// +------------+------------+
504    /// | 2020-08-21 | 20         |
505    /// +------------+------------+
506    /// ```
507    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
508    pub fn first(&self) -> PolarsResult<DataFrame> {
509        let (mut cols, agg_cols) = self.prepare_agg()?;
510        for agg_col in agg_cols {
511            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::First);
512            let mut agg = unsafe { agg_col.agg_first(&self.groups) };
513            agg.rename(new_name);
514            cols.push(agg);
515        }
516        DataFrame::new(cols)
517    }
518
519    /// Aggregate grouped `Series` and return the last value per group.
520    ///
521    /// # Example
522    ///
523    /// ```rust
524    /// # use polars_core::prelude::*;
525    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
526    ///     df.group_by(["date"])?.select(["temp"]).last()
527    /// }
528    /// ```
529    /// Returns:
530    ///
531    /// ```text
532    /// +------------+------------+
533    /// | date       | temp_last |
534    /// | ---        | ---        |
535    /// | Date       | i32        |
536    /// +============+============+
537    /// | 2020-08-23 | 9          |
538    /// +------------+------------+
539    /// | 2020-08-22 | 1          |
540    /// +------------+------------+
541    /// | 2020-08-21 | 10         |
542    /// +------------+------------+
543    /// ```
544    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
545    pub fn last(&self) -> PolarsResult<DataFrame> {
546        let (mut cols, agg_cols) = self.prepare_agg()?;
547        for agg_col in agg_cols {
548            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Last);
549            let mut agg = unsafe { agg_col.agg_last(&self.groups) };
550            agg.rename(new_name);
551            cols.push(agg);
552        }
553        DataFrame::new(cols)
554    }
555
556    /// Aggregate grouped `Series` by counting the number of unique values.
557    ///
558    /// # Example
559    ///
560    /// ```rust
561    /// # use polars_core::prelude::*;
562    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
563    ///     df.group_by(["date"])?.select(["temp"]).n_unique()
564    /// }
565    /// ```
566    /// Returns:
567    ///
568    /// ```text
569    /// +------------+---------------+
570    /// | date       | temp_n_unique |
571    /// | ---        | ---           |
572    /// | Date       | u32           |
573    /// +============+===============+
574    /// | 2020-08-23 | 1             |
575    /// +------------+---------------+
576    /// | 2020-08-22 | 2             |
577    /// +------------+---------------+
578    /// | 2020-08-21 | 2             |
579    /// +------------+---------------+
580    /// ```
581    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
582    pub fn n_unique(&self) -> PolarsResult<DataFrame> {
583        let (mut cols, agg_cols) = self.prepare_agg()?;
584        for agg_col in agg_cols {
585            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::NUnique);
586            let mut agg = unsafe { agg_col.agg_n_unique(&self.groups) };
587            agg.rename(new_name);
588            cols.push(agg);
589        }
590        DataFrame::new(cols)
591    }
592
593    /// Aggregate grouped [`Series`] and determine the quantile per group.
594    ///
595    /// # Example
596    ///
597    /// ```rust
598    /// # use polars_core::prelude::*;
599    ///
600    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
601    ///     df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileMethod::default())
602    /// }
603    /// ```
604    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
605    pub fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<DataFrame> {
606        polars_ensure!(
607            (0.0..=1.0).contains(&quantile),
608            ComputeError: "`quantile` should be within 0.0 and 1.0"
609        );
610        let (mut cols, agg_cols) = self.prepare_agg()?;
611        for agg_col in agg_cols {
612            let new_name = fmt_group_by_column(
613                agg_col.name().as_str(),
614                GroupByMethod::Quantile(quantile, method),
615            );
616            let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, method) };
617            agg.rename(new_name);
618            cols.push(agg);
619        }
620        DataFrame::new(cols)
621    }
622
623    /// Aggregate grouped [`Series`] and determine the median per group.
624    ///
625    /// # Example
626    ///
627    /// ```rust
628    /// # use polars_core::prelude::*;
629    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
630    ///     df.group_by(["date"])?.select(["temp"]).median()
631    /// }
632    /// ```
633    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
634    pub fn median(&self) -> PolarsResult<DataFrame> {
635        let (mut cols, agg_cols) = self.prepare_agg()?;
636        for agg_col in agg_cols {
637            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Median);
638            let mut agg = unsafe { agg_col.agg_median(&self.groups) };
639            agg.rename(new_name);
640            cols.push(agg);
641        }
642        DataFrame::new(cols)
643    }
644
645    /// Aggregate grouped [`Series`] and determine the variance per group.
646    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
647    pub fn var(&self, ddof: u8) -> PolarsResult<DataFrame> {
648        let (mut cols, agg_cols) = self.prepare_agg()?;
649        for agg_col in agg_cols {
650            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Var(ddof));
651            let mut agg = unsafe { agg_col.agg_var(&self.groups, ddof) };
652            agg.rename(new_name);
653            cols.push(agg);
654        }
655        DataFrame::new(cols)
656    }
657
658    /// Aggregate grouped [`Series`] and determine the standard deviation per group.
659    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
660    pub fn std(&self, ddof: u8) -> PolarsResult<DataFrame> {
661        let (mut cols, agg_cols) = self.prepare_agg()?;
662        for agg_col in agg_cols {
663            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Std(ddof));
664            let mut agg = unsafe { agg_col.agg_std(&self.groups, ddof) };
665            agg.rename(new_name);
666            cols.push(agg);
667        }
668        DataFrame::new(cols)
669    }
670
671    /// Aggregate grouped series and compute the number of values per group.
672    ///
673    /// # Example
674    ///
675    /// ```rust
676    /// # use polars_core::prelude::*;
677    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
678    ///     df.group_by(["date"])?.select(["temp"]).count()
679    /// }
680    /// ```
681    /// Returns:
682    ///
683    /// ```text
684    /// +------------+------------+
685    /// | date       | temp_count |
686    /// | ---        | ---        |
687    /// | Date       | u32        |
688    /// +============+============+
689    /// | 2020-08-23 | 1          |
690    /// +------------+------------+
691    /// | 2020-08-22 | 2          |
692    /// +------------+------------+
693    /// | 2020-08-21 | 2          |
694    /// +------------+------------+
695    /// ```
696    pub fn count(&self) -> PolarsResult<DataFrame> {
697        let (mut cols, agg_cols) = self.prepare_agg()?;
698
699        for agg_col in agg_cols {
700            let new_name = fmt_group_by_column(
701                agg_col.name().as_str(),
702                GroupByMethod::Count {
703                    include_nulls: true,
704                },
705            );
706            let mut ca = self.groups.group_count();
707            ca.rename(new_name);
708            cols.push(ca.into_column());
709        }
710        DataFrame::new(cols)
711    }
712
713    /// Get the group_by group indexes.
714    ///
715    /// # Example
716    ///
717    /// ```rust
718    /// # use polars_core::prelude::*;
719    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
720    ///     df.group_by(["date"])?.groups()
721    /// }
722    /// ```
723    /// Returns:
724    ///
725    /// ```text
726    /// +--------------+------------+
727    /// | date         | groups     |
728    /// | ---          | ---        |
729    /// | Date(days)   | list [u32] |
730    /// +==============+============+
731    /// | 2020-08-23   | "[3]"      |
732    /// +--------------+------------+
733    /// | 2020-08-22   | "[2, 4]"   |
734    /// +--------------+------------+
735    /// | 2020-08-21   | "[0, 1]"   |
736    /// +--------------+------------+
737    /// ```
738    pub fn groups(&self) -> PolarsResult<DataFrame> {
739        let mut cols = self.keys();
740        let mut column = self.groups.as_list_chunked();
741        let new_name = fmt_group_by_column("", GroupByMethod::Groups);
742        column.rename(new_name);
743        cols.push(column.into_column());
744        DataFrame::new(cols)
745    }
746
747    /// Aggregate the groups of the group_by operation into lists.
748    ///
749    /// # Example
750    ///
751    /// ```rust
752    /// # use polars_core::prelude::*;
753    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
754    ///     // GroupBy and aggregate to Lists
755    ///     df.group_by(["date"])?.select(["temp"]).agg_list()
756    /// }
757    /// ```
758    /// Returns:
759    ///
760    /// ```text
761    /// +------------+------------------------+
762    /// | date       | temp_agg_list          |
763    /// | ---        | ---                    |
764    /// | Date       | list [i32]             |
765    /// +============+========================+
766    /// | 2020-08-23 | "[Some(9)]"            |
767    /// +------------+------------------------+
768    /// | 2020-08-22 | "[Some(7), Some(1)]"   |
769    /// +------------+------------------------+
770    /// | 2020-08-21 | "[Some(20), Some(10)]" |
771    /// +------------+------------------------+
772    /// ```
773    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
774    pub fn agg_list(&self) -> PolarsResult<DataFrame> {
775        let (mut cols, agg_cols) = self.prepare_agg()?;
776        for agg_col in agg_cols {
777            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Implode);
778            let mut agg = unsafe { agg_col.agg_list(&self.groups) };
779            agg.rename(new_name);
780            cols.push(agg);
781        }
782        DataFrame::new(cols)
783    }
784
785    fn prepare_apply(&self) -> PolarsResult<DataFrame> {
786        polars_ensure!(self.df.height() > 0, ComputeError: "cannot group_by + apply on empty 'DataFrame'");
787        if let Some(agg) = &self.selected_agg {
788            if agg.is_empty() {
789                Ok(self.df.clone())
790            } else {
791                let mut new_cols = Vec::with_capacity(self.selected_keys.len() + agg.len());
792                new_cols.extend_from_slice(&self.selected_keys);
793                let cols = self.df.select_columns_impl(agg.as_slice())?;
794                new_cols.extend(cols);
795                Ok(unsafe { DataFrame::new_no_checks(self.df.height(), new_cols) })
796            }
797        } else {
798            Ok(self.df.clone())
799        }
800    }
801
802    /// Apply a closure over the groups as a new [`DataFrame`] in parallel.
803    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
804    pub fn par_apply<F>(&self, f: F) -> PolarsResult<DataFrame>
805    where
806        F: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
807    {
808        let df = self.prepare_apply()?;
809        let dfs = self
810            .get_groups()
811            .par_iter()
812            .map(|g| {
813                // SAFETY:
814                // groups are in bounds
815                let sub_df = unsafe { take_df(&df, g) };
816                f(sub_df)
817            })
818            .collect::<PolarsResult<Vec<_>>>()?;
819
820        let mut df = accumulate_dataframes_vertical(dfs)?;
821        df.as_single_chunk_par();
822        Ok(df)
823    }
824
825    /// Apply a closure over the groups as a new [`DataFrame`].
826    pub fn apply<F>(&self, mut f: F) -> PolarsResult<DataFrame>
827    where
828        F: FnMut(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
829    {
830        let df = self.prepare_apply()?;
831        let dfs = self
832            .get_groups()
833            .iter()
834            .map(|g| {
835                // SAFETY:
836                // groups are in bounds
837                let sub_df = unsafe { take_df(&df, g) };
838                f(sub_df)
839            })
840            .collect::<PolarsResult<Vec<_>>>()?;
841
842        let mut df = accumulate_dataframes_vertical(dfs)?;
843        df.as_single_chunk_par();
844        Ok(df)
845    }
846
847    pub fn sliced(mut self, slice: Option<(i64, usize)>) -> Self {
848        match slice {
849            None => self,
850            Some((offset, length)) => {
851                self.groups = (self.groups.slice(offset, length)).clone();
852                self.selected_keys = self.keys_sliced(slice);
853                self
854            },
855        }
856    }
857}
858
859unsafe fn take_df(df: &DataFrame, g: GroupsIndicator) -> DataFrame {
860    match g {
861        GroupsIndicator::Idx(idx) => df.take_slice_unchecked(idx.1),
862        GroupsIndicator::Slice([first, len]) => df.slice(first as i64, len as usize),
863    }
864}
865
866#[derive(Copy, Clone, Debug)]
867pub enum GroupByMethod {
868    Min,
869    NanMin,
870    Max,
871    NanMax,
872    Median,
873    Mean,
874    First,
875    Last,
876    Sum,
877    Groups,
878    NUnique,
879    Quantile(f64, QuantileMethod),
880    Count { include_nulls: bool },
881    Implode,
882    Std(u8),
883    Var(u8),
884}
885
886impl Display for GroupByMethod {
887    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
888        use GroupByMethod::*;
889        let s = match self {
890            Min => "min",
891            NanMin => "nan_min",
892            Max => "max",
893            NanMax => "nan_max",
894            Median => "median",
895            Mean => "mean",
896            First => "first",
897            Last => "last",
898            Sum => "sum",
899            Groups => "groups",
900            NUnique => "n_unique",
901            Quantile(_, _) => "quantile",
902            Count { .. } => "count",
903            Implode => "list",
904            Std(_) => "std",
905            Var(_) => "var",
906        };
907        write!(f, "{s}")
908    }
909}
910
911// Formatting functions used in eager and lazy code for renaming grouped columns
912pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> PlSmallStr {
913    use GroupByMethod::*;
914    match method {
915        Min => format_pl_smallstr!("{name}_min"),
916        Max => format_pl_smallstr!("{name}_max"),
917        NanMin => format_pl_smallstr!("{name}_nan_min"),
918        NanMax => format_pl_smallstr!("{name}_nan_max"),
919        Median => format_pl_smallstr!("{name}_median"),
920        Mean => format_pl_smallstr!("{name}_mean"),
921        First => format_pl_smallstr!("{name}_first"),
922        Last => format_pl_smallstr!("{name}_last"),
923        Sum => format_pl_smallstr!("{name}_sum"),
924        Groups => PlSmallStr::from_static("groups"),
925        NUnique => format_pl_smallstr!("{name}_n_unique"),
926        Count { .. } => format_pl_smallstr!("{name}_count"),
927        Implode => format_pl_smallstr!("{name}_agg_list"),
928        Quantile(quantile, _interpol) => format_pl_smallstr!("{name}_quantile_{quantile:.2}"),
929        Std(_) => format_pl_smallstr!("{name}_agg_std"),
930        Var(_) => format_pl_smallstr!("{name}_agg_var"),
931    }
932}
933
934#[cfg(test)]
935mod test {
936    use num_traits::FloatConst;
937
938    use crate::prelude::*;
939
940    #[test]
941    #[cfg(feature = "dtype-date")]
942    #[cfg_attr(miri, ignore)]
943    fn test_group_by() -> PolarsResult<()> {
944        let s0 = Column::new(
945            PlSmallStr::from_static("date"),
946            &[
947                "2020-08-21",
948                "2020-08-21",
949                "2020-08-22",
950                "2020-08-23",
951                "2020-08-22",
952            ],
953        );
954        let s1 = Column::new(PlSmallStr::from_static("temp"), [20, 10, 7, 9, 1]);
955        let s2 = Column::new(PlSmallStr::from_static("rain"), [0.2, 0.1, 0.3, 0.1, 0.01]);
956        let df = DataFrame::new(vec![s0, s1, s2]).unwrap();
957
958        let out = df.group_by_stable(["date"])?.select(["temp"]).count()?;
959        assert_eq!(
960            out.column("temp_count")?,
961            &Column::new(PlSmallStr::from_static("temp_count"), [2 as IdxSize, 2, 1])
962        );
963
964        // Use of deprecated mean() for testing purposes
965        #[allow(deprecated)]
966        // Select multiple
967        let out = df
968            .group_by_stable(["date"])?
969            .select(["temp", "rain"])
970            .mean()?;
971        assert_eq!(
972            out.column("temp_mean")?,
973            &Column::new(PlSmallStr::from_static("temp_mean"), [15.0f64, 4.0, 9.0])
974        );
975
976        // Use of deprecated `mean()` for testing purposes
977        #[allow(deprecated)]
978        // Group by multiple
979        let out = df
980            .group_by_stable(["date", "temp"])?
981            .select(["rain"])
982            .mean()?;
983        assert!(out.column("rain_mean").is_ok());
984
985        // Use of deprecated `sum()` for testing purposes
986        #[allow(deprecated)]
987        let out = df.group_by_stable(["date"])?.select(["temp"]).sum()?;
988        assert_eq!(
989            out.column("temp_sum")?,
990            &Column::new(PlSmallStr::from_static("temp_sum"), [30, 8, 9])
991        );
992
993        // Use of deprecated `n_unique()` for testing purposes
994        #[allow(deprecated)]
995        // implicit select all and only aggregate on methods that support that aggregation
996        let gb = df.group_by(["date"]).unwrap().n_unique().unwrap();
997        // check the group by column is filtered out.
998        assert_eq!(gb.width(), 3);
999        Ok(())
1000    }
1001
1002    #[test]
1003    #[cfg_attr(miri, ignore)]
1004    fn test_static_group_by_by_12_columns() {
1005        // Build GroupBy DataFrame.
1006        let s0 = Column::new("G1".into(), ["A", "A", "B", "B", "C"].as_ref());
1007        let s1 = Column::new("N".into(), [1, 2, 2, 4, 2].as_ref());
1008        let s2 = Column::new("G2".into(), ["k", "l", "m", "m", "l"].as_ref());
1009        let s3 = Column::new("G3".into(), ["a", "b", "c", "c", "d"].as_ref());
1010        let s4 = Column::new("G4".into(), ["1", "2", "3", "3", "4"].as_ref());
1011        let s5 = Column::new("G5".into(), ["X", "Y", "Z", "Z", "W"].as_ref());
1012        let s6 = Column::new("G6".into(), [false, true, true, true, false].as_ref());
1013        let s7 = Column::new("G7".into(), ["r", "x", "q", "q", "o"].as_ref());
1014        let s8 = Column::new("G8".into(), ["R", "X", "Q", "Q", "O"].as_ref());
1015        let s9 = Column::new("G9".into(), [1, 2, 3, 3, 4].as_ref());
1016        let s10 = Column::new("G10".into(), [".", "!", "?", "?", "/"].as_ref());
1017        let s11 = Column::new("G11".into(), ["(", ")", "@", "@", "$"].as_ref());
1018        let s12 = Column::new("G12".into(), ["-", "_", ";", ";", ","].as_ref());
1019
1020        let df =
1021            DataFrame::new(vec![s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]).unwrap();
1022
1023        // Use of deprecated `sum()` for testing purposes
1024        #[allow(deprecated)]
1025        let adf = df
1026            .group_by([
1027                "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12",
1028            ])
1029            .unwrap()
1030            .select(["N"])
1031            .sum()
1032            .unwrap();
1033
1034        assert_eq!(
1035            Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1036            &[Some(1), Some(2), Some(2), Some(6)]
1037        );
1038    }
1039
1040    #[test]
1041    #[cfg_attr(miri, ignore)]
1042    fn test_dynamic_group_by_by_13_columns() {
1043        // The content for every group_by series.
1044        let series_content = ["A", "A", "B", "B", "C"];
1045
1046        // The name of every group_by series.
1047        let series_names = [
1048            "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12", "G13",
1049        ];
1050
1051        // Vector to contain every series.
1052        let mut columns = Vec::with_capacity(14);
1053
1054        // Create a series for every group name.
1055        for series_name in series_names {
1056            let group_columns = Column::new(series_name.into(), series_content.as_ref());
1057            columns.push(group_columns);
1058        }
1059
1060        // Create a series for the aggregation column.
1061        let agg_series = Column::new("N".into(), [1, 2, 3, 3, 4].as_ref());
1062        columns.push(agg_series);
1063
1064        // Create the dataframe with the computed series.
1065        let df = DataFrame::new(columns).unwrap();
1066
1067        // Use of deprecated `sum()` for testing purposes
1068        #[allow(deprecated)]
1069        // Compute the aggregated DataFrame by the 13 columns defined in `series_names`.
1070        let adf = df
1071            .group_by(series_names)
1072            .unwrap()
1073            .select(["N"])
1074            .sum()
1075            .unwrap();
1076
1077        // Check that the results of the group-by are correct. The content of every column
1078        // is equal, then, the grouped columns shall be equal and in the same order.
1079        for series_name in &series_names {
1080            assert_eq!(
1081                Vec::from(&adf.column(series_name).unwrap().str().unwrap().sort(false)),
1082                &[Some("A"), Some("B"), Some("C")]
1083            );
1084        }
1085
1086        // Check the aggregated column is the expected one.
1087        assert_eq!(
1088            Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1089            &[Some(3), Some(4), Some(6)]
1090        );
1091    }
1092
1093    #[test]
1094    #[cfg_attr(miri, ignore)]
1095    fn test_group_by_floats() {
1096        let df = df! {"flt" => [1., 1., 2., 2., 3.],
1097                    "val" => [1, 1, 1, 1, 1]
1098        }
1099        .unwrap();
1100        // Use of deprecated `sum()` for testing purposes
1101        #[allow(deprecated)]
1102        let res = df.group_by(["flt"]).unwrap().sum().unwrap();
1103        let res = res.sort(["flt"], SortMultipleOptions::default()).unwrap();
1104        assert_eq!(
1105            Vec::from(res.column("val_sum").unwrap().i32().unwrap()),
1106            &[Some(2), Some(2), Some(1)]
1107        );
1108    }
1109
1110    #[test]
1111    #[cfg_attr(miri, ignore)]
1112    #[cfg(feature = "dtype-categorical")]
1113    fn test_group_by_categorical() {
1114        let mut df = df! {"foo" => ["a", "a", "b", "b", "c"],
1115                    "ham" => ["a", "a", "b", "b", "c"],
1116                    "bar" => [1, 1, 1, 1, 1]
1117        }
1118        .unwrap();
1119
1120        df.apply("foo", |s| {
1121            s.cast(&DataType::Categorical(None, Default::default()))
1122                .unwrap()
1123        })
1124        .unwrap();
1125
1126        // Use of deprecated `sum()` for testing purposes
1127        #[allow(deprecated)]
1128        // check multiple keys and categorical
1129        let res = df
1130            .group_by_stable(["foo", "ham"])
1131            .unwrap()
1132            .select(["bar"])
1133            .sum()
1134            .unwrap();
1135
1136        assert_eq!(
1137            Vec::from(
1138                res.column("bar_sum")
1139                    .unwrap()
1140                    .as_materialized_series()
1141                    .i32()
1142                    .unwrap()
1143            ),
1144            &[Some(2), Some(2), Some(1)]
1145        );
1146    }
1147
1148    #[test]
1149    #[cfg_attr(miri, ignore)]
1150    fn test_group_by_null_handling() -> PolarsResult<()> {
1151        let df = df!(
1152            "a" => ["a", "a", "a", "b", "b"],
1153            "b" => [Some(1), Some(2), None, None, Some(1)]
1154        )?;
1155        // Use of deprecated `mean()` for testing purposes
1156        #[allow(deprecated)]
1157        let out = df.group_by_stable(["a"])?.mean()?;
1158
1159        assert_eq!(
1160            Vec::from(out.column("b_mean")?.as_materialized_series().f64()?),
1161            &[Some(1.5), Some(1.0)]
1162        );
1163        Ok(())
1164    }
1165
1166    #[test]
1167    #[cfg_attr(miri, ignore)]
1168    fn test_group_by_var() -> PolarsResult<()> {
1169        // check variance and proper coercion to f64
1170        let df = df![
1171            "g" => ["foo", "foo", "bar"],
1172            "flt" => [1.0, 2.0, 3.0],
1173            "int" => [1, 2, 3]
1174        ]?;
1175
1176        // Use of deprecated `sum()` for testing purposes
1177        #[allow(deprecated)]
1178        let out = df.group_by_stable(["g"])?.select(["int"]).var(1)?;
1179
1180        assert_eq!(out.column("int_agg_var")?.f64()?.get(0), Some(0.5));
1181        // Use of deprecated `std()` for testing purposes
1182        #[allow(deprecated)]
1183        let out = df.group_by_stable(["g"])?.select(["int"]).std(1)?;
1184        let val = out.column("int_agg_std")?.f64()?.get(0).unwrap();
1185        let expected = f64::FRAC_1_SQRT_2();
1186        assert!((val - expected).abs() < 0.000001);
1187        Ok(())
1188    }
1189
1190    #[test]
1191    #[cfg_attr(miri, ignore)]
1192    #[cfg(feature = "dtype-categorical")]
1193    fn test_group_by_null_group() -> PolarsResult<()> {
1194        // check if null is own group
1195        let mut df = df![
1196            "g" => [Some("foo"), Some("foo"), Some("bar"), None, None],
1197            "flt" => [1.0, 2.0, 3.0, 1.0, 1.0],
1198            "int" => [1, 2, 3, 1, 1]
1199        ]?;
1200
1201        df.try_apply("g", |s| {
1202            s.cast(&DataType::Categorical(None, Default::default()))
1203        })?;
1204
1205        // Use of deprecated `sum()` for testing purposes
1206        #[allow(deprecated)]
1207        let _ = df.group_by(["g"])?.sum()?;
1208        Ok(())
1209    }
1210}