polars_ops/frame/join/hash_join/
single_keys_inner.rs
1use polars_core::utils::flatten;
2use polars_utils::hashing::{DirtyHash, hash_to_partition};
3use polars_utils::idx_vec::IdxVec;
4use polars_utils::itertools::Itertools;
5use polars_utils::nulls::IsNull;
6use polars_utils::sync::SyncPtr;
7use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};
8
9use super::*;
10
11pub(super) fn probe_inner<T, F, I>(
12 probe: I,
13 hash_tbls: &[PlHashMap<<T as ToTotalOrd>::TotalOrdItem, IdxVec>],
14 results: &mut Vec<(IdxSize, IdxSize)>,
15 local_offset: IdxSize,
16 n_tables: usize,
17 swap_fn: F,
18) where
19 T: TotalHash + TotalEq + DirtyHash + ToTotalOrd,
20 <T as ToTotalOrd>::TotalOrdItem: Hash + Eq + DirtyHash,
21 I: IntoIterator<Item = T>,
22 F: Fn(IdxSize, IdxSize) -> (IdxSize, IdxSize),
23{
24 probe.into_iter().enumerate_idx().for_each(|(idx_a, k)| {
25 let k = k.to_total_ord();
26 let idx_a = idx_a + local_offset;
27 let current_probe_table =
29 unsafe { hash_tbls.get_unchecked(hash_to_partition(k.dirty_hash(), n_tables)) };
30
31 let value = current_probe_table.get(&k);
32
33 if let Some(indexes_b) = value {
34 let tuples = indexes_b.iter().map(|&idx_b| swap_fn(idx_a, idx_b));
35 results.extend(tuples);
36 }
37 });
38}
39
40pub(super) fn hash_join_tuples_inner<T, I>(
41 probe: Vec<I>,
42 build: Vec<I>,
43 swapped: bool,
45 validate: JoinValidation,
46 nulls_equal: bool,
47 build_null_count: usize,
49) -> PolarsResult<(Vec<IdxSize>, Vec<IdxSize>)>
50where
51 I: IntoIterator<Item = T> + Send + Sync + Clone,
52 T: Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + ToTotalOrd,
53 <T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull,
54{
55 let hash_tbls = if validate.needs_checks() {
58 let mut expected_size = build
59 .iter()
60 .map(|v| v.clone().into_iter().size_hint().1.unwrap())
61 .sum();
62 if !nulls_equal {
63 expected_size -= build_null_count;
64 }
65 let hash_tbls = build_tables(build, nulls_equal);
66 let build_size = hash_tbls.iter().map(|m| m.len()).sum();
67 validate.validate_build(build_size, expected_size, swapped)?;
68 hash_tbls
69 } else {
70 build_tables(build, nulls_equal)
71 };
72 try_raise_keyboard_interrupt();
73
74 let n_tables = hash_tbls.len();
75 let offsets = probe_to_offsets(&probe);
76 let out = POOL.install(|| {
79 let tuples = probe
80 .into_par_iter()
81 .zip(offsets)
82 .map(|(probe, offset)| {
83 let probe = probe.into_iter();
84 let hash_tbls = &hash_tbls;
86 let mut results = Vec::with_capacity(probe.size_hint().1.unwrap());
87 let local_offset = offset as IdxSize;
88
89 if swapped {
91 probe_inner(
92 probe,
93 hash_tbls,
94 &mut results,
95 local_offset,
96 n_tables,
97 |idx_a, idx_b| (idx_b, idx_a),
98 )
99 } else {
100 probe_inner(
101 probe,
102 hash_tbls,
103 &mut results,
104 local_offset,
105 n_tables,
106 |idx_a, idx_b| (idx_a, idx_b),
107 )
108 }
109
110 results
111 })
112 .collect::<Vec<_>>();
113
114 let (cap, offsets) = flatten::cap_and_offsets(&tuples);
116 let mut left = Vec::with_capacity(cap);
117 let mut right = Vec::with_capacity(cap);
118
119 let left_ptr = unsafe { SyncPtr::new(left.as_mut_ptr()) };
120 let right_ptr = unsafe { SyncPtr::new(right.as_mut_ptr()) };
121
122 tuples
123 .into_par_iter()
124 .zip(offsets)
125 .for_each(|(tuples, offset)| unsafe {
126 let left_ptr: *mut IdxSize = left_ptr.get();
127 let left_ptr = left_ptr.add(offset);
128 let right_ptr: *mut IdxSize = right_ptr.get();
129 let right_ptr = right_ptr.add(offset);
130
131 for i in 0..tuples.len() {
133 let tuple = tuples.get_unchecked(i);
134 let left_row_idx = tuple.0;
135 let right_row_idx = tuple.1;
136
137 std::ptr::write(left_ptr.add(i), left_row_idx);
138 std::ptr::write(right_ptr.add(i), right_row_idx);
139 }
140 });
141 unsafe {
142 left.set_len(cap);
143 right.set_len(cap);
144 }
145
146 (left, right)
147 });
148 Ok(out)
149}