polars_core/frame/group_by/
into_groups.rs

1use arrow::legacy::kernels::sort_partition::{
2    create_clean_partitions, partition_to_groups, partition_to_groups_amortized_varsize,
3};
4use polars_error::signals::try_raise_keyboard_interrupt;
5use polars_utils::total_ord::{ToTotalOrd, TotalHash};
6
7use super::*;
8use crate::chunked_array::cast::CastOptions;
9use crate::chunked_array::ops::row_encode::_get_rows_encoded_ca_unordered;
10use crate::config::verbose;
11use crate::series::BitRepr;
12use crate::utils::Container;
13use crate::utils::flatten::flatten_par;
14
15/// Used to create the tuples for a group_by operation.
16pub trait IntoGroupsType {
17    /// Create the tuples need for a group_by operation.
18    ///     * The first value in the tuple is the first index of the group.
19    ///     * The second value in the tuple is the indexes of the groups including the first value.
20    fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult<GroupsType> {
21        unimplemented!()
22    }
23}
24
25fn group_multithreaded<T: PolarsDataType>(ca: &ChunkedArray<T>) -> bool {
26    // TODO! change to something sensible
27    ca.len() > 1000 && POOL.current_num_threads() > 1
28}
29
30fn num_groups_proxy<T>(ca: &ChunkedArray<T>, multithreaded: bool, sorted: bool) -> GroupsType
31where
32    T: PolarsNumericType,
33    T::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd,
34    <T::Native as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash,
35{
36    if multithreaded && group_multithreaded(ca) {
37        let n_partitions = _set_partition_size();
38
39        // use the arrays as iterators
40        if ca.null_count() == 0 {
41            let keys = ca
42                .downcast_iter()
43                .map(|arr| arr.values().as_slice())
44                .collect::<Vec<_>>();
45            group_by_threaded_slice(keys, n_partitions, sorted)
46        } else {
47            let keys = ca
48                .downcast_iter()
49                .map(|arr| arr.iter().map(|o| o.copied()))
50                .collect::<Vec<_>>();
51            group_by_threaded_iter(&keys, n_partitions, sorted)
52        }
53    } else if !ca.has_nulls() {
54        group_by(ca.into_no_null_iter(), sorted)
55    } else {
56        group_by(ca.iter(), sorted)
57    }
58}
59
60impl<T> ChunkedArray<T>
61where
62    T: PolarsNumericType,
63    T::Native: NumCast,
64{
65    fn create_groups_from_sorted(&self, multithreaded: bool) -> GroupsSlice {
66        if verbose() {
67            eprintln!("group_by keys are sorted; running sorted key fast path");
68        }
69        let arr = self.downcast_iter().next().unwrap();
70        if arr.is_empty() {
71            return GroupsSlice::default();
72        }
73        let mut values = arr.values().as_slice();
74        let null_count = arr.null_count();
75        let length = values.len();
76
77        // all nulls
78        if null_count == length {
79            return vec![[0, length as IdxSize]];
80        }
81
82        let mut nulls_first = false;
83        if null_count > 0 {
84            nulls_first = arr.get(0).is_none()
85        }
86
87        if nulls_first {
88            values = &values[null_count..];
89        } else {
90            values = &values[..length - null_count];
91        };
92
93        let n_threads = POOL.current_num_threads();
94        if multithreaded && n_threads > 1 {
95            let parts =
96                create_clean_partitions(values, n_threads, self.is_sorted_descending_flag());
97            let n_parts = parts.len();
98
99            let first_ptr = &values[0] as *const T::Native as usize;
100            let groups = parts.par_iter().enumerate().map(|(i, part)| {
101                // we go via usize as *const is not send
102                let first_ptr = first_ptr as *const T::Native;
103
104                let part_first_ptr = &part[0] as *const T::Native;
105                let mut offset = unsafe { part_first_ptr.offset_from(first_ptr) } as IdxSize;
106
107                // nulls first: only add the nulls at the first partition
108                if nulls_first && i == 0 {
109                    partition_to_groups(part, null_count as IdxSize, true, offset)
110                }
111                // nulls last: only compute at the last partition
112                else if !nulls_first && i == n_parts - 1 {
113                    partition_to_groups(part, null_count as IdxSize, false, offset)
114                }
115                // other partitions
116                else {
117                    if nulls_first {
118                        offset += null_count as IdxSize;
119                    };
120
121                    partition_to_groups(part, 0, false, offset)
122                }
123            });
124            let groups = POOL.install(|| groups.collect::<Vec<_>>());
125            flatten_par(&groups)
126        } else {
127            partition_to_groups(values, null_count as IdxSize, nulls_first, 0)
128        }
129    }
130}
131
132#[cfg(all(feature = "dtype-categorical", feature = "performant"))]
133impl IntoGroupsType for CategoricalChunked {
134    fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsType> {
135        Ok(self.group_tuples_perfect(multithreaded, sorted))
136    }
137}
138
139impl<T> IntoGroupsType for ChunkedArray<T>
140where
141    T: PolarsNumericType,
142    T::Native: NumCast,
143{
144    fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsType> {
145        // sorted path
146        if self.is_sorted_ascending_flag() || self.is_sorted_descending_flag() {
147            // don't have to pass `sorted` arg, GroupSlice is always sorted.
148            return Ok(GroupsType::Slice {
149                groups: self.rechunk().create_groups_from_sorted(multithreaded),
150                rolling: false,
151            });
152        }
153
154        let out = match self.dtype() {
155            DataType::UInt64 => {
156                // convince the compiler that we are this type.
157                let ca: &UInt64Chunked = unsafe {
158                    &*(self as *const ChunkedArray<T> as *const ChunkedArray<UInt64Type>)
159                };
160                num_groups_proxy(ca, multithreaded, sorted)
161            },
162            DataType::UInt32 => {
163                // convince the compiler that we are this type.
164                let ca: &UInt32Chunked = unsafe {
165                    &*(self as *const ChunkedArray<T> as *const ChunkedArray<UInt32Type>)
166                };
167                num_groups_proxy(ca, multithreaded, sorted)
168            },
169            DataType::Int64 => {
170                let BitRepr::Large(ca) = self.to_bit_repr() else {
171                    unreachable!()
172                };
173                num_groups_proxy(&ca, multithreaded, sorted)
174            },
175            DataType::Int32 => {
176                let BitRepr::Small(ca) = self.to_bit_repr() else {
177                    unreachable!()
178                };
179                num_groups_proxy(&ca, multithreaded, sorted)
180            },
181            DataType::Float64 => {
182                // convince the compiler that we are this type.
183                let ca: &Float64Chunked = unsafe {
184                    &*(self as *const ChunkedArray<T> as *const ChunkedArray<Float64Type>)
185                };
186                num_groups_proxy(ca, multithreaded, sorted)
187            },
188            DataType::Float32 => {
189                // convince the compiler that we are this type.
190                let ca: &Float32Chunked = unsafe {
191                    &*(self as *const ChunkedArray<T> as *const ChunkedArray<Float32Type>)
192                };
193                num_groups_proxy(ca, multithreaded, sorted)
194            },
195            #[cfg(feature = "dtype-decimal")]
196            DataType::Decimal(_, _) => {
197                // convince the compiler that we are this type.
198                let ca: &Int128Chunked = unsafe {
199                    &*(self as *const ChunkedArray<T> as *const ChunkedArray<Int128Type>)
200                };
201                num_groups_proxy(ca, multithreaded, sorted)
202            },
203            #[cfg(all(feature = "performant", feature = "dtype-i8", feature = "dtype-u8"))]
204            DataType::Int8 => {
205                // convince the compiler that we are this type.
206                let ca: &Int8Chunked =
207                    unsafe { &*(self as *const ChunkedArray<T> as *const ChunkedArray<Int8Type>) };
208                let s = ca.reinterpret_unsigned();
209                return s.group_tuples(multithreaded, sorted);
210            },
211            #[cfg(all(feature = "performant", feature = "dtype-i8", feature = "dtype-u8"))]
212            DataType::UInt8 => {
213                // convince the compiler that we are this type.
214                let ca: &UInt8Chunked =
215                    unsafe { &*(self as *const ChunkedArray<T> as *const ChunkedArray<UInt8Type>) };
216                num_groups_proxy(ca, multithreaded, sorted)
217            },
218            #[cfg(all(feature = "performant", feature = "dtype-i16", feature = "dtype-u16"))]
219            DataType::Int16 => {
220                // convince the compiler that we are this type.
221                let ca: &Int16Chunked =
222                    unsafe { &*(self as *const ChunkedArray<T> as *const ChunkedArray<Int16Type>) };
223                let s = ca.reinterpret_unsigned();
224                return s.group_tuples(multithreaded, sorted);
225            },
226            #[cfg(all(feature = "performant", feature = "dtype-i16", feature = "dtype-u16"))]
227            DataType::UInt16 => {
228                // convince the compiler that we are this type.
229                let ca: &UInt16Chunked = unsafe {
230                    &*(self as *const ChunkedArray<T> as *const ChunkedArray<UInt16Type>)
231                };
232                num_groups_proxy(ca, multithreaded, sorted)
233            },
234            _ => {
235                let ca = unsafe { self.cast_unchecked(&DataType::UInt32).unwrap() };
236                let ca = ca.u32().unwrap();
237                num_groups_proxy(ca, multithreaded, sorted)
238            },
239        };
240        try_raise_keyboard_interrupt();
241        Ok(out)
242    }
243}
244impl IntoGroupsType for BooleanChunked {
245    fn group_tuples(&self, mut multithreaded: bool, sorted: bool) -> PolarsResult<GroupsType> {
246        multithreaded &= POOL.current_num_threads() > 1;
247
248        #[cfg(feature = "performant")]
249        {
250            let ca = self
251                .cast_with_options(&DataType::UInt8, CastOptions::Overflowing)
252                .unwrap();
253            let ca = ca.u8().unwrap();
254            ca.group_tuples(multithreaded, sorted)
255        }
256        #[cfg(not(feature = "performant"))]
257        {
258            let ca = self
259                .cast_with_options(&DataType::UInt32, CastOptions::Overflowing)
260                .unwrap();
261            let ca = ca.u32().unwrap();
262            ca.group_tuples(multithreaded, sorted)
263        }
264    }
265}
266
267impl IntoGroupsType for StringChunked {
268    #[allow(clippy::needless_lifetimes)]
269    fn group_tuples<'a>(&'a self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsType> {
270        self.as_binary().group_tuples(multithreaded, sorted)
271    }
272}
273
274impl IntoGroupsType for BinaryChunked {
275    #[allow(clippy::needless_lifetimes)]
276    fn group_tuples<'a>(
277        &'a self,
278        mut multithreaded: bool,
279        sorted: bool,
280    ) -> PolarsResult<GroupsType> {
281        if self.is_sorted_any() && !self.has_nulls() && self.n_chunks() == 1 {
282            let arr = self.downcast_get(0).unwrap();
283            let values = arr.values_iter();
284            let mut out = Vec::with_capacity(values.len() / 30);
285            partition_to_groups_amortized_varsize(values, arr.len() as _, 0, false, 0, &mut out);
286            return Ok(GroupsType::Slice {
287                groups: out,
288                rolling: false,
289            });
290        }
291
292        multithreaded &= POOL.current_num_threads() > 1;
293        let bh = self.to_bytes_hashes(multithreaded, Default::default());
294
295        let out = if multithreaded {
296            let n_partitions = bh.len();
297            // Take slices so that the vecs are not cloned.
298            let bh = bh.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
299            group_by_threaded_slice(bh, n_partitions, sorted)
300        } else {
301            group_by(bh[0].iter(), sorted)
302        };
303        try_raise_keyboard_interrupt();
304        Ok(out)
305    }
306}
307
308impl IntoGroupsType for BinaryOffsetChunked {
309    #[allow(clippy::needless_lifetimes)]
310    fn group_tuples<'a>(
311        &'a self,
312        mut multithreaded: bool,
313        sorted: bool,
314    ) -> PolarsResult<GroupsType> {
315        if self.is_sorted_any() && !self.has_nulls() && self.n_chunks() == 1 {
316            let arr = self.downcast_get(0).unwrap();
317            let values = arr.values_iter();
318            let mut out = Vec::with_capacity(values.len() / 30);
319            partition_to_groups_amortized_varsize(values, arr.len() as _, 0, false, 0, &mut out);
320            return Ok(GroupsType::Slice {
321                groups: out,
322                rolling: false,
323            });
324        }
325        multithreaded &= POOL.current_num_threads() > 1;
326        let bh = self.to_bytes_hashes(multithreaded, Default::default());
327
328        let out = if multithreaded {
329            let n_partitions = bh.len();
330            // Take slices so that the vecs are not cloned.
331            let bh = bh.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
332            group_by_threaded_slice(bh, n_partitions, sorted)
333        } else {
334            group_by(bh[0].iter(), sorted)
335        };
336        Ok(out)
337    }
338}
339
340impl IntoGroupsType for ListChunked {
341    #[allow(clippy::needless_lifetimes)]
342    #[allow(unused_variables)]
343    fn group_tuples<'a>(
344        &'a self,
345        mut multithreaded: bool,
346        sorted: bool,
347    ) -> PolarsResult<GroupsType> {
348        multithreaded &= POOL.current_num_threads() > 1;
349        let by = &[self.clone().into_column()];
350        let ca = if multithreaded {
351            encode_rows_vertical_par_unordered(by).unwrap()
352        } else {
353            _get_rows_encoded_ca_unordered(PlSmallStr::EMPTY, by).unwrap()
354        };
355
356        ca.group_tuples(multithreaded, sorted)
357    }
358}
359
360#[cfg(feature = "dtype-array")]
361impl IntoGroupsType for ArrayChunked {
362    #[allow(clippy::needless_lifetimes)]
363    #[allow(unused_variables)]
364    fn group_tuples<'a>(&'a self, _multithreaded: bool, _sorted: bool) -> PolarsResult<GroupsType> {
365        todo!("grouping FixedSizeList not yet supported")
366    }
367}
368
369#[cfg(feature = "object")]
370impl<T> IntoGroupsType for ObjectChunked<T>
371where
372    T: PolarsObject,
373{
374    fn group_tuples(&self, _multithreaded: bool, sorted: bool) -> PolarsResult<GroupsType> {
375        Ok(group_by(self.into_iter(), sorted))
376    }
377}