polars_core/frame/group_by/
hashing.rs

1use hashbrown::hash_map::Entry;
2use polars_utils::hashing::{DirtyHash, hash_to_partition};
3use polars_utils::idx_vec::IdxVec;
4use polars_utils::itertools::Itertools;
5use polars_utils::sync::SyncPtr;
6use polars_utils::total_ord::{ToTotalOrd, TotalHash, TotalOrdWrap};
7use polars_utils::unitvec;
8use rayon::prelude::*;
9
10use crate::POOL;
11use crate::hashing::*;
12use crate::prelude::*;
13use crate::utils::flatten;
14
15fn get_init_size() -> usize {
16    // we check if this is executed from the main thread
17    // we don't want to pre-allocate this much if executed
18    // group_tuples in a parallel iterator as that explodes allocation
19    if POOL.current_thread_index().is_none() {
20        _HASHMAP_INIT_SIZE
21    } else {
22        0
23    }
24}
25
26fn finish_group_order(mut out: Vec<Vec<IdxItem>>, sorted: bool) -> GroupsType {
27    if sorted {
28        // we can just take the first value, no need to flatten
29        let mut out = if out.len() == 1 {
30            out.pop().unwrap()
31        } else {
32            let (cap, offsets) = flatten::cap_and_offsets(&out);
33            // we write (first, all) tuple because of sorting
34            let mut items = Vec::with_capacity(cap);
35            let items_ptr = unsafe { SyncPtr::new(items.as_mut_ptr()) };
36
37            POOL.install(|| {
38                out.into_par_iter()
39                    .zip(offsets)
40                    .for_each(|(mut g, offset)| {
41                        // pre-sort every array
42                        // this will make the final single threaded sort much faster
43                        g.sort_unstable_by_key(|g| g.0);
44
45                        unsafe {
46                            let mut items_ptr: *mut (IdxSize, IdxVec) = items_ptr.get();
47                            items_ptr = items_ptr.add(offset);
48
49                            for (i, g) in g.into_iter().enumerate() {
50                                std::ptr::write(items_ptr.add(i), g)
51                            }
52                        }
53                    });
54            });
55            unsafe {
56                items.set_len(cap);
57            }
58            items
59        };
60        out.sort_unstable_by_key(|g| g.0);
61        let mut idx = GroupsIdx::from_iter(out);
62        idx.sorted = true;
63        GroupsType::Idx(idx)
64    } else {
65        // we can just take the first value, no need to flatten
66        if out.len() == 1 {
67            GroupsType::Idx(GroupsIdx::from(out.pop().unwrap()))
68        } else {
69            // flattens
70            GroupsType::Idx(GroupsIdx::from(out))
71        }
72    }
73}
74
75pub(crate) fn group_by<K>(keys: impl Iterator<Item = K>, sorted: bool) -> GroupsType
76where
77    K: TotalHash + TotalEq,
78{
79    let init_size = get_init_size();
80    let (mut first, mut groups);
81    if sorted {
82        groups = Vec::with_capacity(get_init_size());
83        first = Vec::with_capacity(get_init_size());
84        let mut hash_tbl = PlHashMap::with_capacity(init_size);
85        for (idx, k) in keys.enumerate_idx() {
86            match hash_tbl.entry(TotalOrdWrap(k)) {
87                Entry::Vacant(entry) => {
88                    let group_idx = groups.len() as IdxSize;
89                    entry.insert(group_idx);
90                    groups.push(unitvec![idx]);
91                    first.push(idx);
92                },
93                Entry::Occupied(entry) => unsafe {
94                    groups.get_unchecked_mut(*entry.get() as usize).push(idx)
95                },
96            }
97        }
98    } else {
99        let mut hash_tbl = PlHashMap::with_capacity(init_size);
100        for (idx, k) in keys.enumerate_idx() {
101            match hash_tbl.entry(TotalOrdWrap(k)) {
102                Entry::Vacant(entry) => {
103                    entry.insert((idx, unitvec![idx]));
104                },
105                Entry::Occupied(mut entry) => entry.get_mut().1.push(idx),
106            }
107        }
108        (first, groups) = hash_tbl.into_values().unzip();
109    }
110    GroupsType::Idx(GroupsIdx::new(first, groups, sorted))
111}
112
113// giving the slice info to the compiler is much
114// faster than the using an iterator, that's why we
115// have the code duplication
116pub(crate) fn group_by_threaded_slice<T, IntoSlice>(
117    keys: Vec<IntoSlice>,
118    n_partitions: usize,
119    sorted: bool,
120) -> GroupsType
121where
122    T: ToTotalOrd,
123    <T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + DirtyHash,
124    IntoSlice: AsRef<[T]> + Send + Sync,
125{
126    let init_size = get_init_size();
127
128    // We will create a hashtable in every thread.
129    // We use the hash to partition the keys to the matching hashtable.
130    // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition.
131    let out = POOL.install(|| {
132        (0..n_partitions)
133            .into_par_iter()
134            .map(|thread_no| {
135                let mut hash_tbl = PlHashMap::with_capacity(init_size);
136
137                let mut offset = 0;
138                for keys in &keys {
139                    let keys = keys.as_ref();
140                    let len = keys.len() as IdxSize;
141
142                    for (key_idx, k) in keys.iter().enumerate_idx() {
143                        let k = k.to_total_ord();
144                        let idx = key_idx + offset;
145
146                        if thread_no == hash_to_partition(k.dirty_hash(), n_partitions) {
147                            match hash_tbl.entry(k) {
148                                Entry::Vacant(entry) => {
149                                    entry.insert((idx, unitvec![idx]));
150                                },
151                                Entry::Occupied(mut entry) => {
152                                    entry.get_mut().1.push(idx);
153                                },
154                            }
155                        }
156                    }
157                    offset += len;
158                }
159                hash_tbl
160                    .into_iter()
161                    .map(|(_k, v)| v)
162                    .collect_trusted::<Vec<_>>()
163            })
164            .collect::<Vec<_>>()
165    });
166    finish_group_order(out, sorted)
167}
168
169pub(crate) fn group_by_threaded_iter<T, I>(
170    keys: &[I],
171    n_partitions: usize,
172    sorted: bool,
173) -> GroupsType
174where
175    I: IntoIterator<Item = T> + Send + Sync + Clone,
176    I::IntoIter: ExactSizeIterator,
177    T: ToTotalOrd,
178    <T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + DirtyHash,
179{
180    let init_size = get_init_size();
181
182    // We will create a hashtable in every thread.
183    // We use the hash to partition the keys to the matching hashtable.
184    // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition.
185    let out = POOL.install(|| {
186        (0..n_partitions)
187            .into_par_iter()
188            .map(|thread_no| {
189                let mut hash_tbl: PlHashMap<T::TotalOrdItem, IdxVec> =
190                    PlHashMap::with_capacity(init_size);
191
192                let mut offset = 0;
193                for keys in keys {
194                    let keys = keys.clone().into_iter();
195                    let len = keys.len() as IdxSize;
196
197                    for (key_idx, k) in keys.into_iter().enumerate_idx() {
198                        let k = k.to_total_ord();
199                        let idx = key_idx + offset;
200
201                        if thread_no == hash_to_partition(k.dirty_hash(), n_partitions) {
202                            match hash_tbl.entry(k) {
203                                Entry::Vacant(entry) => {
204                                    entry.insert(unitvec![idx]);
205                                },
206                                Entry::Occupied(mut entry) => {
207                                    entry.get_mut().push(idx);
208                                },
209                            }
210                        }
211                    }
212                    offset += len;
213                }
214                // iterating the hash tables locally
215                // was faster than iterating in the materialization phase directly
216                // the proper end vec. I believe this is because the hash-table
217                // currently is local to the thread so in hot cache
218                // So we first collect into a tight vec and then do a second
219                // materialization run
220                // this is also faster than the index-map approach where we
221                // directly locally store to a vec at the cost of an extra
222                // indirection
223                hash_tbl
224                    .into_iter()
225                    .map(|(_k, v)| (unsafe { *v.first().unwrap_unchecked() }, v))
226                    .collect_trusted::<Vec<_>>()
227            })
228            .collect::<Vec<_>>()
229    });
230    finish_group_order(out, sorted)
231}