polars_ops/frame/join/hash_join/
single_keys.rs1use polars_utils::hashing::{DirtyHash, hash_to_partition};
2use polars_utils::idx_vec::IdxVec;
3use polars_utils::nulls::IsNull;
4use polars_utils::sync::SyncPtr;
5use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};
6use polars_utils::unitvec;
7
8use super::*;
9
10const MIN_ELEMS_PER_THREAD: usize = if cfg!(debug_assertions) { 1 } else { 128 };
15
16pub(crate) fn build_tables<T, I>(
17 keys: Vec<I>,
18 nulls_equal: bool,
19) -> Vec<PlHashMap<<T as ToTotalOrd>::TotalOrdItem, IdxVec>>
20where
21 T: TotalHash + TotalEq + ToTotalOrd,
22 <T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull,
23 I: IntoIterator<Item = T> + Send + Sync + Clone,
24{
25 let n_partitions = keys.len();
28 let n_threads = n_partitions;
29 let num_keys_est: usize = keys
30 .iter()
31 .map(|k| k.clone().into_iter().size_hint().0)
32 .sum();
33
34 if num_keys_est < 2 * MIN_ELEMS_PER_THREAD {
36 let mut hm: PlHashMap<T::TotalOrdItem, IdxVec> = PlHashMap::new();
37 let mut offset = 0;
38 for it in keys {
39 for k in it {
40 let k = k.to_total_ord();
41 if !k.is_null() || nulls_equal {
42 hm.entry(k).or_default().push(offset);
43 }
44 offset += 1;
45 }
46 }
47 return vec![hm];
48 }
49
50 RAYON.install(|| {
51 let per_thread_partition_sizes: Vec<Vec<usize>> = keys
53 .par_iter()
54 .with_max_len(1)
55 .map(|key_portion| {
56 let mut partition_sizes = vec![0; n_partitions];
57 for key in key_portion.clone() {
58 let key = key.to_total_ord();
59 let p = hash_to_partition(key.dirty_hash(), n_partitions);
60 unsafe {
61 *partition_sizes.get_unchecked_mut(p) += 1;
62 }
63 }
64 partition_sizes
65 })
66 .collect();
67
68 let mut per_thread_partition_offsets = vec![0; n_partitions * n_threads + 1];
70 let mut partition_offsets = vec![0; n_partitions + 1];
71 let mut cum_offset = 0;
72 for p in 0..n_partitions {
73 partition_offsets[p] = cum_offset;
74 for t in 0..n_threads {
75 per_thread_partition_offsets[t * n_partitions + p] = cum_offset;
76 cum_offset += per_thread_partition_sizes[t][p];
77 }
78 }
79 let num_keys = cum_offset;
80 per_thread_partition_offsets[n_threads * n_partitions] = num_keys;
81 partition_offsets[n_partitions] = num_keys;
82
83 let mut per_thread_input_offsets = vec![0; n_partitions];
86 cum_offset = 0;
87 for t in 0..n_threads {
88 per_thread_input_offsets[t] = cum_offset;
89 cum_offset += per_thread_partition_sizes[t]
90 .iter()
91 .take(n_partitions)
92 .sum::<usize>();
93 }
94
95 let mut scatter_keys: Vec<T::TotalOrdItem> = Vec::with_capacity(num_keys);
97 let mut scatter_idxs: Vec<IdxSize> = Vec::with_capacity(num_keys);
98 let scatter_keys_ptr = unsafe { SyncPtr::new(scatter_keys.as_mut_ptr()) };
99 let scatter_idxs_ptr = unsafe { SyncPtr::new(scatter_idxs.as_mut_ptr()) };
100 keys.into_par_iter()
101 .with_max_len(1)
102 .enumerate()
103 .for_each(|(t, key_portion)| {
104 let mut partition_offsets =
105 per_thread_partition_offsets[t * n_partitions..(t + 1) * n_partitions].to_vec();
106 for (i, key) in key_portion.into_iter().enumerate() {
107 let key = key.to_total_ord();
108 unsafe {
109 let p = hash_to_partition(key.dirty_hash(), n_partitions);
110 let off = partition_offsets.get_unchecked_mut(p);
111 *scatter_keys_ptr.get().add(*off) = key;
112 *scatter_idxs_ptr.get().add(*off) =
113 (per_thread_input_offsets[t] + i) as IdxSize;
114 *off += 1;
115 }
116 }
117 });
118 unsafe {
119 scatter_keys.set_len(num_keys);
120 scatter_idxs.set_len(num_keys);
121 }
122
123 (0..n_partitions)
125 .into_par_iter()
126 .with_max_len(1)
127 .map(|p| {
128 let partition_range = partition_offsets[p]..partition_offsets[p + 1];
134 let full_size = partition_range.len();
135 let mut conservative_size = _HASHMAP_INIT_SIZE.max(full_size / 64);
136 let mut hm: PlHashMap<T::TotalOrdItem, IdxVec> =
137 PlHashMap::with_capacity(conservative_size);
138
139 unsafe {
140 for i in partition_range {
141 if hm.len() == conservative_size {
142 hm.reserve(full_size - conservative_size);
143 conservative_size = 0; }
145
146 let key = *scatter_keys.get_unchecked(i);
147
148 if !key.is_null() || nulls_equal {
149 let idx = *scatter_idxs.get_unchecked(i);
150 match hm.entry(key) {
151 Entry::Occupied(mut o) => {
152 o.get_mut().push(idx as IdxSize);
153 },
154 Entry::Vacant(v) => {
155 let iv = unitvec![idx as IdxSize];
156 v.insert(iv);
157 },
158 };
159 }
160 }
161 }
162
163 hm
164 })
165 .collect()
166 })
167}
168
169pub(super) fn probe_to_offsets<T, I>(probe: &[I]) -> Vec<usize>
171where
172 I: IntoIterator<Item = T> + Clone,
173{
174 probe
175 .iter()
176 .map(|ph| ph.clone().into_iter().size_hint().1.unwrap())
177 .scan(0, |state, val| {
178 let out = *state;
179 *state += val;
180 Some(out)
181 })
182 .collect()
183}