polars_ops/frame/join/hash_join/
single_keys.rs1use polars_utils::hashing::{DirtyHash, hash_to_partition};
2use polars_utils::idx_vec::IdxVec;
3use polars_utils::nulls::IsNull;
4use polars_utils::sync::SyncPtr;
5use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};
6use polars_utils::unitvec;
7
8use super::*;
9
10const MIN_ELEMS_PER_THREAD: usize = if cfg!(debug_assertions) { 1 } else { 128 };
15
16pub(crate) fn build_tables<T, I>(
17 keys: Vec<I>,
18 nulls_equal: bool,
19) -> Vec<PlHashMap<<T as ToTotalOrd>::TotalOrdItem, IdxVec>>
20where
21 T: TotalHash + TotalEq + ToTotalOrd,
22 <T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull,
23 I: IntoIterator<Item = T> + Send + Sync + Clone,
24{
25 let n_partitions = keys.len();
28 let n_threads = n_partitions;
29 let num_keys_est: usize = keys
30 .iter()
31 .map(|k| k.clone().into_iter().size_hint().0)
32 .sum();
33
34 if num_keys_est < 2 * MIN_ELEMS_PER_THREAD {
36 let mut hm: PlHashMap<T::TotalOrdItem, IdxVec> = PlHashMap::new();
37 let mut offset = 0;
38 for it in keys {
39 for k in it {
40 let k = k.to_total_ord();
41 if !k.is_null() || nulls_equal {
42 hm.entry(k).or_default().push(offset);
43 }
44 offset += 1;
45 }
46 }
47 return vec![hm];
48 }
49
50 POOL.install(|| {
51 let per_thread_partition_sizes: Vec<Vec<usize>> = keys
53 .par_iter()
54 .with_max_len(1)
55 .map(|key_portion| {
56 let mut partition_sizes = vec![0; n_partitions];
57 for key in key_portion.clone() {
58 let key = key.to_total_ord();
59 let p = hash_to_partition(key.dirty_hash(), n_partitions);
60 unsafe {
61 *partition_sizes.get_unchecked_mut(p) += 1;
62 }
63 }
64 partition_sizes
65 })
66 .collect();
67
68 let mut per_thread_partition_offsets = vec![0; n_partitions * n_threads + 1];
70 let mut partition_offsets = vec![0; n_partitions + 1];
71 let mut cum_offset = 0;
72 for p in 0..n_partitions {
73 partition_offsets[p] = cum_offset;
74 for t in 0..n_threads {
75 per_thread_partition_offsets[t * n_partitions + p] = cum_offset;
76 cum_offset += per_thread_partition_sizes[t][p];
77 }
78 }
79 let num_keys = cum_offset;
80 per_thread_partition_offsets[n_threads * n_partitions] = num_keys;
81 partition_offsets[n_partitions] = num_keys;
82
83 let mut per_thread_input_offsets = vec![0; n_partitions];
86 cum_offset = 0;
87 for t in 0..n_threads {
88 per_thread_input_offsets[t] = cum_offset;
89 for p in 0..n_partitions {
90 cum_offset += per_thread_partition_sizes[t][p];
91 }
92 }
93
94 let mut scatter_keys: Vec<T::TotalOrdItem> = Vec::with_capacity(num_keys);
96 let mut scatter_idxs: Vec<IdxSize> = Vec::with_capacity(num_keys);
97 let scatter_keys_ptr = unsafe { SyncPtr::new(scatter_keys.as_mut_ptr()) };
98 let scatter_idxs_ptr = unsafe { SyncPtr::new(scatter_idxs.as_mut_ptr()) };
99 keys.into_par_iter()
100 .with_max_len(1)
101 .enumerate()
102 .for_each(|(t, key_portion)| {
103 let mut partition_offsets =
104 per_thread_partition_offsets[t * n_partitions..(t + 1) * n_partitions].to_vec();
105 for (i, key) in key_portion.into_iter().enumerate() {
106 let key = key.to_total_ord();
107 unsafe {
108 let p = hash_to_partition(key.dirty_hash(), n_partitions);
109 let off = partition_offsets.get_unchecked_mut(p);
110 *scatter_keys_ptr.get().add(*off) = key;
111 *scatter_idxs_ptr.get().add(*off) =
112 (per_thread_input_offsets[t] + i) as IdxSize;
113 *off += 1;
114 }
115 }
116 });
117 unsafe {
118 scatter_keys.set_len(num_keys);
119 scatter_idxs.set_len(num_keys);
120 }
121
122 (0..n_partitions)
124 .into_par_iter()
125 .with_max_len(1)
126 .map(|p| {
127 let partition_range = partition_offsets[p]..partition_offsets[p + 1];
133 let full_size = partition_range.len();
134 let mut conservative_size = _HASHMAP_INIT_SIZE.max(full_size / 64);
135 let mut hm: PlHashMap<T::TotalOrdItem, IdxVec> =
136 PlHashMap::with_capacity(conservative_size);
137
138 unsafe {
139 for i in partition_range {
140 if hm.len() == conservative_size {
141 hm.reserve(full_size - conservative_size);
142 conservative_size = 0; }
144
145 let key = *scatter_keys.get_unchecked(i);
146
147 if !key.is_null() || nulls_equal {
148 let idx = *scatter_idxs.get_unchecked(i);
149 match hm.entry(key) {
150 Entry::Occupied(mut o) => {
151 o.get_mut().push(idx as IdxSize);
152 },
153 Entry::Vacant(v) => {
154 let iv = unitvec![idx as IdxSize];
155 v.insert(iv);
156 },
157 };
158 }
159 }
160 }
161
162 hm
163 })
164 .collect()
165 })
166}
167
168pub(super) fn probe_to_offsets<T, I>(probe: &[I]) -> Vec<usize>
170where
171 I: IntoIterator<Item = T> + Clone,
172{
173 probe
174 .iter()
175 .map(|ph| ph.clone().into_iter().size_hint().1.unwrap())
176 .scan(0, |state, val| {
177 let out = *state;
178 *state += val;
179 Some(out)
180 })
181 .collect()
182}