Skip to main content

polars_ops/frame/join/hash_join/
sort_merge.rs

1#[cfg(feature = "performant")]
2use arrow::legacy::kernels::sorted_join;
3#[cfg(feature = "performant")]
4use polars_core::utils::_split_offsets;
5#[cfg(feature = "performant")]
6use polars_core::utils::flatten::flatten_par;
7
8use super::*;
9
10#[cfg(feature = "performant")]
11fn par_sorted_merge_left_impl<T>(
12    s_left: &ChunkedArray<T>,
13    s_right: &ChunkedArray<T>,
14) -> (Vec<IdxSize>, Vec<NullableIdxSize>)
15where
16    T: PolarsNumericType,
17{
18    let offsets = _split_offsets(s_left.len(), RAYON.current_num_threads());
19    let s_left = s_left.rechunk();
20    let s_right = s_right.rechunk();
21
22    // we can unwrap because we should not have nulls
23    let slice_left = s_left.cont_slice().unwrap();
24    let slice_right = s_right.cont_slice().unwrap();
25
26    let indexes = offsets.into_par_iter().map(|(offset, len)| {
27        let slice_left = &slice_left[offset..offset + len];
28        sorted_join::left::join(slice_left, slice_right, offset as IdxSize)
29    });
30    let indexes = RAYON.install(|| indexes.collect::<Vec<_>>());
31
32    let lefts = indexes.iter().map(|t| &t.0).collect::<Vec<_>>();
33    let rights = indexes.iter().map(|t| &t.1).collect::<Vec<_>>();
34
35    (flatten_par(&lefts), flatten_par(&rights))
36}
37
38#[cfg(feature = "performant")]
39pub(super) fn par_sorted_merge_left(
40    s_left: &Series,
41    s_right: &Series,
42) -> (Vec<IdxSize>, Vec<NullableIdxSize>) {
43    // Don't use bit_repr here. It messes up sortedness.
44    debug_assert_eq!(s_left.dtype(), s_right.dtype());
45    let s_left = s_left.to_physical_repr();
46    let s_right = s_right.to_physical_repr();
47
48    match s_left.dtype() {
49        #[cfg(feature = "dtype-i8")]
50        DataType::Int8 => par_sorted_merge_left_impl(s_left.i8().unwrap(), s_right.i8().unwrap()),
51        #[cfg(feature = "dtype-u8")]
52        DataType::UInt8 => par_sorted_merge_left_impl(s_left.u8().unwrap(), s_right.u8().unwrap()),
53        #[cfg(feature = "dtype-u16")]
54        DataType::UInt16 => {
55            par_sorted_merge_left_impl(s_left.u16().unwrap(), s_right.u16().unwrap())
56        },
57        #[cfg(feature = "dtype-i16")]
58        DataType::Int16 => {
59            par_sorted_merge_left_impl(s_left.i16().unwrap(), s_right.i16().unwrap())
60        },
61        DataType::UInt32 => {
62            par_sorted_merge_left_impl(s_left.u32().unwrap(), s_right.u32().unwrap())
63        },
64        DataType::Int32 => {
65            par_sorted_merge_left_impl(s_left.i32().unwrap(), s_right.i32().unwrap())
66        },
67        DataType::UInt64 => {
68            par_sorted_merge_left_impl(s_left.u64().unwrap(), s_right.u64().unwrap())
69        },
70        DataType::Int64 => {
71            par_sorted_merge_left_impl(s_left.i64().unwrap(), s_right.i64().unwrap())
72        },
73        #[cfg(feature = "dtype-u128")]
74        DataType::UInt128 => {
75            par_sorted_merge_left_impl(s_left.u128().unwrap(), s_right.u128().unwrap())
76        },
77        #[cfg(feature = "dtype-i128")]
78        DataType::Int128 => {
79            par_sorted_merge_left_impl(s_left.i128().unwrap(), s_right.i128().unwrap())
80        },
81        #[cfg(feature = "dtype-f16")]
82        DataType::Float16 => {
83            par_sorted_merge_left_impl(s_left.f16().unwrap(), s_right.f16().unwrap())
84        },
85        DataType::Float32 => {
86            par_sorted_merge_left_impl(s_left.f32().unwrap(), s_right.f32().unwrap())
87        },
88        DataType::Float64 => {
89            par_sorted_merge_left_impl(s_left.f64().unwrap(), s_right.f64().unwrap())
90        },
91        dt => panic!("{dt:?}"),
92    }
93}
94#[cfg(feature = "performant")]
95fn par_sorted_merge_inner_impl<T>(
96    s_left: &ChunkedArray<T>,
97    s_right: &ChunkedArray<T>,
98) -> (Vec<IdxSize>, Vec<IdxSize>)
99where
100    T: PolarsNumericType,
101{
102    let offsets = _split_offsets(s_left.len(), RAYON.current_num_threads());
103    let s_left = s_left.rechunk();
104    let s_right = s_right.rechunk();
105
106    // we can unwrap because we should not have nulls
107    let slice_left = s_left.cont_slice().unwrap();
108    let slice_right = s_right.cont_slice().unwrap();
109
110    let indexes = offsets.into_par_iter().map(|(offset, len)| {
111        let slice_left = &slice_left[offset..offset + len];
112        sorted_join::inner::join(slice_left, slice_right, offset as IdxSize)
113    });
114    let indexes = RAYON.install(|| indexes.collect::<Vec<_>>());
115
116    let lefts = indexes.iter().map(|t| &t.0).collect::<Vec<_>>();
117    let rights = indexes.iter().map(|t| &t.1).collect::<Vec<_>>();
118
119    (flatten_par(&lefts), flatten_par(&rights))
120}
121
122#[cfg(feature = "performant")]
123pub(super) fn par_sorted_merge_inner_no_nulls(
124    s_left: &Series,
125    s_right: &Series,
126) -> (Vec<IdxSize>, Vec<IdxSize>) {
127    // Don't use bit_repr here. It messes up sortedness.
128    debug_assert_eq!(s_left.dtype(), s_right.dtype());
129    let s_left = s_left.to_physical_repr();
130    let s_right = s_right.to_physical_repr();
131
132    match s_left.dtype() {
133        #[cfg(feature = "dtype-i8")]
134        DataType::Int8 => par_sorted_merge_inner_impl(s_left.i8().unwrap(), s_right.i8().unwrap()),
135        #[cfg(feature = "dtype-u8")]
136        DataType::UInt8 => par_sorted_merge_inner_impl(s_left.u8().unwrap(), s_right.u8().unwrap()),
137        #[cfg(feature = "dtype-u16")]
138        DataType::UInt16 => {
139            par_sorted_merge_inner_impl(s_left.u16().unwrap(), s_right.u16().unwrap())
140        },
141        #[cfg(feature = "dtype-i16")]
142        DataType::Int16 => {
143            par_sorted_merge_inner_impl(s_left.i16().unwrap(), s_right.i16().unwrap())
144        },
145        DataType::UInt32 => {
146            par_sorted_merge_inner_impl(s_left.u32().unwrap(), s_right.u32().unwrap())
147        },
148        DataType::Int32 => {
149            par_sorted_merge_inner_impl(s_left.i32().unwrap(), s_right.i32().unwrap())
150        },
151        DataType::UInt64 => {
152            par_sorted_merge_inner_impl(s_left.u64().unwrap(), s_right.u64().unwrap())
153        },
154        DataType::Int64 => {
155            par_sorted_merge_inner_impl(s_left.i64().unwrap(), s_right.i64().unwrap())
156        },
157        #[cfg(feature = "dtype-u128")]
158        DataType::UInt128 => {
159            par_sorted_merge_inner_impl(s_left.u128().unwrap(), s_right.u128().unwrap())
160        },
161        #[cfg(feature = "dtype-i128")]
162        DataType::Int128 => {
163            par_sorted_merge_inner_impl(s_left.i128().unwrap(), s_right.i128().unwrap())
164        },
165        #[cfg(feature = "dtype-f16")]
166        DataType::Float16 => {
167            par_sorted_merge_inner_impl(s_left.f16().unwrap(), s_right.f16().unwrap())
168        },
169        DataType::Float32 => {
170            par_sorted_merge_inner_impl(s_left.f32().unwrap(), s_right.f32().unwrap())
171        },
172        DataType::Float64 => {
173            par_sorted_merge_inner_impl(s_left.f64().unwrap(), s_right.f64().unwrap())
174        },
175        _ => unreachable!(),
176    }
177}
178
179pub(crate) fn to_left_join_ids(
180    left_idx: Vec<IdxSize>,
181    right_idx: Vec<NullableIdxSize>,
182) -> LeftJoinIds {
183    #[cfg(feature = "chunked_ids")]
184    {
185        (Either::Left(left_idx), Either::Left(right_idx))
186    }
187
188    #[cfg(not(feature = "chunked_ids"))]
189    {
190        (left_idx, right_idx)
191    }
192}
193
194#[cfg(feature = "performant")]
195fn create_reverse_map_from_arg_sort(mut arg_sort: IdxCa) -> Vec<IdxSize> {
196    let arr = unsafe { arg_sort.chunks_mut() }.pop().unwrap();
197    primitive_to_vec::<IdxSize>(arr).unwrap()
198}
199
200#[cfg(not(feature = "performant"))]
201pub(crate) fn _sort_or_hash_inner(
202    s_left: &Series,
203    s_right: &Series,
204    _verbose: bool,
205    validate: JoinValidation,
206    nulls_equal: bool,
207) -> PolarsResult<(InnerJoinIds, bool)> {
208    s_left.hash_join_inner(s_right, validate, nulls_equal)
209}
210
211#[cfg(feature = "performant")]
212pub(crate) fn _sort_or_hash_inner(
213    s_left: &Series,
214    s_right: &Series,
215    verbose: bool,
216    validate: JoinValidation,
217    nulls_equal: bool,
218) -> PolarsResult<(InnerJoinIds, bool)> {
219    // We check if keys are sorted.
220    // - If they are we can do a sorted merge join
221    // If one of the keys is not, it can still be faster to sort that key and use
222    // the `arg_sort` indices to revert the sort once the join keys are determined.
223    let size_factor_rhs = s_right.len() as f32 / s_left.len() as f32;
224    let size_factor_lhs = s_left.len() as f32 / s_right.len() as f32;
225    let size_factor_acceptable = std::env::var("POLARS_JOIN_SORT_FACTOR")
226        .map(|s| s.parse::<f32>().unwrap())
227        .unwrap_or(1.0);
228    let is_numeric = s_left.dtype().to_physical().is_primitive_numeric();
229
230    if validate.needs_checks() {
231        return s_left.hash_join_inner(s_right, validate, nulls_equal);
232    }
233
234    let no_nulls = s_left.null_count() == 0 && s_right.null_count() == 0;
235    match (s_left.is_sorted_flag(), s_right.is_sorted_flag(), no_nulls) {
236        (IsSorted::Ascending, IsSorted::Ascending, true) if is_numeric => {
237            if verbose {
238                eprintln!("inner join: keys are sorted: use sorted merge join");
239            }
240            Ok((par_sorted_merge_inner_no_nulls(s_left, s_right), true))
241        },
242        (IsSorted::Ascending, _, true)
243            if is_numeric && size_factor_rhs < size_factor_acceptable =>
244        {
245            if verbose {
246                eprintln!("right key will be descending sorted in inner join operation.")
247            }
248
249            let sort_idx = s_right.arg_sort(SortOptions {
250                descending: false,
251                nulls_last: false,
252                multithreaded: true,
253                maintain_order: false,
254                limit: None,
255            });
256            let s_right = unsafe { s_right.take_unchecked(&sort_idx) };
257            let ids = par_sorted_merge_inner_no_nulls(s_left, &s_right);
258            let reverse_idx_map = create_reverse_map_from_arg_sort(sort_idx);
259
260            let (left, mut right) = ids;
261
262            RAYON.install(|| {
263                right.par_iter_mut().for_each(|idx| {
264                    *idx = unsafe { *reverse_idx_map.get_unchecked(*idx as usize) };
265                });
266            });
267
268            Ok(((left, right), true))
269        },
270        (_, IsSorted::Ascending, true)
271            if is_numeric && size_factor_lhs < size_factor_acceptable =>
272        {
273            if verbose {
274                eprintln!("left key will be descending sorted in inner join operation.")
275            }
276
277            let sort_idx = s_left.arg_sort(SortOptions {
278                descending: false,
279                nulls_last: false,
280                multithreaded: true,
281                maintain_order: false,
282                limit: None,
283            });
284            let s_left = unsafe { s_left.take_unchecked(&sort_idx) };
285            let ids = par_sorted_merge_inner_no_nulls(&s_left, s_right);
286            let reverse_idx_map = create_reverse_map_from_arg_sort(sort_idx);
287
288            let (mut left, right) = ids;
289
290            RAYON.install(|| {
291                left.par_iter_mut().for_each(|idx| {
292                    *idx = unsafe { *reverse_idx_map.get_unchecked(*idx as usize) };
293                });
294            });
295
296            // set sorted to `false` as we descending sorted the left key.
297            Ok(((left, right), false))
298        },
299        _ => s_left.hash_join_inner(s_right, validate, nulls_equal),
300    }
301}
302
303#[cfg(not(feature = "performant"))]
304pub(crate) fn sort_or_hash_left(
305    s_left: &Series,
306    s_right: &Series,
307    _verbose: bool,
308    validate: JoinValidation,
309    nulls_equal: bool,
310) -> PolarsResult<LeftJoinIds> {
311    s_left.hash_join_left(s_right, validate, nulls_equal)
312}
313
314#[cfg(feature = "performant")]
315pub(crate) fn sort_or_hash_left(
316    s_left: &Series,
317    s_right: &Series,
318    verbose: bool,
319    validate: JoinValidation,
320    nulls_equal: bool,
321) -> PolarsResult<LeftJoinIds> {
322    if validate.needs_checks() {
323        return s_left.hash_join_left(s_right, validate, nulls_equal);
324    }
325
326    let size_factor_rhs = s_right.len() as f32 / s_left.len() as f32;
327    let size_factor_acceptable = std::env::var("POLARS_JOIN_SORT_FACTOR")
328        .map(|s| s.parse::<f32>().unwrap())
329        .unwrap_or(1.0);
330    let is_numeric = s_left.dtype().to_physical().is_primitive_numeric();
331
332    let no_nulls = s_left.null_count() == 0 && s_right.null_count() == 0;
333
334    match (s_left.is_sorted_flag(), s_right.is_sorted_flag(), no_nulls) {
335        (IsSorted::Ascending, IsSorted::Ascending, true) if is_numeric => {
336            if verbose {
337                eprintln!("left join: keys are sorted: use sorted merge join");
338            }
339            let (left_idx, right_idx) = par_sorted_merge_left(s_left, s_right);
340            Ok(to_left_join_ids(left_idx, right_idx))
341        },
342        (IsSorted::Ascending, _, true)
343            if is_numeric && size_factor_rhs < size_factor_acceptable =>
344        {
345            if verbose {
346                eprintln!("right key will be reverse sorted in left join operation.")
347            }
348
349            let sort_idx = s_right.arg_sort(SortOptions {
350                descending: false,
351                nulls_last: false,
352                multithreaded: true,
353                maintain_order: false,
354                limit: None,
355            });
356            let s_right = unsafe { s_right.take_unchecked(&sort_idx) };
357
358            let ids = par_sorted_merge_left(s_left, &s_right);
359            let reverse_idx_map = create_reverse_map_from_arg_sort(sort_idx);
360            let (left, mut right) = ids;
361
362            RAYON.install(|| {
363                right.par_iter_mut().for_each(|opt_idx| {
364                    if !opt_idx.is_null_idx() {
365                        *opt_idx =
366                            unsafe { *reverse_idx_map.get_unchecked(opt_idx.idx() as usize) }
367                                .into();
368                    }
369                });
370            });
371
372            Ok(to_left_join_ids(left, right))
373        },
374        // don't reverse sort a left join key yet. Have to figure out how to set sorted flag
375        _ => s_left.hash_join_left(s_right, validate, nulls_equal),
376    }
377}