polars_ops/frame/join/hash_join/
single_keys_outer.rs

1use std::hash::BuildHasher;
2
3use arrow::array::{MutablePrimitiveArray, PrimitiveArray};
4use arrow::legacy::utils::CustomIterTools;
5use polars_utils::hashing::hash_to_partition;
6use polars_utils::idx_vec::IdxVec;
7use polars_utils::nulls::IsNull;
8use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};
9use polars_utils::unitvec;
10
11use super::*;
12
13pub(crate) fn create_hash_and_keys_threaded_vectorized<I, T>(
14    iters: Vec<I>,
15    build_hasher: Option<PlRandomState>,
16) -> (Vec<Vec<(u64, T)>>, PlRandomState)
17where
18    I: IntoIterator<Item = T> + Send,
19    I::IntoIter: TrustedLen,
20    T: TotalHash + TotalEq + Send + ToTotalOrd,
21    <T as ToTotalOrd>::TotalOrdItem: Hash + Eq,
22{
23    let build_hasher = build_hasher.unwrap_or_default();
24    let hashes = POOL.install(|| {
25        iters
26            .into_par_iter()
27            .map(|iter| {
28                // create hashes and keys
29                #[allow(clippy::needless_borrows_for_generic_args)]
30                iter.into_iter()
31                    .map(|val| (build_hasher.hash_one(&val.to_total_ord()), val))
32                    .collect_trusted::<Vec<_>>()
33            })
34            .collect()
35    });
36    (hashes, build_hasher)
37}
38
39pub(crate) fn prepare_hashed_relation_threaded<T, I>(
40    iters: Vec<I>,
41) -> Vec<PlHashMap<<T as ToTotalOrd>::TotalOrdItem, (bool, IdxVec)>>
42where
43    I: Iterator<Item = T> + Send + TrustedLen,
44    T: Send + Sync + TotalHash + TotalEq + ToTotalOrd,
45    <T as ToTotalOrd>::TotalOrdItem: Send + Sync + Hash + Eq,
46{
47    let n_partitions = _set_partition_size();
48    let (hashes_and_keys, build_hasher) = create_hash_and_keys_threaded_vectorized(iters, None);
49
50    // We will create a hashtable in every thread.
51    // We use the hash to partition the keys to the matching hashtable.
52    // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition.
53    POOL.install(|| {
54        (0..n_partitions)
55            .into_par_iter()
56            .map(|partition_no| {
57                let hashes_and_keys = &hashes_and_keys;
58                let mut hash_tbl: PlHashMap<T::TotalOrdItem, (bool, IdxVec)> =
59                    PlHashMap::with_hasher(build_hasher);
60
61                let mut offset = 0;
62                for hashes_and_keys in hashes_and_keys {
63                    let len = hashes_and_keys.len();
64                    hashes_and_keys
65                        .iter()
66                        .enumerate()
67                        .for_each(|(idx, (h, k))| {
68                            let k = k.to_total_ord();
69                            let idx = idx as IdxSize;
70                            // partition hashes by thread no.
71                            // So only a part of the hashes go to this hashmap
72                            if partition_no == hash_to_partition(*h, n_partitions) {
73                                let idx = idx + offset;
74                                let entry = hash_tbl
75                                    .raw_entry_mut()
76                                    // uses the key to check equality to find and entry
77                                    .from_key_hashed_nocheck(*h, &k);
78
79                                match entry {
80                                    RawEntryMut::Vacant(entry) => {
81                                        entry.insert_hashed_nocheck(*h, k, (false, unitvec![idx]));
82                                    },
83                                    RawEntryMut::Occupied(mut entry) => {
84                                        let (_k, v) = entry.get_key_value_mut();
85                                        v.1.push(idx);
86                                    },
87                                }
88                            }
89                        });
90
91                    offset += len as IdxSize;
92                }
93                hash_tbl
94            })
95            .collect()
96    })
97}
98
99/// Probe the build table and add tuples to the results.
100#[allow(clippy::too_many_arguments)]
101fn probe_outer<T, F, G, H>(
102    probe_hashes: &[Vec<(u64, T)>],
103    hash_tbls: &mut [PlHashMap<<T as ToTotalOrd>::TotalOrdItem, (bool, IdxVec)>],
104    results: &mut (
105        MutablePrimitiveArray<IdxSize>,
106        MutablePrimitiveArray<IdxSize>,
107    ),
108    n_tables: usize,
109    // Function that get index_a, index_b when there is a match and pushes to result
110    swap_fn_match: F,
111    // Function that get index_a when there is no match and pushes to result
112    swap_fn_no_match: G,
113    // Function that get index_b from the build table that did not match any in A and pushes to result
114    swap_fn_drain: H,
115    nulls_equal: bool,
116) where
117    T: TotalHash + TotalEq + ToTotalOrd,
118    <T as ToTotalOrd>::TotalOrdItem: Hash + Eq + IsNull,
119    // idx_a, idx_b -> ...
120    F: Fn(IdxSize, IdxSize) -> (Option<IdxSize>, Option<IdxSize>),
121    // idx_a -> ...
122    G: Fn(IdxSize) -> (Option<IdxSize>, Option<IdxSize>),
123    // idx_b -> ...
124    H: Fn(IdxSize) -> (Option<IdxSize>, Option<IdxSize>),
125{
126    // needed for the partition shift instead of modulo to make sense
127    let mut idx_a = 0;
128    for probe_hashes in probe_hashes {
129        for (h, key) in probe_hashes {
130            let key = key.to_total_ord();
131            let h = *h;
132            // probe table that contains the hashed value
133            let current_probe_table =
134                unsafe { hash_tbls.get_unchecked_mut(hash_to_partition(h, n_tables)) };
135
136            let entry = current_probe_table
137                .raw_entry_mut()
138                .from_key_hashed_nocheck(h, &key);
139
140            match entry {
141                // match and remove
142                RawEntryMut::Occupied(mut occupied) => {
143                    if key.is_null() && !nulls_equal {
144                        let (l, r) = swap_fn_no_match(idx_a);
145                        results.0.push(l);
146                        results.1.push(r);
147                    } else {
148                        let (tracker, indexes_b) = occupied.get_mut();
149                        *tracker = true;
150                        for (l, r) in indexes_b.iter().map(|&idx_b| swap_fn_match(idx_a, idx_b)) {
151                            results.0.push(l);
152                            results.1.push(r);
153                        }
154                    }
155                },
156                // no match
157                RawEntryMut::Vacant(_) => {
158                    let (l, r) = swap_fn_no_match(idx_a);
159                    results.0.push(l);
160                    results.1.push(r);
161                },
162            }
163            idx_a += 1;
164        }
165    }
166
167    for hash_tbl in hash_tbls {
168        hash_tbl.iter().for_each(|(_k, (tracker, indexes_b))| {
169            // remaining joined values from the right table
170            if !*tracker {
171                for (l, r) in indexes_b.iter().map(|&idx_b| swap_fn_drain(idx_b)) {
172                    results.0.push(l);
173                    results.1.push(r);
174                }
175            }
176        });
177    }
178}
179
180/// Hash join outer. Both left and right can have no match so Options
181pub(super) fn hash_join_tuples_outer<T, I, J>(
182    probe: Vec<I>,
183    build: Vec<J>,
184    swapped: bool,
185    validate: JoinValidation,
186    nulls_equal: bool,
187) -> PolarsResult<(PrimitiveArray<IdxSize>, PrimitiveArray<IdxSize>)>
188where
189    I: IntoIterator<Item = T>,
190    J: IntoIterator<Item = T>,
191    <J as IntoIterator>::IntoIter: TrustedLen + Send,
192    <I as IntoIterator>::IntoIter: TrustedLen + Send,
193    T: Send + Sync + TotalHash + TotalEq + IsNull + ToTotalOrd,
194    <T as ToTotalOrd>::TotalOrdItem: Send + Sync + Hash + Eq + IsNull,
195{
196    let probe = probe.into_iter().map(|i| i.into_iter()).collect::<Vec<_>>();
197    let build = build.into_iter().map(|i| i.into_iter()).collect::<Vec<_>>();
198    // This function is partially multi-threaded.
199    // Parts that are done in parallel:
200    //  - creation of the probe tables
201    //  - creation of the hashes
202
203    // during the probe phase values are removed from the tables, that's done single threaded to
204    // keep it lock free.
205
206    let size = probe
207        .iter()
208        .map(|a| a.size_hint().1.unwrap())
209        .sum::<usize>()
210        + build
211            .iter()
212            .map(|b| b.size_hint().1.unwrap())
213            .sum::<usize>();
214    let mut results = (
215        MutablePrimitiveArray::with_capacity(size),
216        MutablePrimitiveArray::with_capacity(size),
217    );
218
219    // prepare hash table
220    let mut hash_tbls = if validate.needs_checks() {
221        let expected_size = build.iter().map(|i| i.size_hint().0).sum();
222        let hash_tbls = prepare_hashed_relation_threaded(build);
223        let build_size = hash_tbls.iter().map(|m| m.len()).sum();
224        validate.validate_build(build_size, expected_size, swapped)?;
225        hash_tbls
226    } else {
227        prepare_hashed_relation_threaded(build)
228    };
229    let random_state = hash_tbls[0].hasher();
230
231    // we pre hash the probing values
232    let (probe_hashes, _) = create_hash_and_keys_threaded_vectorized(probe, Some(*random_state));
233
234    let n_tables = hash_tbls.len();
235    try_raise_keyboard_interrupt();
236
237    // probe the hash table.
238    // Note: indexes from b that are not matched will be None, Some(idx_b)
239    // Therefore we remove the matches and the remaining will be joined from the right
240
241    // branch is because we want to only do the swap check once
242    if swapped {
243        probe_outer(
244            &probe_hashes,
245            &mut hash_tbls,
246            &mut results,
247            n_tables,
248            |idx_a, idx_b| (Some(idx_b), Some(idx_a)),
249            |idx_a| (None, Some(idx_a)),
250            |idx_b| (Some(idx_b), None),
251            nulls_equal,
252        )
253    } else {
254        probe_outer(
255            &probe_hashes,
256            &mut hash_tbls,
257            &mut results,
258            n_tables,
259            |idx_a, idx_b| (Some(idx_a), Some(idx_b)),
260            |idx_a| (Some(idx_a), None),
261            |idx_b| (None, Some(idx_b)),
262            nulls_equal,
263        )
264    }
265    Ok((results.0.into(), results.1.into()))
266}