polars_ops/frame/join/hash_join/
single_keys.rs

1use polars_utils::hashing::{DirtyHash, hash_to_partition};
2use polars_utils::idx_vec::IdxVec;
3use polars_utils::nulls::IsNull;
4use polars_utils::sync::SyncPtr;
5use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};
6use polars_utils::unitvec;
7
8use super::*;
9
10// FIXME: we should compute the number of threads / partition size we'll use.
11// let avail_threads = POOL.current_num_threads();
12// let n_threads = (num_keys / MIN_ELEMS_PER_THREAD).clamp(1, avail_threads);
13// Use a small element per thread threshold for debugging/testing purposes.
14const MIN_ELEMS_PER_THREAD: usize = if cfg!(debug_assertions) { 1 } else { 128 };
15
16pub(crate) fn build_tables<T, I>(
17    keys: Vec<I>,
18    nulls_equal: bool,
19) -> Vec<PlHashMap<<T as ToTotalOrd>::TotalOrdItem, IdxVec>>
20where
21    T: TotalHash + TotalEq + ToTotalOrd,
22    <T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull,
23    I: IntoIterator<Item = T> + Send + Sync + Clone,
24{
25    // FIXME: change interface to split the input here, instead of taking
26    // pre-split input iterators.
27    let n_partitions = keys.len();
28    let n_threads = n_partitions;
29    let num_keys_est: usize = keys
30        .iter()
31        .map(|k| k.clone().into_iter().size_hint().0)
32        .sum();
33
34    // Don't bother parallelizing anything for small inputs.
35    if num_keys_est < 2 * MIN_ELEMS_PER_THREAD {
36        let mut hm: PlHashMap<T::TotalOrdItem, IdxVec> = PlHashMap::new();
37        let mut offset = 0;
38        for it in keys {
39            for k in it {
40                let k = k.to_total_ord();
41                if !k.is_null() || nulls_equal {
42                    hm.entry(k).or_default().push(offset);
43                }
44                offset += 1;
45            }
46        }
47        return vec![hm];
48    }
49
50    POOL.install(|| {
51        // Compute the number of elements in each partition for each portion.
52        let per_thread_partition_sizes: Vec<Vec<usize>> = keys
53            .par_iter()
54            .with_max_len(1)
55            .map(|key_portion| {
56                let mut partition_sizes = vec![0; n_partitions];
57                for key in key_portion.clone() {
58                    let key = key.to_total_ord();
59                    let p = hash_to_partition(key.dirty_hash(), n_partitions);
60                    unsafe {
61                        *partition_sizes.get_unchecked_mut(p) += 1;
62                    }
63                }
64                partition_sizes
65            })
66            .collect();
67
68        // Compute output offsets with a cumulative sum.
69        let mut per_thread_partition_offsets = vec![0; n_partitions * n_threads + 1];
70        let mut partition_offsets = vec![0; n_partitions + 1];
71        let mut cum_offset = 0;
72        for p in 0..n_partitions {
73            partition_offsets[p] = cum_offset;
74            for t in 0..n_threads {
75                per_thread_partition_offsets[t * n_partitions + p] = cum_offset;
76                cum_offset += per_thread_partition_sizes[t][p];
77            }
78        }
79        let num_keys = cum_offset;
80        per_thread_partition_offsets[n_threads * n_partitions] = num_keys;
81        partition_offsets[n_partitions] = num_keys;
82
83        // FIXME: we wouldn't need this if we changed our interface to split the
84        // input in this function, instead of taking a vec of iterators.
85        let mut per_thread_input_offsets = vec![0; n_partitions];
86        cum_offset = 0;
87        for t in 0..n_threads {
88            per_thread_input_offsets[t] = cum_offset;
89            for p in 0..n_partitions {
90                cum_offset += per_thread_partition_sizes[t][p];
91            }
92        }
93
94        // Scatter values into partitions.
95        let mut scatter_keys: Vec<T::TotalOrdItem> = Vec::with_capacity(num_keys);
96        let mut scatter_idxs: Vec<IdxSize> = Vec::with_capacity(num_keys);
97        let scatter_keys_ptr = unsafe { SyncPtr::new(scatter_keys.as_mut_ptr()) };
98        let scatter_idxs_ptr = unsafe { SyncPtr::new(scatter_idxs.as_mut_ptr()) };
99        keys.into_par_iter()
100            .with_max_len(1)
101            .enumerate()
102            .for_each(|(t, key_portion)| {
103                let mut partition_offsets =
104                    per_thread_partition_offsets[t * n_partitions..(t + 1) * n_partitions].to_vec();
105                for (i, key) in key_portion.into_iter().enumerate() {
106                    let key = key.to_total_ord();
107                    unsafe {
108                        let p = hash_to_partition(key.dirty_hash(), n_partitions);
109                        let off = partition_offsets.get_unchecked_mut(p);
110                        *scatter_keys_ptr.get().add(*off) = key;
111                        *scatter_idxs_ptr.get().add(*off) =
112                            (per_thread_input_offsets[t] + i) as IdxSize;
113                        *off += 1;
114                    }
115                }
116            });
117        unsafe {
118            scatter_keys.set_len(num_keys);
119            scatter_idxs.set_len(num_keys);
120        }
121
122        // Build tables.
123        (0..n_partitions)
124            .into_par_iter()
125            .with_max_len(1)
126            .map(|p| {
127                // Resizing the hash map is very, very expensive. That's why we
128                // adopt a hybrid strategy: we assume an initially small hash
129                // map, which would satisfy a highly skewed relation. If this
130                // fills up we immediately reserve enough for a full cardinality
131                // data set.
132                let partition_range = partition_offsets[p]..partition_offsets[p + 1];
133                let full_size = partition_range.len();
134                let mut conservative_size = _HASHMAP_INIT_SIZE.max(full_size / 64);
135                let mut hm: PlHashMap<T::TotalOrdItem, IdxVec> =
136                    PlHashMap::with_capacity(conservative_size);
137
138                unsafe {
139                    for i in partition_range {
140                        if hm.len() == conservative_size {
141                            hm.reserve(full_size - conservative_size);
142                            conservative_size = 0; // Hack to ensure we never hit this branch again.
143                        }
144
145                        let key = *scatter_keys.get_unchecked(i);
146
147                        if !key.is_null() || nulls_equal {
148                            let idx = *scatter_idxs.get_unchecked(i);
149                            match hm.entry(key) {
150                                Entry::Occupied(mut o) => {
151                                    o.get_mut().push(idx as IdxSize);
152                                },
153                                Entry::Vacant(v) => {
154                                    let iv = unitvec![idx as IdxSize];
155                                    v.insert(iv);
156                                },
157                            };
158                        }
159                    }
160                }
161
162                hm
163            })
164            .collect()
165    })
166}
167
168// we determine the offset so that we later know which index to store in the join tuples
169pub(super) fn probe_to_offsets<T, I>(probe: &[I]) -> Vec<usize>
170where
171    I: IntoIterator<Item = T> + Clone,
172{
173    probe
174        .iter()
175        .map(|ph| ph.clone().into_iter().size_hint().1.unwrap())
176        .scan(0, |state, val| {
177            let out = *state;
178            *state += val;
179            Some(out)
180        })
181        .collect()
182}