polars_core/frame/group_by/
mod.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3
4use num_traits::NumCast;
5use polars_compute::rolling::QuantileMethod;
6use polars_utils::format_pl_smallstr;
7use polars_utils::hashing::DirtyHash;
8use rayon::prelude::*;
9
10use self::hashing::*;
11use crate::POOL;
12use crate::prelude::*;
13use crate::utils::{_set_partition_size, accumulate_dataframes_vertical};
14
15pub mod aggregations;
16pub mod expr;
17pub(crate) mod hashing;
18mod into_groups;
19mod position;
20
21pub use into_groups::*;
22pub use position::*;
23
24use crate::chunked_array::ops::row_encode::{
25    encode_rows_unordered, encode_rows_vertical_par_unordered,
26};
27
28impl DataFrame {
29    pub fn group_by_with_series(
30        &self,
31        mut by: Vec<Column>,
32        multithreaded: bool,
33        sorted: bool,
34    ) -> PolarsResult<GroupBy<'_>> {
35        polars_ensure!(
36            !by.is_empty(),
37            ComputeError: "at least one key is required in a group_by operation"
38        );
39
40        // Ensure all 'by' columns have the same common_height
41        // The condition self.width > 0 ensures we can still call this on a
42        // dummy dataframe where we provide the keys
43        let common_height = if self.width() > 0 {
44            self.height()
45        } else {
46            by.iter().map(|s| s.len()).max().expect("at least 1 key")
47        };
48        for by_key in by.iter_mut() {
49            if by_key.len() != common_height {
50                polars_ensure!(
51                    by_key.len() == 1,
52                    ShapeMismatch: "series used as keys should have the same length as the DataFrame"
53                );
54                *by_key = by_key.new_from_index(0, common_height)
55            }
56        }
57
58        let groups = if by.len() == 1 {
59            let column = &by[0];
60            column
61                .as_materialized_series()
62                .group_tuples(multithreaded, sorted)
63        } else if by.iter().any(|s| s.dtype().is_object()) {
64            #[cfg(feature = "object")]
65            {
66                let mut df = DataFrame::new(self.height(), by.clone()).unwrap();
67                let n = df.height();
68                let rows = df.to_av_rows();
69                let iter = (0..n).map(|i| rows.get(i));
70                Ok(group_by(iter, sorted))
71            }
72            #[cfg(not(feature = "object"))]
73            {
74                unreachable!()
75            }
76        } else {
77            // Skip null dtype.
78            let by = by
79                .iter()
80                .filter(|s| !s.dtype().is_null())
81                .cloned()
82                .collect::<Vec<_>>();
83            if by.is_empty() {
84                let groups = if self.height() == 0 {
85                    vec![]
86                } else {
87                    vec![[0, self.height() as IdxSize]]
88                };
89
90                Ok(GroupsType::new_slice(groups, false, true))
91            } else {
92                let rows = if multithreaded {
93                    encode_rows_vertical_par_unordered(&by)
94                } else {
95                    encode_rows_unordered(&by)
96                }?
97                .into_series();
98                rows.group_tuples(multithreaded, sorted)
99            }
100        };
101        Ok(GroupBy::new(self, by, groups?.into_sliceable(), None))
102    }
103
104    /// Group DataFrame using a Series column.
105    ///
106    /// # Example
107    ///
108    /// ```
109    /// use polars_core::prelude::*;
110    /// fn group_by_sum(df: &DataFrame) -> PolarsResult<DataFrame> {
111    ///     df.group_by(["column_name"])?
112    ///     .select(["agg_column_name"])
113    ///     .sum()
114    /// }
115    /// ```
116    pub fn group_by<I, S>(&self, by: I) -> PolarsResult<GroupBy<'_>>
117    where
118        I: IntoIterator<Item = S>,
119        S: AsRef<str>,
120    {
121        let selected_keys = self.select_to_vec(by)?;
122        self.group_by_with_series(selected_keys, true, false)
123    }
124
125    /// Group DataFrame using a Series column.
126    /// The groups are ordered by their smallest row index.
127    pub fn group_by_stable<I, S>(&self, by: I) -> PolarsResult<GroupBy<'_>>
128    where
129        I: IntoIterator<Item = S>,
130        S: AsRef<str>,
131    {
132        let selected_keys = self.select_to_vec(by)?;
133        self.group_by_with_series(selected_keys, true, true)
134    }
135}
136
137/// Returned by a group_by operation on a DataFrame. This struct supports
138/// several aggregations.
139///
140/// Until described otherwise, the examples in this struct are performed on the following DataFrame:
141///
142/// ```ignore
143/// use polars_core::prelude::*;
144///
145/// let dates = &[
146/// "2020-08-21",
147/// "2020-08-21",
148/// "2020-08-22",
149/// "2020-08-23",
150/// "2020-08-22",
151/// ];
152/// // date format
153/// let fmt = "%Y-%m-%d";
154/// // create date series
155/// let s0 = DateChunked::parse_from_str_slice("date", dates, fmt)
156///         .into_series();
157/// // create temperature series
158/// let s1 = Series::new("temp".into(), [20, 10, 7, 9, 1]);
159/// // create rain series
160/// let s2 = Series::new("rain".into(), [0.2, 0.1, 0.3, 0.1, 0.01]);
161/// // create a new DataFrame
162/// let df = DataFrame::new_infer_height(vec![s0, s1, s2]).unwrap();
163/// println!("{:?}", df);
164/// ```
165///
166/// Outputs:
167///
168/// ```text
169/// +------------+------+------+
170/// | date       | temp | rain |
171/// | ---        | ---  | ---  |
172/// | Date       | i32  | f64  |
173/// +============+======+======+
174/// | 2020-08-21 | 20   | 0.2  |
175/// +------------+------+------+
176/// | 2020-08-21 | 10   | 0.1  |
177/// +------------+------+------+
178/// | 2020-08-22 | 7    | 0.3  |
179/// +------------+------+------+
180/// | 2020-08-23 | 9    | 0.1  |
181/// +------------+------+------+
182/// | 2020-08-22 | 1    | 0.01 |
183/// +------------+------+------+
184/// ```
185///
186#[derive(Debug, Clone)]
187pub struct GroupBy<'a> {
188    pub df: &'a DataFrame,
189    pub(crate) selected_keys: Vec<Column>,
190    // [first idx, [other idx]]
191    groups: GroupPositions,
192    // columns selected for aggregation
193    pub(crate) selected_agg: Option<Vec<PlSmallStr>>,
194}
195
196impl<'a> GroupBy<'a> {
197    pub fn new(
198        df: &'a DataFrame,
199        by: Vec<Column>,
200        groups: GroupPositions,
201        selected_agg: Option<Vec<PlSmallStr>>,
202    ) -> Self {
203        GroupBy {
204            df,
205            selected_keys: by,
206            groups,
207            selected_agg,
208        }
209    }
210
211    /// Select the column(s) that should be aggregated.
212    /// You can select a single column or a slice of columns.
213    ///
214    /// Note that making a selection with this method is not required. If you
215    /// skip it all columns (except for the keys) will be selected for aggregation.
216    #[must_use]
217    pub fn select<I: IntoIterator<Item = S>, S: Into<PlSmallStr>>(mut self, selection: I) -> Self {
218        self.selected_agg = Some(selection.into_iter().map(|s| s.into()).collect());
219        self
220    }
221
222    /// Get the internal representation of the GroupBy operation.
223    /// The Vec returned contains:
224    ///     (first_idx, [`Vec<indexes>`])
225    ///     Where second value in the tuple is a vector with all matching indexes.
226    pub fn get_groups(&self) -> &GroupPositions {
227        &self.groups
228    }
229
230    /// Get the internal representation of the GroupBy operation.
231    /// The Vec returned contains:
232    ///     (first_idx, [`Vec<indexes>`])
233    ///     Where second value in the tuple is a vector with all matching indexes.
234    ///
235    /// # Safety
236    /// Groups should always be in bounds of the `DataFrame` hold by this [`GroupBy`].
237    /// If you mutate it, you must hold that invariant.
238    pub unsafe fn get_groups_mut(&mut self) -> &mut GroupPositions {
239        &mut self.groups
240    }
241
242    pub fn into_groups(self) -> GroupPositions {
243        self.groups
244    }
245
246    pub fn keys_sliced(&self, slice: Option<(i64, usize)>) -> Vec<Column> {
247        #[allow(unused_assignments)]
248        // needed to keep the lifetimes valid for this scope
249        let mut groups_owned = None;
250
251        let groups = if let Some((offset, len)) = slice {
252            groups_owned = Some(self.groups.slice(offset, len));
253            groups_owned.as_deref().unwrap()
254        } else {
255            &self.groups
256        };
257        POOL.install(|| {
258            self.selected_keys
259                .par_iter()
260                .map(Column::as_materialized_series)
261                .map(|s| {
262                    match groups {
263                        GroupsType::Idx(groups) => {
264                            // SAFETY: groups are always in bounds.
265                            let mut out = unsafe { s.take_slice_unchecked(groups.first()) };
266                            if groups.sorted {
267                                out.set_sorted_flag(s.is_sorted_flag());
268                            };
269                            out
270                        },
271                        GroupsType::Slice {
272                            groups,
273                            overlapping,
274                            monotonic: _,
275                        } => {
276                            if *overlapping && !groups.is_empty() {
277                                // Groups can be sliced.
278                                let offset = groups[0][0];
279                                let [upper_offset, upper_len] = groups[groups.len() - 1];
280                                return s.slice(
281                                    offset as i64,
282                                    ((upper_offset + upper_len) - offset) as usize,
283                                );
284                            }
285
286                            let indices = groups
287                                .iter()
288                                .map(|&[first, _len]| first)
289                                .collect_ca(PlSmallStr::EMPTY);
290                            // SAFETY: groups are always in bounds.
291                            let mut out = unsafe { s.take_unchecked(&indices) };
292                            // Sliced groups are always in order of discovery.
293                            out.set_sorted_flag(s.is_sorted_flag());
294                            out
295                        },
296                    }
297                })
298                .map(Column::from)
299                .collect()
300        })
301    }
302
303    pub fn keys(&self) -> Vec<Column> {
304        self.keys_sliced(None)
305    }
306
307    fn prepare_agg(&self) -> PolarsResult<(Vec<Column>, Vec<Column>)> {
308        let keys = self.keys();
309
310        let agg_col = match &self.selected_agg {
311            Some(selection) => self.df.select_to_vec(selection),
312            None => {
313                let by: Vec<_> = self.selected_keys.iter().map(|s| s.name()).collect();
314                let selection = self
315                    .df
316                    .columns()
317                    .iter()
318                    .map(|s| s.name())
319                    .filter(|a| !by.contains(a))
320                    .cloned()
321                    .collect::<Vec<_>>();
322
323                self.df.select_to_vec(selection.as_slice())
324            },
325        }?;
326
327        Ok((keys, agg_col))
328    }
329
330    /// Aggregate grouped series and compute the mean per group.
331    ///
332    /// # Example
333    ///
334    /// ```rust
335    /// # use polars_core::prelude::*;
336    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
337    ///     df.group_by(["date"])?.select(["temp", "rain"]).mean()
338    /// }
339    /// ```
340    /// Returns:
341    ///
342    /// ```text
343    /// +------------+-----------+-----------+
344    /// | date       | temp_mean | rain_mean |
345    /// | ---        | ---       | ---       |
346    /// | Date       | f64       | f64       |
347    /// +============+===========+===========+
348    /// | 2020-08-23 | 9         | 0.1       |
349    /// +------------+-----------+-----------+
350    /// | 2020-08-22 | 4         | 0.155     |
351    /// +------------+-----------+-----------+
352    /// | 2020-08-21 | 15        | 0.15      |
353    /// +------------+-----------+-----------+
354    /// ```
355    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
356    pub fn mean(&self) -> PolarsResult<DataFrame> {
357        let (mut cols, agg_cols) = self.prepare_agg()?;
358
359        for agg_col in agg_cols {
360            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Mean);
361            let mut agg = unsafe { agg_col.agg_mean(&self.groups) };
362            agg.rename(new_name);
363            cols.push(agg);
364        }
365
366        DataFrame::new_infer_height(cols)
367    }
368
369    /// Aggregate grouped series and compute the sum per group.
370    ///
371    /// # Example
372    ///
373    /// ```rust
374    /// # use polars_core::prelude::*;
375    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
376    ///     df.group_by(["date"])?.select(["temp"]).sum()
377    /// }
378    /// ```
379    /// Returns:
380    ///
381    /// ```text
382    /// +------------+----------+
383    /// | date       | temp_sum |
384    /// | ---        | ---      |
385    /// | Date       | i32      |
386    /// +============+==========+
387    /// | 2020-08-23 | 9        |
388    /// +------------+----------+
389    /// | 2020-08-22 | 8        |
390    /// +------------+----------+
391    /// | 2020-08-21 | 30       |
392    /// +------------+----------+
393    /// ```
394    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
395    pub fn sum(&self) -> PolarsResult<DataFrame> {
396        let (mut cols, agg_cols) = self.prepare_agg()?;
397
398        for agg_col in agg_cols {
399            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Sum);
400            let mut agg = unsafe { agg_col.agg_sum(&self.groups) };
401            agg.rename(new_name);
402            cols.push(agg);
403        }
404        DataFrame::new_infer_height(cols)
405    }
406
407    /// Aggregate grouped series and compute the minimal value per group.
408    ///
409    /// # Example
410    ///
411    /// ```rust
412    /// # use polars_core::prelude::*;
413    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
414    ///     df.group_by(["date"])?.select(["temp"]).min()
415    /// }
416    /// ```
417    /// Returns:
418    ///
419    /// ```text
420    /// +------------+----------+
421    /// | date       | temp_min |
422    /// | ---        | ---      |
423    /// | Date       | i32      |
424    /// +============+==========+
425    /// | 2020-08-23 | 9        |
426    /// +------------+----------+
427    /// | 2020-08-22 | 1        |
428    /// +------------+----------+
429    /// | 2020-08-21 | 10       |
430    /// +------------+----------+
431    /// ```
432    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
433    pub fn min(&self) -> PolarsResult<DataFrame> {
434        let (mut cols, agg_cols) = self.prepare_agg()?;
435        for agg_col in agg_cols {
436            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Min);
437            let mut agg = unsafe { agg_col.agg_min(&self.groups) };
438            agg.rename(new_name);
439            cols.push(agg);
440        }
441        DataFrame::new_infer_height(cols)
442    }
443
444    /// Aggregate grouped series and compute the maximum value per group.
445    ///
446    /// # Example
447    ///
448    /// ```rust
449    /// # use polars_core::prelude::*;
450    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
451    ///     df.group_by(["date"])?.select(["temp"]).max()
452    /// }
453    /// ```
454    /// Returns:
455    ///
456    /// ```text
457    /// +------------+----------+
458    /// | date       | temp_max |
459    /// | ---        | ---      |
460    /// | Date       | i32      |
461    /// +============+==========+
462    /// | 2020-08-23 | 9        |
463    /// +------------+----------+
464    /// | 2020-08-22 | 7        |
465    /// +------------+----------+
466    /// | 2020-08-21 | 20       |
467    /// +------------+----------+
468    /// ```
469    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
470    pub fn max(&self) -> PolarsResult<DataFrame> {
471        let (mut cols, agg_cols) = self.prepare_agg()?;
472        for agg_col in agg_cols {
473            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Max);
474            let mut agg = unsafe { agg_col.agg_max(&self.groups) };
475            agg.rename(new_name);
476            cols.push(agg);
477        }
478        DataFrame::new_infer_height(cols)
479    }
480
481    /// Aggregate grouped `Series` and find the first value per group.
482    ///
483    /// # Example
484    ///
485    /// ```rust
486    /// # use polars_core::prelude::*;
487    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
488    ///     df.group_by(["date"])?.select(["temp"]).first()
489    /// }
490    /// ```
491    /// Returns:
492    ///
493    /// ```text
494    /// +------------+------------+
495    /// | date       | temp_first |
496    /// | ---        | ---        |
497    /// | Date       | i32        |
498    /// +============+============+
499    /// | 2020-08-23 | 9          |
500    /// +------------+------------+
501    /// | 2020-08-22 | 7          |
502    /// +------------+------------+
503    /// | 2020-08-21 | 20         |
504    /// +------------+------------+
505    /// ```
506    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
507    pub fn first(&self) -> PolarsResult<DataFrame> {
508        let (mut cols, agg_cols) = self.prepare_agg()?;
509        for agg_col in agg_cols {
510            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::First);
511            let mut agg = unsafe { agg_col.agg_first(&self.groups) };
512            agg.rename(new_name);
513            cols.push(agg);
514        }
515        DataFrame::new_infer_height(cols)
516    }
517
518    /// Aggregate grouped `Series` and return the last value per group.
519    ///
520    /// # Example
521    ///
522    /// ```rust
523    /// # use polars_core::prelude::*;
524    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
525    ///     df.group_by(["date"])?.select(["temp"]).last()
526    /// }
527    /// ```
528    /// Returns:
529    ///
530    /// ```text
531    /// +------------+------------+
532    /// | date       | temp_last |
533    /// | ---        | ---        |
534    /// | Date       | i32        |
535    /// +============+============+
536    /// | 2020-08-23 | 9          |
537    /// +------------+------------+
538    /// | 2020-08-22 | 1          |
539    /// +------------+------------+
540    /// | 2020-08-21 | 10         |
541    /// +------------+------------+
542    /// ```
543    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
544    pub fn last(&self) -> PolarsResult<DataFrame> {
545        let (mut cols, agg_cols) = self.prepare_agg()?;
546        for agg_col in agg_cols {
547            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Last);
548            let mut agg = unsafe { agg_col.agg_last(&self.groups) };
549            agg.rename(new_name);
550            cols.push(agg);
551        }
552        DataFrame::new_infer_height(cols)
553    }
554
555    /// Aggregate grouped `Series` by counting the number of unique values.
556    ///
557    /// # Example
558    ///
559    /// ```rust
560    /// # use polars_core::prelude::*;
561    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
562    ///     df.group_by(["date"])?.select(["temp"]).n_unique()
563    /// }
564    /// ```
565    /// Returns:
566    ///
567    /// ```text
568    /// +------------+---------------+
569    /// | date       | temp_n_unique |
570    /// | ---        | ---           |
571    /// | Date       | u32           |
572    /// +============+===============+
573    /// | 2020-08-23 | 1             |
574    /// +------------+---------------+
575    /// | 2020-08-22 | 2             |
576    /// +------------+---------------+
577    /// | 2020-08-21 | 2             |
578    /// +------------+---------------+
579    /// ```
580    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
581    pub fn n_unique(&self) -> PolarsResult<DataFrame> {
582        let (mut cols, agg_cols) = self.prepare_agg()?;
583        for agg_col in agg_cols {
584            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::NUnique);
585            let mut agg = unsafe { agg_col.agg_n_unique(&self.groups) };
586            agg.rename(new_name);
587            cols.push(agg);
588        }
589        DataFrame::new_infer_height(cols)
590    }
591
592    /// Aggregate grouped [`Series`] and determine the quantile per group.
593    ///
594    /// # Example
595    ///
596    /// ```rust
597    /// # use polars_core::prelude::*;
598    ///
599    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
600    ///     df.group_by(["date"])?.select(["temp"]).quantile(0.2, QuantileMethod::default())
601    /// }
602    /// ```
603    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
604    pub fn quantile(&self, quantile: f64, method: QuantileMethod) -> PolarsResult<DataFrame> {
605        polars_ensure!(
606            (0.0..=1.0).contains(&quantile),
607            ComputeError: "`quantile` should be within 0.0 and 1.0"
608        );
609        let (mut cols, agg_cols) = self.prepare_agg()?;
610        for agg_col in agg_cols {
611            let new_name = fmt_group_by_column(
612                agg_col.name().as_str(),
613                GroupByMethod::Quantile(quantile, method),
614            );
615            let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, method) };
616            agg.rename(new_name);
617            cols.push(agg);
618        }
619        DataFrame::new_infer_height(cols)
620    }
621
622    /// Aggregate grouped [`Series`] and determine the median per group.
623    ///
624    /// # Example
625    ///
626    /// ```rust
627    /// # use polars_core::prelude::*;
628    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
629    ///     df.group_by(["date"])?.select(["temp"]).median()
630    /// }
631    /// ```
632    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
633    pub fn median(&self) -> PolarsResult<DataFrame> {
634        let (mut cols, agg_cols) = self.prepare_agg()?;
635        for agg_col in agg_cols {
636            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Median);
637            let mut agg = unsafe { agg_col.agg_median(&self.groups) };
638            agg.rename(new_name);
639            cols.push(agg);
640        }
641        DataFrame::new_infer_height(cols)
642    }
643
644    /// Aggregate grouped [`Series`] and determine the variance per group.
645    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
646    pub fn var(&self, ddof: u8) -> PolarsResult<DataFrame> {
647        let (mut cols, agg_cols) = self.prepare_agg()?;
648        for agg_col in agg_cols {
649            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Var(ddof));
650            let mut agg = unsafe { agg_col.agg_var(&self.groups, ddof) };
651            agg.rename(new_name);
652            cols.push(agg);
653        }
654        DataFrame::new_infer_height(cols)
655    }
656
657    /// Aggregate grouped [`Series`] and determine the standard deviation per group.
658    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
659    pub fn std(&self, ddof: u8) -> PolarsResult<DataFrame> {
660        let (mut cols, agg_cols) = self.prepare_agg()?;
661        for agg_col in agg_cols {
662            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Std(ddof));
663            let mut agg = unsafe { agg_col.agg_std(&self.groups, ddof) };
664            agg.rename(new_name);
665            cols.push(agg);
666        }
667        DataFrame::new_infer_height(cols)
668    }
669
670    /// Aggregate grouped series and compute the number of values per group.
671    ///
672    /// # Example
673    ///
674    /// ```rust
675    /// # use polars_core::prelude::*;
676    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
677    ///     df.group_by(["date"])?.select(["temp"]).count()
678    /// }
679    /// ```
680    /// Returns:
681    ///
682    /// ```text
683    /// +------------+------------+
684    /// | date       | temp_count |
685    /// | ---        | ---        |
686    /// | Date       | u32        |
687    /// +============+============+
688    /// | 2020-08-23 | 1          |
689    /// +------------+------------+
690    /// | 2020-08-22 | 2          |
691    /// +------------+------------+
692    /// | 2020-08-21 | 2          |
693    /// +------------+------------+
694    /// ```
695    pub fn count(&self) -> PolarsResult<DataFrame> {
696        let (mut cols, agg_cols) = self.prepare_agg()?;
697
698        for agg_col in agg_cols {
699            let new_name = fmt_group_by_column(
700                agg_col.name().as_str(),
701                GroupByMethod::Count {
702                    include_nulls: true,
703                },
704            );
705            let mut ca = self.groups.group_count();
706            ca.rename(new_name);
707            cols.push(ca.into_column());
708        }
709        DataFrame::new_infer_height(cols)
710    }
711
712    /// Get the group_by group indexes.
713    ///
714    /// # Example
715    ///
716    /// ```rust
717    /// # use polars_core::prelude::*;
718    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
719    ///     df.group_by(["date"])?.groups()
720    /// }
721    /// ```
722    /// Returns:
723    ///
724    /// ```text
725    /// +--------------+------------+
726    /// | date         | groups     |
727    /// | ---          | ---        |
728    /// | Date(days)   | list [u32] |
729    /// +==============+============+
730    /// | 2020-08-23   | "[3]"      |
731    /// +--------------+------------+
732    /// | 2020-08-22   | "[2, 4]"   |
733    /// +--------------+------------+
734    /// | 2020-08-21   | "[0, 1]"   |
735    /// +--------------+------------+
736    /// ```
737    pub fn groups(&self) -> PolarsResult<DataFrame> {
738        let mut cols = self.keys();
739        let mut column = self.groups.as_list_chunked();
740        let new_name = fmt_group_by_column("", GroupByMethod::Groups);
741        column.rename(new_name);
742        cols.push(column.into_column());
743        DataFrame::new_infer_height(cols)
744    }
745
746    /// Aggregate the groups of the group_by operation into lists.
747    ///
748    /// # Example
749    ///
750    /// ```rust
751    /// # use polars_core::prelude::*;
752    /// fn example(df: DataFrame) -> PolarsResult<DataFrame> {
753    ///     // GroupBy and aggregate to Lists
754    ///     df.group_by(["date"])?.select(["temp"]).agg_list()
755    /// }
756    /// ```
757    /// Returns:
758    ///
759    /// ```text
760    /// +------------+------------------------+
761    /// | date       | temp_agg_list          |
762    /// | ---        | ---                    |
763    /// | Date       | list [i32]             |
764    /// +============+========================+
765    /// | 2020-08-23 | "[Some(9)]"            |
766    /// +------------+------------------------+
767    /// | 2020-08-22 | "[Some(7), Some(1)]"   |
768    /// +------------+------------------------+
769    /// | 2020-08-21 | "[Some(20), Some(10)]" |
770    /// +------------+------------------------+
771    /// ```
772    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
773    pub fn agg_list(&self) -> PolarsResult<DataFrame> {
774        let (mut cols, agg_cols) = self.prepare_agg()?;
775        for agg_col in agg_cols {
776            let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Implode);
777            let mut agg = unsafe { agg_col.agg_list(&self.groups) };
778            agg.rename(new_name);
779            cols.push(agg);
780        }
781        DataFrame::new_infer_height(cols)
782    }
783
784    fn prepare_apply(&self) -> PolarsResult<DataFrame> {
785        polars_ensure!(self.df.height() > 0, ComputeError: "cannot group_by + apply on empty 'DataFrame'");
786        if let Some(agg) = &self.selected_agg {
787            if agg.is_empty() {
788                Ok(self.df.clone())
789            } else {
790                let mut new_cols = Vec::with_capacity(self.selected_keys.len() + agg.len());
791                new_cols.extend_from_slice(&self.selected_keys);
792                let cols = self.df.select_to_vec(agg.as_slice())?;
793                new_cols.extend(cols);
794                Ok(unsafe { DataFrame::new_unchecked(self.df.height(), new_cols) })
795            }
796        } else {
797            Ok(self.df.clone())
798        }
799    }
800
801    /// Apply a closure over the groups as a new [`DataFrame`] in parallel.
802    #[deprecated(since = "0.24.1", note = "use polars.lazy aggregations")]
803    pub fn par_apply<F>(&self, f: F) -> PolarsResult<DataFrame>
804    where
805        F: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
806    {
807        let df = self.prepare_apply()?;
808        let dfs = self
809            .get_groups()
810            .par_iter()
811            .map(|g| {
812                // SAFETY:
813                // groups are in bounds
814                let sub_df = unsafe { take_df(&df, g) };
815                f(sub_df)
816            })
817            .collect::<PolarsResult<Vec<_>>>()?;
818
819        let mut df = accumulate_dataframes_vertical(dfs)?;
820        df.rechunk_mut_par();
821        Ok(df)
822    }
823
824    /// Apply a closure over the groups as a new [`DataFrame`].
825    pub fn apply<F>(&self, f: F) -> PolarsResult<DataFrame>
826    where
827        F: FnMut(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
828    {
829        self.apply_sliced(None, f)
830    }
831
832    pub fn apply_sliced<F>(&self, slice: Option<(i64, usize)>, mut f: F) -> PolarsResult<DataFrame>
833    where
834        F: FnMut(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
835    {
836        let df = self.prepare_apply()?;
837        let max_height = if let Some((offset, len)) = slice {
838            offset.try_into().unwrap_or(usize::MAX).saturating_add(len)
839        } else {
840            usize::MAX
841        };
842        let mut height = 0;
843        let mut dfs = Vec::with_capacity(self.get_groups().len());
844        for g in self.get_groups().iter() {
845            // SAFETY: groups are in bounds.
846            let sub_df = unsafe { take_df(&df, g) };
847            let df = f(sub_df)?;
848            height += df.height();
849            dfs.push(df);
850
851            // Even if max_height is zero we need at least one df, so check
852            // after first push.
853            if height >= max_height {
854                break;
855            }
856        }
857
858        let mut df = accumulate_dataframes_vertical(dfs)?;
859        if let Some((offset, len)) = slice {
860            df = df.slice(offset, len);
861        }
862        Ok(df)
863    }
864
865    pub fn sliced(mut self, slice: Option<(i64, usize)>) -> Self {
866        match slice {
867            None => self,
868            Some((offset, length)) => {
869                self.groups = self.groups.slice(offset, length);
870                self.selected_keys = self.keys_sliced(slice);
871                self
872            },
873        }
874    }
875}
876
877unsafe fn take_df(df: &DataFrame, g: GroupsIndicator) -> DataFrame {
878    match g {
879        GroupsIndicator::Idx(idx) => df.take_slice_unchecked(idx.1),
880        GroupsIndicator::Slice([first, len]) => df.slice(first as i64, len as usize),
881    }
882}
883
884#[derive(Copy, Clone, Debug)]
885pub enum GroupByMethod {
886    Min,
887    NanMin,
888    Max,
889    NanMax,
890    Median,
891    Mean,
892    First,
893    FirstNonNull,
894    Last,
895    LastNonNull,
896    Item { allow_empty: bool },
897    Sum,
898    Groups,
899    NUnique,
900    Quantile(f64, QuantileMethod),
901    Count { include_nulls: bool },
902    Implode,
903    Std(u8),
904    Var(u8),
905    ArgMin,
906    ArgMax,
907}
908
909impl Display for GroupByMethod {
910    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
911        use GroupByMethod::*;
912        let s = match self {
913            Min => "min",
914            NanMin => "nan_min",
915            Max => "max",
916            NanMax => "nan_max",
917            Median => "median",
918            Mean => "mean",
919            First => "first",
920            FirstNonNull => "first_non_null",
921            Last => "last",
922            LastNonNull => "last_non_null",
923            Item { .. } => "item",
924            Sum => "sum",
925            Groups => "groups",
926            NUnique => "n_unique",
927            Quantile(_, _) => "quantile",
928            Count { .. } => "count",
929            Implode => "list",
930            Std(_) => "std",
931            Var(_) => "var",
932            ArgMin => "arg_min",
933            ArgMax => "arg_max",
934        };
935        write!(f, "{s}")
936    }
937}
938
939// Formatting functions used in eager and lazy code for renaming grouped columns
940pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> PlSmallStr {
941    use GroupByMethod::*;
942    match method {
943        Min => format_pl_smallstr!("{name}_min"),
944        Max => format_pl_smallstr!("{name}_max"),
945        NanMin => format_pl_smallstr!("{name}_nan_min"),
946        NanMax => format_pl_smallstr!("{name}_nan_max"),
947        Median => format_pl_smallstr!("{name}_median"),
948        Mean => format_pl_smallstr!("{name}_mean"),
949        First => format_pl_smallstr!("{name}_first"),
950        FirstNonNull => format_pl_smallstr!("{name}_first_non_null"),
951        Last => format_pl_smallstr!("{name}_last"),
952        LastNonNull => format_pl_smallstr!("{name}_last_non_null"),
953        Item { .. } => format_pl_smallstr!("{name}_item"),
954        Sum => format_pl_smallstr!("{name}_sum"),
955        Groups => PlSmallStr::from_static("groups"),
956        NUnique => format_pl_smallstr!("{name}_n_unique"),
957        Count { .. } => format_pl_smallstr!("{name}_count"),
958        Implode => format_pl_smallstr!("{name}_agg_list"),
959        Quantile(quantile, _interpol) => format_pl_smallstr!("{name}_quantile_{quantile:.2}"),
960        Std(_) => format_pl_smallstr!("{name}_agg_std"),
961        Var(_) => format_pl_smallstr!("{name}_agg_var"),
962        ArgMin => format_pl_smallstr!("{name}_arg_min"),
963        ArgMax => format_pl_smallstr!("{name}_arg_max"),
964    }
965}
966
967#[cfg(test)]
968mod test {
969    use num_traits::FloatConst;
970
971    use crate::prelude::*;
972
973    #[test]
974    #[cfg(feature = "dtype-date")]
975    #[cfg_attr(miri, ignore)]
976    fn test_group_by() -> PolarsResult<()> {
977        let s0 = Column::new(
978            PlSmallStr::from_static("date"),
979            &[
980                "2020-08-21",
981                "2020-08-21",
982                "2020-08-22",
983                "2020-08-23",
984                "2020-08-22",
985            ],
986        );
987        let s1 = Column::new(PlSmallStr::from_static("temp"), [20, 10, 7, 9, 1]);
988        let s2 = Column::new(PlSmallStr::from_static("rain"), [0.2, 0.1, 0.3, 0.1, 0.01]);
989        let df = DataFrame::new_infer_height(vec![s0, s1, s2]).unwrap();
990
991        let out = df.group_by_stable(["date"])?.select(["temp"]).count()?;
992        assert_eq!(
993            out.column("temp_count")?,
994            &Column::new(PlSmallStr::from_static("temp_count"), [2 as IdxSize, 2, 1])
995        );
996
997        // Use of deprecated mean() for testing purposes
998        #[allow(deprecated)]
999        // Select multiple
1000        let out = df
1001            .group_by_stable(["date"])?
1002            .select(["temp", "rain"])
1003            .mean()?;
1004        assert_eq!(
1005            out.column("temp_mean")?,
1006            &Column::new(PlSmallStr::from_static("temp_mean"), [15.0f64, 4.0, 9.0])
1007        );
1008
1009        // Use of deprecated `mean()` for testing purposes
1010        #[allow(deprecated)]
1011        // Group by multiple
1012        let out = df
1013            .group_by_stable(["date", "temp"])?
1014            .select(["rain"])
1015            .mean()?;
1016        assert!(out.column("rain_mean").is_ok());
1017
1018        // Use of deprecated `sum()` for testing purposes
1019        #[allow(deprecated)]
1020        let out = df.group_by_stable(["date"])?.select(["temp"]).sum()?;
1021        assert_eq!(
1022            out.column("temp_sum")?,
1023            &Column::new(PlSmallStr::from_static("temp_sum"), [30, 8, 9])
1024        );
1025
1026        // Use of deprecated `n_unique()` for testing purposes
1027        #[allow(deprecated)]
1028        // implicit select all and only aggregate on methods that support that aggregation
1029        let gb = df.group_by(["date"]).unwrap().n_unique().unwrap();
1030        // check the group by column is filtered out.
1031        assert_eq!(gb.width(), 3);
1032        Ok(())
1033    }
1034
1035    #[test]
1036    #[cfg_attr(miri, ignore)]
1037    fn test_static_group_by_by_12_columns() {
1038        // Build GroupBy DataFrame.
1039        let s0 = Column::new("G1".into(), ["A", "A", "B", "B", "C"].as_ref());
1040        let s1 = Column::new("N".into(), [1, 2, 2, 4, 2].as_ref());
1041        let s2 = Column::new("G2".into(), ["k", "l", "m", "m", "l"].as_ref());
1042        let s3 = Column::new("G3".into(), ["a", "b", "c", "c", "d"].as_ref());
1043        let s4 = Column::new("G4".into(), ["1", "2", "3", "3", "4"].as_ref());
1044        let s5 = Column::new("G5".into(), ["X", "Y", "Z", "Z", "W"].as_ref());
1045        let s6 = Column::new("G6".into(), [false, true, true, true, false].as_ref());
1046        let s7 = Column::new("G7".into(), ["r", "x", "q", "q", "o"].as_ref());
1047        let s8 = Column::new("G8".into(), ["R", "X", "Q", "Q", "O"].as_ref());
1048        let s9 = Column::new("G9".into(), [1, 2, 3, 3, 4].as_ref());
1049        let s10 = Column::new("G10".into(), [".", "!", "?", "?", "/"].as_ref());
1050        let s11 = Column::new("G11".into(), ["(", ")", "@", "@", "$"].as_ref());
1051        let s12 = Column::new("G12".into(), ["-", "_", ";", ";", ","].as_ref());
1052
1053        let df = DataFrame::new_infer_height(vec![
1054            s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
1055        ])
1056        .unwrap();
1057
1058        // Use of deprecated `sum()` for testing purposes
1059        #[allow(deprecated)]
1060        let adf = df
1061            .group_by([
1062                "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12",
1063            ])
1064            .unwrap()
1065            .select(["N"])
1066            .sum()
1067            .unwrap();
1068
1069        assert_eq!(
1070            Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1071            &[Some(1), Some(2), Some(2), Some(6)]
1072        );
1073    }
1074
1075    #[test]
1076    #[cfg_attr(miri, ignore)]
1077    fn test_dynamic_group_by_by_13_columns() {
1078        // The content for every group_by series.
1079        let series_content = ["A", "A", "B", "B", "C"];
1080
1081        // The name of every group_by series.
1082        let series_names = [
1083            "G1", "G2", "G3", "G4", "G5", "G6", "G7", "G8", "G9", "G10", "G11", "G12", "G13",
1084        ];
1085
1086        // Vector to contain every series.
1087        let mut columns = Vec::with_capacity(14);
1088
1089        // Create a series for every group name.
1090        for series_name in series_names {
1091            let group_columns = Column::new(series_name.into(), series_content.as_ref());
1092            columns.push(group_columns);
1093        }
1094
1095        // Create a series for the aggregation column.
1096        let agg_series = Column::new("N".into(), [1, 2, 3, 3, 4].as_ref());
1097        columns.push(agg_series);
1098
1099        // Create the dataframe with the computed series.
1100        let df = DataFrame::new_infer_height(columns).unwrap();
1101
1102        // Use of deprecated `sum()` for testing purposes
1103        #[allow(deprecated)]
1104        // Compute the aggregated DataFrame by the 13 columns defined in `series_names`.
1105        let adf = df
1106            .group_by(series_names)
1107            .unwrap()
1108            .select(["N"])
1109            .sum()
1110            .unwrap();
1111
1112        // Check that the results of the group-by are correct. The content of every column
1113        // is equal, then, the grouped columns shall be equal and in the same order.
1114        for series_name in &series_names {
1115            assert_eq!(
1116                Vec::from(&adf.column(series_name).unwrap().str().unwrap().sort(false)),
1117                &[Some("A"), Some("B"), Some("C")]
1118            );
1119        }
1120
1121        // Check the aggregated column is the expected one.
1122        assert_eq!(
1123            Vec::from(&adf.column("N_sum").unwrap().i32().unwrap().sort(false)),
1124            &[Some(3), Some(4), Some(6)]
1125        );
1126    }
1127
1128    #[test]
1129    #[cfg_attr(miri, ignore)]
1130    fn test_group_by_floats() {
1131        let df = df! {"flt" => [1., 1., 2., 2., 3.],
1132                    "val" => [1, 1, 1, 1, 1]
1133        }
1134        .unwrap();
1135        // Use of deprecated `sum()` for testing purposes
1136        #[allow(deprecated)]
1137        let res = df.group_by(["flt"]).unwrap().sum().unwrap();
1138        let res = res.sort(["flt"], SortMultipleOptions::default()).unwrap();
1139        assert_eq!(
1140            Vec::from(res.column("val_sum").unwrap().i32().unwrap()),
1141            &[Some(2), Some(2), Some(1)]
1142        );
1143    }
1144
1145    #[test]
1146    #[cfg_attr(miri, ignore)]
1147    #[cfg(feature = "dtype-categorical")]
1148    fn test_group_by_categorical() {
1149        let mut df = df! {"foo" => ["a", "a", "b", "b", "c"],
1150                    "ham" => ["a", "a", "b", "b", "c"],
1151                    "bar" => [1, 1, 1, 1, 1]
1152        }
1153        .unwrap();
1154
1155        df.apply("foo", |s| {
1156            s.cast(&DataType::from_categories(Categories::global()))
1157                .unwrap()
1158        })
1159        .unwrap();
1160
1161        // Use of deprecated `sum()` for testing purposes
1162        #[allow(deprecated)]
1163        // check multiple keys and categorical
1164        let res = df
1165            .group_by_stable(["foo", "ham"])
1166            .unwrap()
1167            .select(["bar"])
1168            .sum()
1169            .unwrap();
1170
1171        assert_eq!(
1172            Vec::from(
1173                res.column("bar_sum")
1174                    .unwrap()
1175                    .as_materialized_series()
1176                    .i32()
1177                    .unwrap()
1178            ),
1179            &[Some(2), Some(2), Some(1)]
1180        );
1181    }
1182
1183    #[test]
1184    #[cfg_attr(miri, ignore)]
1185    fn test_group_by_null_handling() -> PolarsResult<()> {
1186        let df = df!(
1187            "a" => ["a", "a", "a", "b", "b"],
1188            "b" => [Some(1), Some(2), None, None, Some(1)]
1189        )?;
1190        // Use of deprecated `mean()` for testing purposes
1191        #[allow(deprecated)]
1192        let out = df.group_by_stable(["a"])?.mean()?;
1193
1194        assert_eq!(
1195            Vec::from(out.column("b_mean")?.as_materialized_series().f64()?),
1196            &[Some(1.5), Some(1.0)]
1197        );
1198        Ok(())
1199    }
1200
1201    #[test]
1202    #[cfg_attr(miri, ignore)]
1203    fn test_group_by_var() -> PolarsResult<()> {
1204        // check variance and proper coercion to f64
1205        let df = df![
1206            "g" => ["foo", "foo", "bar"],
1207            "flt" => [1.0, 2.0, 3.0],
1208            "int" => [1, 2, 3]
1209        ]?;
1210
1211        // Use of deprecated `sum()` for testing purposes
1212        #[allow(deprecated)]
1213        let out = df.group_by_stable(["g"])?.select(["int"]).var(1)?;
1214
1215        assert_eq!(out.column("int_agg_var")?.f64()?.get(0), Some(0.5));
1216        // Use of deprecated `std()` for testing purposes
1217        #[allow(deprecated)]
1218        let out = df.group_by_stable(["g"])?.select(["int"]).std(1)?;
1219        let val = out.column("int_agg_std")?.f64()?.get(0).unwrap();
1220        let expected = f64::FRAC_1_SQRT_2();
1221        assert!((val - expected).abs() < 0.000001);
1222        Ok(())
1223    }
1224
1225    #[test]
1226    #[cfg_attr(miri, ignore)]
1227    #[cfg(feature = "dtype-categorical")]
1228    fn test_group_by_null_group() -> PolarsResult<()> {
1229        // check if null is own group
1230        let mut df = df![
1231            "g" => [Some("foo"), Some("foo"), Some("bar"), None, None],
1232            "flt" => [1.0, 2.0, 3.0, 1.0, 1.0],
1233            "int" => [1, 2, 3, 1, 1]
1234        ]?;
1235
1236        df.try_apply("g", |s| {
1237            s.cast(&DataType::from_categories(Categories::global()))
1238        })?;
1239
1240        // Use of deprecated `sum()` for testing purposes
1241        #[allow(deprecated)]
1242        let _ = df.group_by(["g"])?.sum()?;
1243        Ok(())
1244    }
1245}