polars_core/frame/group_by/
perfect.rs

1use std::fmt::Debug;
2use std::mem::MaybeUninit;
3
4use num_traits::{FromPrimitive, ToPrimitive};
5use polars_utils::idx_vec::IdxVec;
6use polars_utils::sync::SyncPtr;
7use rayon::prelude::*;
8
9use crate::POOL;
10#[cfg(all(feature = "dtype-categorical", feature = "performant"))]
11use crate::config::verbose;
12use crate::datatypes::*;
13use crate::prelude::*;
14
15impl<T> ChunkedArray<T>
16where
17    T: PolarsIntegerType,
18    T::Native: ToPrimitive + FromPrimitive + Debug,
19{
20    /// Use the indexes as perfect groups.
21    ///
22    /// # Safety
23    /// This ChunkedArray must contain each value in [0..num_groups) at least
24    /// once, and nothing outside this range.
25    pub unsafe fn group_tuples_perfect(
26        &self,
27        num_groups: usize,
28        mut multithreaded: bool,
29        group_capacity: usize,
30    ) -> GroupsType {
31        multithreaded &= POOL.current_num_threads() > 1;
32        // The latest index will be used for the null sentinel.
33        let len = if self.null_count() > 0 {
34            // We add one to store the null sentinel group.
35            num_groups + 1
36        } else {
37            num_groups
38        };
39        let null_idx = len.saturating_sub(1);
40
41        let n_threads = POOL.current_num_threads();
42        let chunk_size = len / n_threads;
43
44        let (groups, first) = if multithreaded && chunk_size > 1 {
45            let mut groups: Vec<IdxVec> = Vec::new();
46            groups.resize_with(len, || IdxVec::with_capacity(group_capacity));
47            let mut first: Vec<IdxSize> = Vec::with_capacity(len);
48
49            // Round up offsets to nearest cache line for groups to reduce false sharing.
50            let groups_start = groups.as_ptr();
51            let mut per_thread_offsets = Vec::with_capacity(n_threads + 1);
52            per_thread_offsets.push(0);
53            for t in 0..n_threads {
54                let ideal_offset = (t + 1) * chunk_size;
55                let cache_aligned_offset =
56                    ideal_offset + groups_start.wrapping_add(ideal_offset).align_offset(128);
57                if t == n_threads - 1 {
58                    per_thread_offsets.push(len);
59                } else {
60                    per_thread_offsets.push(std::cmp::min(cache_aligned_offset, len));
61                }
62            }
63
64            let groups_ptr = unsafe { SyncPtr::new(groups.as_mut_ptr()) };
65            let first_ptr = unsafe { SyncPtr::new(first.as_mut_ptr()) };
66            POOL.install(|| {
67                (0..n_threads).into_par_iter().for_each(|thread_no| {
68                    // We use raw pointers because the slices would overlap.
69                    // However, each thread has its own range it is responsible for.
70                    let groups = groups_ptr.get();
71                    let first = first_ptr.get();
72                    let start = per_thread_offsets[thread_no];
73                    let start = T::Native::from_usize(start).unwrap();
74                    let end = per_thread_offsets[thread_no + 1];
75                    let end = T::Native::from_usize(end).unwrap();
76
77                    if start == end && thread_no != n_threads - 1 {
78                        return;
79                    };
80
81                    let push_to_group = |cat, row_nr| unsafe {
82                        debug_assert!(cat < len);
83                        let buf = &mut *groups.add(cat);
84                        buf.push(row_nr);
85                        if buf.len() == 1 {
86                            *first.add(cat) = row_nr;
87                        }
88                    };
89
90                    let mut row_nr = 0 as IdxSize;
91                    for arr in self.downcast_iter() {
92                        if arr.null_count() == 0 {
93                            for &cat in arr.values().as_slice() {
94                                if cat >= start && cat < end {
95                                    push_to_group(cat.to_usize().unwrap(), row_nr);
96                                }
97
98                                row_nr += 1;
99                            }
100                        } else {
101                            for opt_cat in arr.iter() {
102                                if let Some(&cat) = opt_cat {
103                                    if cat >= start && cat < end {
104                                        push_to_group(cat.to_usize().unwrap(), row_nr);
105                                    }
106                                } else if thread_no == n_threads - 1 {
107                                    // Last thread handles null values.
108                                    push_to_group(null_idx, row_nr);
109                                }
110
111                                row_nr += 1;
112                            }
113                        }
114                    }
115                });
116            });
117            unsafe {
118                first.set_len(len);
119            }
120            (groups, first)
121        } else {
122            let mut groups = Vec::with_capacity(len);
123            let mut first = Vec::with_capacity(len);
124            let first_out = first.spare_capacity_mut();
125            groups.resize_with(len, || IdxVec::with_capacity(group_capacity));
126
127            let mut push_to_group = |cat, row_nr| unsafe {
128                let buf: &mut IdxVec = groups.get_unchecked_mut(cat);
129                buf.push(row_nr);
130                if buf.len() == 1 {
131                    *first_out.get_unchecked_mut(cat) = MaybeUninit::new(row_nr);
132                }
133            };
134
135            let mut row_nr = 0 as IdxSize;
136            for arr in self.downcast_iter() {
137                for opt_cat in arr.iter() {
138                    if let Some(cat) = opt_cat {
139                        push_to_group(cat.to_usize().unwrap(), row_nr);
140                    } else {
141                        push_to_group(null_idx, row_nr);
142                    }
143
144                    row_nr += 1;
145                }
146            }
147            unsafe {
148                first.set_len(len);
149            }
150            (groups, first)
151        };
152
153        // NOTE! we set sorted here!
154        // this happens to be true for `fast_unique` categoricals
155        GroupsType::Idx(GroupsIdx::new(first, groups, true))
156    }
157}
158
159#[cfg(all(feature = "dtype-categorical", feature = "performant"))]
160// Special implementation so that cats can be processed in a single pass
161impl CategoricalChunked {
162    // Use the indexes as perfect groups
163    pub fn group_tuples_perfect(&self, multithreaded: bool, sorted: bool) -> GroupsType {
164        let rev_map = self.get_rev_map();
165        if self.is_empty() {
166            return GroupsType::Idx(GroupsIdx::new(vec![], vec![], true));
167        }
168        let cats = self.physical();
169
170        let mut out = match &**rev_map {
171            RevMapping::Local(cached, _) => {
172                if self._can_fast_unique() {
173                    assert!(cached.len() <= self.len(), "invalid invariant");
174                    if verbose() {
175                        eprintln!("grouping categoricals, run perfect hash function");
176                    }
177                    // on relative small tables this isn't much faster than the default strategy
178                    // but on huge tables, this can be > 2x faster
179                    unsafe { cats.group_tuples_perfect(cached.len(), multithreaded, 0) }
180                } else {
181                    self.physical().group_tuples(multithreaded, sorted).unwrap()
182                }
183            },
184            RevMapping::Global(_mapping, _cached, _) => {
185                // TODO! see if we can optimize this
186                // the problem is that the global categories are not guaranteed packed together
187                // so we might need to deref them first to local ones, but that might be more
188                // expensive than just hashing (benchmark first)
189                self.physical().group_tuples(multithreaded, sorted).unwrap()
190            },
191        };
192        if sorted {
193            out.sort()
194        }
195        out
196    }
197}