polars_ops/frame/join/hash_join/
single_keys_inner.rs

1use polars_core::utils::flatten;
2use polars_utils::hashing::{DirtyHash, hash_to_partition};
3use polars_utils::idx_vec::IdxVec;
4use polars_utils::itertools::Itertools;
5use polars_utils::nulls::IsNull;
6use polars_utils::sync::SyncPtr;
7use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};
8
9use super::*;
10
11pub(super) fn probe_inner<T, F, I>(
12    probe: I,
13    hash_tbls: &[PlHashMap<<T as ToTotalOrd>::TotalOrdItem, IdxVec>],
14    results: &mut Vec<(IdxSize, IdxSize)>,
15    local_offset: IdxSize,
16    n_tables: usize,
17    swap_fn: F,
18) where
19    T: TotalHash + TotalEq + DirtyHash + ToTotalOrd,
20    <T as ToTotalOrd>::TotalOrdItem: Hash + Eq + DirtyHash,
21    I: IntoIterator<Item = T>,
22    F: Fn(IdxSize, IdxSize) -> (IdxSize, IdxSize),
23{
24    probe.into_iter().enumerate_idx().for_each(|(idx_a, k)| {
25        let k = k.to_total_ord();
26        let idx_a = idx_a + local_offset;
27        // probe table that contains the hashed value
28        let current_probe_table =
29            unsafe { hash_tbls.get_unchecked(hash_to_partition(k.dirty_hash(), n_tables)) };
30
31        let value = current_probe_table.get(&k);
32
33        if let Some(indexes_b) = value {
34            let tuples = indexes_b.iter().map(|&idx_b| swap_fn(idx_a, idx_b));
35            results.extend(tuples);
36        }
37    });
38}
39
40pub(super) fn hash_join_tuples_inner<T, I>(
41    probe: Vec<I>,
42    build: Vec<I>,
43    // Because b should be the shorter relation we could need to swap to keep left left and right right.
44    swapped: bool,
45    validate: JoinValidation,
46    nulls_equal: bool,
47    // Null count is required for join validation
48    build_null_count: usize,
49) -> PolarsResult<(Vec<IdxSize>, Vec<IdxSize>)>
50where
51    I: IntoIterator<Item = T> + Send + Sync + Clone,
52    T: Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + ToTotalOrd,
53    <T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull,
54{
55    // NOTE: see the left join for more elaborate comments
56    // first we hash one relation
57    let hash_tbls = if validate.needs_checks() {
58        let mut expected_size = build
59            .iter()
60            .map(|v| v.clone().into_iter().size_hint().1.unwrap())
61            .sum();
62        if !nulls_equal {
63            expected_size -= build_null_count;
64        }
65        let hash_tbls = build_tables(build, nulls_equal);
66        let build_size = hash_tbls.iter().map(|m| m.len()).sum();
67        validate.validate_build(build_size, expected_size, swapped)?;
68        hash_tbls
69    } else {
70        build_tables(build, nulls_equal)
71    };
72    try_raise_keyboard_interrupt();
73
74    let n_tables = hash_tbls.len();
75    let offsets = probe_to_offsets(&probe);
76    // next we probe the other relation
77    // code duplication is because we want to only do the swap check once
78    let out = POOL.install(|| {
79        let tuples = probe
80            .into_par_iter()
81            .zip(offsets)
82            .map(|(probe, offset)| {
83                let probe = probe.into_iter();
84                // local reference
85                let hash_tbls = &hash_tbls;
86                let mut results = Vec::with_capacity(probe.size_hint().1.unwrap());
87                let local_offset = offset as IdxSize;
88
89                // branch is to hoist swap out of the inner loop.
90                if swapped {
91                    probe_inner(
92                        probe,
93                        hash_tbls,
94                        &mut results,
95                        local_offset,
96                        n_tables,
97                        |idx_a, idx_b| (idx_b, idx_a),
98                    )
99                } else {
100                    probe_inner(
101                        probe,
102                        hash_tbls,
103                        &mut results,
104                        local_offset,
105                        n_tables,
106                        |idx_a, idx_b| (idx_a, idx_b),
107                    )
108                }
109
110                results
111            })
112            .collect::<Vec<_>>();
113
114        // parallel materialization
115        let (cap, offsets) = flatten::cap_and_offsets(&tuples);
116        let mut left = Vec::with_capacity(cap);
117        let mut right = Vec::with_capacity(cap);
118
119        let left_ptr = unsafe { SyncPtr::new(left.as_mut_ptr()) };
120        let right_ptr = unsafe { SyncPtr::new(right.as_mut_ptr()) };
121
122        tuples
123            .into_par_iter()
124            .zip(offsets)
125            .for_each(|(tuples, offset)| unsafe {
126                let left_ptr: *mut IdxSize = left_ptr.get();
127                let left_ptr = left_ptr.add(offset);
128                let right_ptr: *mut IdxSize = right_ptr.get();
129                let right_ptr = right_ptr.add(offset);
130
131                // amortize loop counter
132                for i in 0..tuples.len() {
133                    let tuple = tuples.get_unchecked(i);
134                    let left_row_idx = tuple.0;
135                    let right_row_idx = tuple.1;
136
137                    std::ptr::write(left_ptr.add(i), left_row_idx);
138                    std::ptr::write(right_ptr.add(i), right_row_idx);
139                }
140            });
141        unsafe {
142            left.set_len(cap);
143            right.set_len(cap);
144        }
145
146        (left, right)
147    });
148    Ok(out)
149}