1use std::borrow::Cow;
2use std::cmp::Ordering;
3use std::iter::repeat_n;
4
5use arrow::array::Array;
6use arrow::array::builder::ShareStrategy;
7use polars_core::frame::builder::DataFrameBuilder;
8use polars_core::prelude::*;
9use polars_core::with_match_physical_numeric_polars_type;
10use polars_utils::itertools::Itertools;
11use polars_utils::total_ord::TotalOrd;
12use polars_utils::{IdxSize, format_pl_smallstr};
13
14use crate::frame::{JoinArgs, JoinType};
15use crate::series::coalesce_columns;
16
17#[allow(clippy::too_many_arguments)]
18pub fn match_keys(
19 build_keys: &Series,
20 probe_keys: &Series,
21 gather_build: &mut Vec<IdxSize>,
22 gather_probe: &mut Vec<IdxSize>,
23 gather_probe_unmatched: Option<&mut Vec<IdxSize>>,
24 build_emit_unmatched: bool,
25 descending: bool,
26 nulls_equal: bool,
27 limit_results: usize,
28 build_row_offset: &mut usize,
29 probe_row_offset: &mut usize,
30 probe_last_matched: &mut usize,
31) {
32 macro_rules! dispatch {
33 ($build_keys_ca:expr) => {
34 match_keys_impl(
35 $build_keys_ca,
36 probe_keys.as_ref().as_ref(),
37 gather_build,
38 gather_probe,
39 gather_probe_unmatched,
40 build_emit_unmatched,
41 descending,
42 nulls_equal,
43 limit_results,
44 build_row_offset,
45 probe_row_offset,
46 probe_last_matched,
47 )
48 };
49 }
50
51 assert_eq!(build_keys.dtype(), probe_keys.dtype());
52 match build_keys.dtype() {
53 dt if dt.is_primitive_numeric() => {
54 with_match_physical_numeric_polars_type!(dt, |$T| {
55 type PhysCa = ChunkedArray<$T>;
56 let build_keys_ca: &PhysCa = build_keys.as_ref().as_ref();
57 dispatch!(build_keys_ca)
58 })
59 },
60 DataType::Boolean => dispatch!(build_keys.bool().unwrap()),
61 DataType::String => dispatch!(build_keys.str().unwrap()),
62 DataType::Binary => dispatch!(build_keys.binary().unwrap()),
63 DataType::BinaryOffset => dispatch!(build_keys.binary_offset().unwrap()),
64 #[cfg(feature = "dtype-categorical")]
65 DataType::Enum(cats, _) => with_match_categorical_physical_type!(cats.physical(), |$C| {
66 type PhysCa = ChunkedArray<<$C as PolarsCategoricalType>::PolarsPhysical>;
67 let build_keys_ca: &PhysCa = build_keys.as_ref().as_ref();
68 dispatch!(build_keys_ca)
69 }),
70 DataType::Null => match_null_keys_impl(
71 build_keys.len(),
72 probe_keys.len(),
73 gather_build,
74 gather_probe,
75 gather_probe_unmatched,
76 build_emit_unmatched,
77 descending,
78 nulls_equal,
79 limit_results,
80 build_row_offset,
81 probe_row_offset,
82 probe_last_matched,
83 ),
84 dt => unimplemented!("merge-join kernel not implemented for {:?}", dt),
85 }
86}
87
88#[allow(clippy::mut_range_bound, clippy::too_many_arguments)]
89fn match_keys_impl<'a, T: PolarsDataType>(
90 build_keys: &'a ChunkedArray<T>,
91 probe_keys: &'a ChunkedArray<T>,
92 gather_build: &mut Vec<IdxSize>,
93 gather_probe: &mut Vec<IdxSize>,
94 mut gather_probe_unmatched: Option<&mut Vec<IdxSize>>,
95 build_emit_unmatched: bool,
96 descending: bool,
97 nulls_equal: bool,
98 limit_results: usize,
99 build_row_offset: &mut usize,
100 probe_row_offset: &mut usize,
101 probe_first_unmatched: &mut usize,
102) where
103 T::Physical<'a>: TotalOrd,
104{
105 assert!(gather_build.is_empty());
106 assert!(gather_probe.is_empty());
107
108 let build_key = build_keys.downcast_as_array();
109 let probe_key = probe_keys.downcast_as_array();
110
111 while *build_row_offset < build_key.len() {
112 if gather_build.len() >= limit_results {
113 return;
114 }
115
116 let build_keyval = unsafe { build_key.get_unchecked(*build_row_offset) };
117 let build_keyval = build_keyval.as_ref();
118 let mut build_keyval_matched = false;
119
120 if nulls_equal || build_keyval.is_some() {
121 for probe_idx in *probe_row_offset..probe_key.len() {
122 let probe_keyval = unsafe { probe_key.get_unchecked(probe_idx) };
123 let probe_keyval = probe_keyval.as_ref();
124
125 let mut ord: Ordering = match (&build_keyval, &probe_keyval) {
126 (None, None) if nulls_equal => Ordering::Equal,
127 (Some(l), Some(r)) => TotalOrd::tot_cmp(*l, *r),
128 _ => continue,
129 };
130 if descending {
131 ord = ord.reverse();
132 }
133
134 match ord {
135 Ordering::Equal => {
136 if let Some(probe_unmatched) = gather_probe_unmatched.as_mut() {
137 probe_unmatched
139 .extend(*probe_first_unmatched as IdxSize..probe_idx as IdxSize);
140 *probe_first_unmatched = (*probe_first_unmatched).max(probe_idx + 1);
141 }
142 gather_build.push(*build_row_offset as IdxSize);
143 gather_probe.push(probe_idx as IdxSize);
144 build_keyval_matched = true;
145 },
146 Ordering::Greater => {
147 if let Some(probe_unmatched) = gather_probe_unmatched.as_mut() {
148 probe_unmatched
150 .extend(*probe_first_unmatched as IdxSize..=probe_idx as IdxSize);
151 *probe_first_unmatched = (*probe_first_unmatched).max(probe_idx + 1);
152 }
153 *probe_row_offset = probe_idx + 1;
154 },
155 Ordering::Less => {
156 break;
157 },
158 }
159 }
160 }
161 if build_emit_unmatched && !build_keyval_matched {
162 gather_build.push(*build_row_offset as IdxSize);
163 gather_probe.push(IdxSize::MAX);
164 }
165 *build_row_offset += 1;
166 }
167 if let Some(probe_unmatched) = gather_probe_unmatched {
168 probe_unmatched.extend(*probe_first_unmatched as IdxSize..probe_key.len() as IdxSize);
169 *probe_first_unmatched = probe_key.len();
170 }
171 *probe_row_offset = probe_key.len();
172}
173
174#[allow(clippy::mut_range_bound, clippy::too_many_arguments)]
175fn match_null_keys_impl(
176 build_n: usize,
177 probe_n: usize,
178 gather_build: &mut Vec<IdxSize>,
179 gather_probe: &mut Vec<IdxSize>,
180 gather_probe_unmatched: Option<&mut Vec<IdxSize>>,
181 build_emit_unmatched: bool,
182 _descending: bool,
183 nulls_equal: bool,
184 limit_results: usize,
185 build_row_offset: &mut usize,
186 probe_row_offset: &mut usize,
187 probe_last_matched: &mut usize,
188) {
189 assert!(gather_build.is_empty());
190 assert!(gather_probe.is_empty());
191
192 if nulls_equal {
193 while *build_row_offset < build_n {
195 if gather_build.len() >= limit_results {
196 return;
197 }
198 for probe_idx in *probe_row_offset..probe_n {
199 gather_build.push(*build_row_offset as IdxSize);
200 gather_probe.push(probe_idx as IdxSize);
201 }
202 *build_row_offset += 1;
203 }
204 } else {
205 if build_emit_unmatched {
208 gather_build.extend(0..build_n as IdxSize);
209 gather_probe.extend(repeat_n(IdxSize::MAX, build_n));
210 }
211 if let Some(probe_unmatched) = gather_probe_unmatched {
212 probe_unmatched.extend(*probe_last_matched as IdxSize..probe_n as IdxSize);
213 *probe_last_matched = probe_n;
214 }
215 }
216 *build_row_offset = build_n;
217 *probe_row_offset = probe_n;
218}
219
220#[allow(clippy::too_many_arguments)]
221pub fn gather_and_postprocess(
222 build: DataFrame,
223 probe: DataFrame,
224 gather_build: Option<&[IdxSize]>,
225 gather_probe: Option<&[IdxSize]>,
226 df_builders: &mut Option<(DataFrameBuilder, DataFrameBuilder)>,
227 args: &JoinArgs,
228 left_on: &[PlSmallStr],
229 right_on: &[PlSmallStr],
230 left_is_build: bool,
231 output_schema: &Schema,
232) -> PolarsResult<DataFrame> {
233 let should_coalesce = args.should_coalesce();
234 let left_emit_unmatched = matches!(args.how, JoinType::Left | JoinType::Full);
235 let right_emit_unmatched = matches!(args.how, JoinType::Right | JoinType::Full);
236
237 let (mut left, mut right);
238 let (gather_left, gather_right);
239 if left_is_build {
240 (left, right) = (build, probe);
241 (gather_left, gather_right) = (gather_build, gather_probe);
242 } else {
243 (left, right) = (probe, build);
244 (gather_left, gather_right) = (gather_probe, gather_build);
245 }
246
247 for col in left
249 .columns()
250 .iter()
251 .map(Column::name)
252 .cloned()
253 .collect_vec()
254 {
255 if left_on.contains(&col) && should_coalesce {
256 continue;
257 }
258 if !output_schema.contains(&col) {
259 left.drop_in_place(&col).unwrap();
260 }
261 }
262 for col in right
263 .columns()
264 .iter()
265 .map(Column::name)
266 .cloned()
267 .collect_vec()
268 {
269 if left_on.contains(&col) && should_coalesce {
270 continue;
271 }
272 let renamed = match left.schema().contains(&col) {
273 true => Cow::Owned(format_pl_smallstr!("{}{}", col, args.suffix())),
274 false => Cow::Borrowed(&col),
275 };
276 if !output_schema.contains(&renamed) {
277 right.drop_in_place(&col).unwrap();
278 }
279 }
280
281 if df_builders.is_none() {
282 *df_builders = Some((
283 DataFrameBuilder::new(left.schema().clone()),
284 DataFrameBuilder::new(right.schema().clone()),
285 ));
286 }
287
288 let (left_build, right_build) = df_builders.as_mut().unwrap();
289 let mut left = match gather_left {
290 Some(gather_left) if right_emit_unmatched => {
291 left_build.opt_gather_extend(&left, gather_left, ShareStrategy::Never);
292 left_build.freeze_reset()
293 },
294 Some(gather_left) => unsafe {
295 left_build.gather_extend(&left, gather_left, ShareStrategy::Never);
296 left_build.freeze_reset()
297 },
298 None => DataFrame::full_null(left.schema(), gather_right.unwrap().len()),
299 };
300 let mut right = match gather_right {
301 Some(gather_right) if left_emit_unmatched => {
302 right_build.opt_gather_extend(&right, gather_right, ShareStrategy::Never);
303 right_build.freeze_reset()
304 },
305 Some(gather_right) => unsafe {
306 right_build.gather_extend(&right, gather_right, ShareStrategy::Never);
307 right_build.freeze_reset()
308 },
309 None => DataFrame::full_null(right.schema(), gather_left.unwrap().len()),
310 };
311
312 if args.how == JoinType::Left && should_coalesce {
314 for c in left_on {
315 if right.schema().contains(c) {
316 right.drop_in_place(c.as_str())?;
317 }
318 }
319 } else if args.how == JoinType::Right && should_coalesce {
320 for c in right_on {
321 if left.schema().contains(c) {
322 left.drop_in_place(c.as_str())?;
323 }
324 }
325 }
326
327 let left_cols: PlHashSet<_> = left.columns().iter().map(Column::name).cloned().collect();
329 let right_cols_vec = right.get_column_names_owned();
330 let renames = right_cols_vec
331 .iter()
332 .filter(|c| left_cols.contains(*c))
333 .map(|c| {
334 let renamed = format_pl_smallstr!("{}{}", c, args.suffix());
335 (c.as_str(), renamed)
336 });
337 right.rename_many(renames).unwrap();
338
339 left.hstack_mut(right.columns())?;
340
341 if args.how == JoinType::Full && should_coalesce {
342 for (left_keycol, right_keycol) in Iterator::zip(left_on.iter(), right_on.iter()) {
344 let right_keycol = format_pl_smallstr!("{}{}", right_keycol, args.suffix());
345 let left_col = left.column(left_keycol).unwrap();
346 let right_col = left.column(&right_keycol).unwrap();
347 let coalesced = coalesce_columns(&[left_col.clone(), right_col.clone()]).unwrap();
348 left.replace(left_keycol, coalesced)
349 .unwrap()
350 .drop_in_place(&right_keycol)
351 .unwrap();
352 }
353 }
354
355 if should_coalesce {
356 for col in left_on {
357 if left.schema().contains(col) && !output_schema.contains(col) {
358 left.drop_in_place(col).unwrap();
359 }
360 }
361 for col in right_on {
362 let renamed = match left.schema().contains(col) {
363 true => Cow::Owned(format_pl_smallstr!("{}{}", col, args.suffix())),
364 false => Cow::Borrowed(col),
365 };
366 if left.schema().contains(&renamed) && !output_schema.contains(&renamed) {
367 left.drop_in_place(&renamed).unwrap();
368 }
369 }
370 }
371
372 debug_assert_eq!(**left.schema(), *output_schema);
373 Ok(left)
374}