polars_ops/frame/join/hash_join/
single_keys_left.rs

1use polars_core::utils::flatten::flatten_par;
2use polars_utils::hashing::{DirtyHash, hash_to_partition};
3use polars_utils::nulls::IsNull;
4use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};
5
6use super::*;
7
8#[cfg(feature = "chunked_ids")]
9unsafe fn apply_mapping(idx: Vec<IdxSize>, chunk_mapping: &[ChunkId]) -> Vec<ChunkId> {
10    idx.iter()
11        .map(|idx| *chunk_mapping.get_unchecked(*idx as usize))
12        .collect()
13}
14
15#[cfg(feature = "chunked_ids")]
16unsafe fn apply_opt_mapping(idx: Vec<NullableIdxSize>, chunk_mapping: &[ChunkId]) -> Vec<ChunkId> {
17    idx.iter()
18        .map(|opt_idx| {
19            if opt_idx.is_null_idx() {
20                ChunkId::null()
21            } else {
22                *chunk_mapping.get_unchecked(opt_idx.idx() as usize)
23            }
24        })
25        .collect()
26}
27
28#[cfg(feature = "chunked_ids")]
29pub(super) fn finish_left_join_mappings(
30    result_idx_left: Vec<IdxSize>,
31    result_idx_right: Vec<NullableIdxSize>,
32    chunk_mapping_left: Option<&[ChunkId]>,
33    chunk_mapping_right: Option<&[ChunkId]>,
34) -> LeftJoinIds {
35    let left = match chunk_mapping_left {
36        None => ChunkJoinIds::Left(result_idx_left),
37        Some(mapping) => ChunkJoinIds::Right(unsafe { apply_mapping(result_idx_left, mapping) }),
38    };
39
40    let right = match chunk_mapping_right {
41        None => ChunkJoinOptIds::Left(result_idx_right),
42        Some(mapping) => {
43            ChunkJoinOptIds::Right(unsafe { apply_opt_mapping(result_idx_right, mapping) })
44        },
45    };
46    (left, right)
47}
48
49#[cfg(not(feature = "chunked_ids"))]
50pub(super) fn finish_left_join_mappings(
51    _result_idx_left: Vec<IdxSize>,
52    _result_idx_right: Vec<NullableIdxSize>,
53    _chunk_mapping_left: Option<&[ChunkId]>,
54    _chunk_mapping_right: Option<&[ChunkId]>,
55) -> LeftJoinIds {
56    (_result_idx_left, _result_idx_right)
57}
58
59pub(super) fn flatten_left_join_ids(result: Vec<LeftJoinIds>) -> LeftJoinIds {
60    #[cfg(feature = "chunked_ids")]
61    {
62        let left = if result[0].0.is_left() {
63            let lefts = result
64                .iter()
65                .map(|join_id| join_id.0.as_ref().left().unwrap())
66                .collect::<Vec<_>>();
67            let lefts = flatten_par(&lefts);
68            ChunkJoinIds::Left(lefts)
69        } else {
70            let lefts = result
71                .iter()
72                .map(|join_id| join_id.0.as_ref().right().unwrap())
73                .collect::<Vec<_>>();
74            let lefts = flatten_par(&lefts);
75            ChunkJoinIds::Right(lefts)
76        };
77
78        let right = if result[0].1.is_left() {
79            let rights = result
80                .iter()
81                .map(|join_id| join_id.1.as_ref().left().unwrap())
82                .collect::<Vec<_>>();
83            let rights = flatten_par(&rights);
84            ChunkJoinOptIds::Left(rights)
85        } else {
86            let rights = result
87                .iter()
88                .map(|join_id| join_id.1.as_ref().right().unwrap())
89                .collect::<Vec<_>>();
90            let rights = flatten_par(&rights);
91            ChunkJoinOptIds::Right(rights)
92        };
93
94        (left, right)
95    }
96    #[cfg(not(feature = "chunked_ids"))]
97    {
98        let lefts = result.iter().map(|join_id| &join_id.0).collect::<Vec<_>>();
99        let rights = result.iter().map(|join_id| &join_id.1).collect::<Vec<_>>();
100        let lefts = flatten_par(&lefts);
101        let rights = flatten_par(&rights);
102        (lefts, rights)
103    }
104}
105
106pub(super) fn hash_join_tuples_left<T, I>(
107    probe: Vec<I>,
108    build: Vec<I>,
109    // map the global indices to [chunk_idx, array_idx]
110    // only needed if we have non contiguous memory
111    chunk_mapping_left: Option<&[ChunkId]>,
112    chunk_mapping_right: Option<&[ChunkId]>,
113    validate: JoinValidation,
114    nulls_equal: bool,
115    // We should know the number of nulls to avoid extra calculation
116    build_null_count: usize,
117) -> PolarsResult<LeftJoinIds>
118where
119    I: IntoIterator<Item = T>,
120    <I as IntoIterator>::IntoIter: Send + Sync + Clone,
121    T: Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd,
122    <T as ToTotalOrd>::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull,
123{
124    let probe = probe.into_iter().map(|i| i.into_iter()).collect::<Vec<_>>();
125    let build = build.into_iter().map(|i| i.into_iter()).collect::<Vec<_>>();
126    // first we hash one relation
127    let hash_tbls = if validate.needs_checks() {
128        let mut expected_size = build.iter().map(|v| v.size_hint().1.unwrap()).sum();
129        if !nulls_equal {
130            expected_size -= build_null_count;
131        }
132        let hash_tbls = build_tables(build, nulls_equal);
133        let build_size = hash_tbls.iter().map(|m| m.len()).sum();
134        validate.validate_build(build_size, expected_size, false)?;
135        hash_tbls
136    } else {
137        build_tables(build, nulls_equal)
138    };
139    try_raise_keyboard_interrupt();
140    let n_tables = hash_tbls.len();
141
142    // we determine the offset so that we later know which index to store in the join tuples
143    let offsets = probe_to_offsets(&probe);
144
145    // next we probe the other relation
146    let result: Vec<LeftJoinIds> = POOL.install(move || {
147        probe
148            .into_par_iter()
149            .zip(offsets)
150            // probes_hashes: Vec<u64> processed by this thread
151            // offset: offset index
152            .map(move |(probe, offset)| {
153                // local reference
154                let hash_tbls = &hash_tbls;
155
156                // assume the result tuples equal length of the no. of hashes processed by this thread.
157                let mut result_idx_left = Vec::with_capacity(probe.size_hint().1.unwrap());
158                let mut result_idx_right = Vec::with_capacity(probe.size_hint().1.unwrap());
159
160                probe.enumerate().for_each(|(idx_a, k)| {
161                    let k = k.to_total_ord();
162                    let idx_a = (idx_a + offset) as IdxSize;
163                    // probe table that contains the hashed value
164                    let current_probe_table = unsafe {
165                        hash_tbls.get_unchecked(hash_to_partition(k.dirty_hash(), n_tables))
166                    };
167
168                    // we already hashed, so we don't have to hash again.
169                    let value = current_probe_table.get(&k);
170
171                    match value {
172                        // left and right matches
173                        Some(indexes_b) => {
174                            result_idx_left.extend(std::iter::repeat_n(idx_a, indexes_b.len()));
175                            result_idx_right.extend_from_slice(bytemuck::cast_slice(indexes_b));
176                        },
177                        // only left values, right = null
178                        None => {
179                            result_idx_left.push(idx_a);
180                            result_idx_right.push(NullableIdxSize::null());
181                        },
182                    }
183                });
184                finish_left_join_mappings(
185                    result_idx_left,
186                    result_idx_right,
187                    chunk_mapping_left,
188                    chunk_mapping_right,
189                )
190            })
191            .collect()
192    });
193
194    Ok(flatten_left_join_ids(result))
195}