polars_ops/frame/join/
merge_join.rs

1use std::borrow::Cow;
2use std::cmp::Ordering;
3use std::iter::repeat_n;
4
5use arrow::array::Array;
6use arrow::array::builder::ShareStrategy;
7use polars_core::frame::builder::DataFrameBuilder;
8use polars_core::prelude::*;
9use polars_core::with_match_physical_numeric_polars_type;
10use polars_utils::itertools::Itertools;
11use polars_utils::total_ord::TotalOrd;
12use polars_utils::{IdxSize, format_pl_smallstr};
13
14use crate::frame::{JoinArgs, JoinType};
15use crate::series::coalesce_columns;
16
17#[allow(clippy::too_many_arguments)]
18pub fn match_keys(
19    build_keys: &Series,
20    probe_keys: &Series,
21    gather_build: &mut Vec<IdxSize>,
22    gather_probe: &mut Vec<IdxSize>,
23    gather_probe_unmatched: Option<&mut Vec<IdxSize>>,
24    build_emit_unmatched: bool,
25    descending: bool,
26    nulls_equal: bool,
27    limit_results: usize,
28    build_row_offset: &mut usize,
29    probe_row_offset: &mut usize,
30    probe_last_matched: &mut usize,
31) {
32    macro_rules! dispatch {
33        ($build_keys_ca:expr) => {
34            match_keys_impl(
35                $build_keys_ca,
36                probe_keys.as_ref().as_ref(),
37                gather_build,
38                gather_probe,
39                gather_probe_unmatched,
40                build_emit_unmatched,
41                descending,
42                nulls_equal,
43                limit_results,
44                build_row_offset,
45                probe_row_offset,
46                probe_last_matched,
47            )
48        };
49    }
50
51    assert_eq!(build_keys.dtype(), probe_keys.dtype());
52    match build_keys.dtype() {
53        dt if dt.is_primitive_numeric() => {
54            with_match_physical_numeric_polars_type!(dt, |$T| {
55                type PhysCa = ChunkedArray<$T>;
56                let build_keys_ca: &PhysCa  = build_keys.as_ref().as_ref();
57                dispatch!(build_keys_ca)
58            })
59        },
60        DataType::Boolean => dispatch!(build_keys.bool().unwrap()),
61        DataType::String => dispatch!(build_keys.str().unwrap()),
62        DataType::Binary => dispatch!(build_keys.binary().unwrap()),
63        DataType::BinaryOffset => dispatch!(build_keys.binary_offset().unwrap()),
64        #[cfg(feature = "dtype-categorical")]
65        DataType::Enum(cats, _) => with_match_categorical_physical_type!(cats.physical(), |$C| {
66            type PhysCa = ChunkedArray<<$C as PolarsCategoricalType>::PolarsPhysical>;
67            let build_keys_ca: &PhysCa = build_keys.as_ref().as_ref();
68            dispatch!(build_keys_ca)
69        }),
70        DataType::Null => match_null_keys_impl(
71            build_keys.len(),
72            probe_keys.len(),
73            gather_build,
74            gather_probe,
75            gather_probe_unmatched,
76            build_emit_unmatched,
77            descending,
78            nulls_equal,
79            limit_results,
80            build_row_offset,
81            probe_row_offset,
82            probe_last_matched,
83        ),
84        dt => unimplemented!("merge-join kernel not implemented for {:?}", dt),
85    }
86}
87
88#[allow(clippy::mut_range_bound, clippy::too_many_arguments)]
89fn match_keys_impl<'a, T: PolarsDataType>(
90    build_keys: &'a ChunkedArray<T>,
91    probe_keys: &'a ChunkedArray<T>,
92    gather_build: &mut Vec<IdxSize>,
93    gather_probe: &mut Vec<IdxSize>,
94    mut gather_probe_unmatched: Option<&mut Vec<IdxSize>>,
95    build_emit_unmatched: bool,
96    descending: bool,
97    nulls_equal: bool,
98    limit_results: usize,
99    build_row_offset: &mut usize,
100    probe_row_offset: &mut usize,
101    probe_first_unmatched: &mut usize,
102) where
103    T::Physical<'a>: TotalOrd,
104{
105    assert!(gather_build.is_empty());
106    assert!(gather_probe.is_empty());
107
108    let build_key = build_keys.downcast_as_array();
109    let probe_key = probe_keys.downcast_as_array();
110
111    while *build_row_offset < build_key.len() {
112        if gather_build.len() >= limit_results {
113            return;
114        }
115
116        let build_keyval = unsafe { build_key.get_unchecked(*build_row_offset) };
117        let build_keyval = build_keyval.as_ref();
118        let mut build_keyval_matched = false;
119
120        if nulls_equal || build_keyval.is_some() {
121            for probe_idx in *probe_row_offset..probe_key.len() {
122                let probe_keyval = unsafe { probe_key.get_unchecked(probe_idx) };
123                let probe_keyval = probe_keyval.as_ref();
124
125                let mut ord: Ordering = match (&build_keyval, &probe_keyval) {
126                    (None, None) if nulls_equal => Ordering::Equal,
127                    (Some(l), Some(r)) => TotalOrd::tot_cmp(*l, *r),
128                    _ => continue,
129                };
130                if descending {
131                    ord = ord.reverse();
132                }
133
134                match ord {
135                    Ordering::Equal => {
136                        if let Some(probe_unmatched) = gather_probe_unmatched.as_mut() {
137                            // All probe keys up to and *excluding* this matched key are unmatched
138                            probe_unmatched
139                                .extend(*probe_first_unmatched as IdxSize..probe_idx as IdxSize);
140                            *probe_first_unmatched = (*probe_first_unmatched).max(probe_idx + 1);
141                        }
142                        gather_build.push(*build_row_offset as IdxSize);
143                        gather_probe.push(probe_idx as IdxSize);
144                        build_keyval_matched = true;
145                    },
146                    Ordering::Greater => {
147                        if let Some(probe_unmatched) = gather_probe_unmatched.as_mut() {
148                            // All probe keys up to and *including* this matched key are unmatched
149                            probe_unmatched
150                                .extend(*probe_first_unmatched as IdxSize..=probe_idx as IdxSize);
151                            *probe_first_unmatched = (*probe_first_unmatched).max(probe_idx + 1);
152                        }
153                        *probe_row_offset = probe_idx + 1;
154                    },
155                    Ordering::Less => {
156                        break;
157                    },
158                }
159            }
160        }
161        if build_emit_unmatched && !build_keyval_matched {
162            gather_build.push(*build_row_offset as IdxSize);
163            gather_probe.push(IdxSize::MAX);
164        }
165        *build_row_offset += 1;
166    }
167    if let Some(probe_unmatched) = gather_probe_unmatched {
168        probe_unmatched.extend(*probe_first_unmatched as IdxSize..probe_key.len() as IdxSize);
169        *probe_first_unmatched = probe_key.len();
170    }
171    *probe_row_offset = probe_key.len();
172}
173
174#[allow(clippy::mut_range_bound, clippy::too_many_arguments)]
175fn match_null_keys_impl(
176    build_n: usize,
177    probe_n: usize,
178    gather_build: &mut Vec<IdxSize>,
179    gather_probe: &mut Vec<IdxSize>,
180    gather_probe_unmatched: Option<&mut Vec<IdxSize>>,
181    build_emit_unmatched: bool,
182    _descending: bool,
183    nulls_equal: bool,
184    limit_results: usize,
185    build_row_offset: &mut usize,
186    probe_row_offset: &mut usize,
187    probe_last_matched: &mut usize,
188) {
189    assert!(gather_build.is_empty());
190    assert!(gather_probe.is_empty());
191
192    if nulls_equal {
193        // All keys will match all other keys, so just emit the Cartesian product
194        while *build_row_offset < build_n {
195            if gather_build.len() >= limit_results {
196                return;
197            }
198            for probe_idx in *probe_row_offset..probe_n {
199                gather_build.push(*build_row_offset as IdxSize);
200                gather_probe.push(probe_idx as IdxSize);
201            }
202            *build_row_offset += 1;
203        }
204    } else {
205        // No keys can ever match, so just emit all build keys into gather_build
206        // and all probe keys into gather_probe_unmatched.
207        if build_emit_unmatched {
208            gather_build.extend(0..build_n as IdxSize);
209            gather_probe.extend(repeat_n(IdxSize::MAX, build_n));
210        }
211        if let Some(probe_unmatched) = gather_probe_unmatched {
212            probe_unmatched.extend(*probe_last_matched as IdxSize..probe_n as IdxSize);
213            *probe_last_matched = probe_n;
214        }
215    }
216    *build_row_offset = build_n;
217    *probe_row_offset = probe_n;
218}
219
220#[allow(clippy::too_many_arguments)]
221pub fn gather_and_postprocess(
222    build: DataFrame,
223    probe: DataFrame,
224    gather_build: Option<&[IdxSize]>,
225    gather_probe: Option<&[IdxSize]>,
226    df_builders: &mut Option<(DataFrameBuilder, DataFrameBuilder)>,
227    args: &JoinArgs,
228    left_on: &[PlSmallStr],
229    right_on: &[PlSmallStr],
230    left_is_build: bool,
231    output_schema: &Schema,
232) -> PolarsResult<DataFrame> {
233    let should_coalesce = args.should_coalesce();
234    let left_emit_unmatched = matches!(args.how, JoinType::Left | JoinType::Full);
235    let right_emit_unmatched = matches!(args.how, JoinType::Right | JoinType::Full);
236
237    let (mut left, mut right);
238    let (gather_left, gather_right);
239    if left_is_build {
240        (left, right) = (build, probe);
241        (gather_left, gather_right) = (gather_build, gather_probe);
242    } else {
243        (left, right) = (probe, build);
244        (gather_left, gather_right) = (gather_probe, gather_build);
245    }
246
247    // Remove non-payload columns
248    for col in left
249        .columns()
250        .iter()
251        .map(Column::name)
252        .cloned()
253        .collect_vec()
254    {
255        if left_on.contains(&col) && should_coalesce {
256            continue;
257        }
258        if !output_schema.contains(&col) {
259            left.drop_in_place(&col).unwrap();
260        }
261    }
262    for col in right
263        .columns()
264        .iter()
265        .map(Column::name)
266        .cloned()
267        .collect_vec()
268    {
269        if left_on.contains(&col) && should_coalesce {
270            continue;
271        }
272        let renamed = match left.schema().contains(&col) {
273            true => Cow::Owned(format_pl_smallstr!("{}{}", col, args.suffix())),
274            false => Cow::Borrowed(&col),
275        };
276        if !output_schema.contains(&renamed) {
277            right.drop_in_place(&col).unwrap();
278        }
279    }
280
281    if df_builders.is_none() {
282        *df_builders = Some((
283            DataFrameBuilder::new(left.schema().clone()),
284            DataFrameBuilder::new(right.schema().clone()),
285        ));
286    }
287
288    let (left_build, right_build) = df_builders.as_mut().unwrap();
289    let mut left = match gather_left {
290        Some(gather_left) if right_emit_unmatched => {
291            left_build.opt_gather_extend(&left, gather_left, ShareStrategy::Never);
292            left_build.freeze_reset()
293        },
294        Some(gather_left) => unsafe {
295            left_build.gather_extend(&left, gather_left, ShareStrategy::Never);
296            left_build.freeze_reset()
297        },
298        None => DataFrame::full_null(left.schema(), gather_right.unwrap().len()),
299    };
300    let mut right = match gather_right {
301        Some(gather_right) if left_emit_unmatched => {
302            right_build.opt_gather_extend(&right, gather_right, ShareStrategy::Never);
303            right_build.freeze_reset()
304        },
305        Some(gather_right) => unsafe {
306            right_build.gather_extend(&right, gather_right, ShareStrategy::Never);
307            right_build.freeze_reset()
308        },
309        None => DataFrame::full_null(right.schema(), gather_left.unwrap().len()),
310    };
311
312    // Coalsesce the key columns
313    if args.how == JoinType::Left && should_coalesce {
314        for c in left_on {
315            if right.schema().contains(c) {
316                right.drop_in_place(c.as_str())?;
317            }
318        }
319    } else if args.how == JoinType::Right && should_coalesce {
320        for c in right_on {
321            if left.schema().contains(c) {
322                left.drop_in_place(c.as_str())?;
323            }
324        }
325    }
326
327    // Rename any right columns to "{}_right"
328    let left_cols: PlHashSet<_> = left.columns().iter().map(Column::name).cloned().collect();
329    let right_cols_vec = right.get_column_names_owned();
330    let renames = right_cols_vec
331        .iter()
332        .filter(|c| left_cols.contains(*c))
333        .map(|c| {
334            let renamed = format_pl_smallstr!("{}{}", c, args.suffix());
335            (c.as_str(), renamed)
336        });
337    right.rename_many(renames).unwrap();
338
339    left.hstack_mut(right.columns())?;
340
341    if args.how == JoinType::Full && should_coalesce {
342        // Coalesce key columns
343        for (left_keycol, right_keycol) in Iterator::zip(left_on.iter(), right_on.iter()) {
344            let right_keycol = format_pl_smallstr!("{}{}", right_keycol, args.suffix());
345            let left_col = left.column(left_keycol).unwrap();
346            let right_col = left.column(&right_keycol).unwrap();
347            let coalesced = coalesce_columns(&[left_col.clone(), right_col.clone()]).unwrap();
348            left.replace(left_keycol, coalesced)
349                .unwrap()
350                .drop_in_place(&right_keycol)
351                .unwrap();
352        }
353    }
354
355    if should_coalesce {
356        for col in left_on {
357            if left.schema().contains(col) && !output_schema.contains(col) {
358                left.drop_in_place(col).unwrap();
359            }
360        }
361        for col in right_on {
362            let renamed = match left.schema().contains(col) {
363                true => Cow::Owned(format_pl_smallstr!("{}{}", col, args.suffix())),
364                false => Cow::Borrowed(col),
365            };
366            if left.schema().contains(&renamed) && !output_schema.contains(&renamed) {
367                left.drop_in_place(&renamed).unwrap();
368            }
369        }
370    }
371
372    debug_assert_eq!(**left.schema(), *output_schema);
373    Ok(left)
374}