1use super::*;
2
3pub(super) type JoinIds = Vec<IdxSize>;
4pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
5pub type InnerJoinIds = (JoinIds, JoinIds);
6
7#[cfg(feature = "chunked_ids")]
8pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
9#[cfg(feature = "chunked_ids")]
10pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
11
12#[cfg(not(feature = "chunked_ids"))]
13pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
14
15#[cfg(not(feature = "chunked_ids"))]
16pub type ChunkJoinIds = Vec<IdxSize>;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20use strum_macros::IntoStaticStr;
21
22#[derive(Clone, PartialEq, Debug, Hash)]
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
27pub enum JoinBuildSide {
28 PreferLeft,
31 ForceLeft,
33
34 PreferRight,
36 ForceRight,
37}
38
39#[derive(Clone, PartialEq, Debug, Hash, Default)]
40#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
41#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
42pub struct JoinArgs {
43 pub how: JoinType,
44 pub validation: JoinValidation,
45 pub suffix: Option<PlSmallStr>,
46 pub slice: Option<(i64, usize)>,
47 pub nulls_equal: bool,
48 pub coalesce: JoinCoalesce,
49 pub maintain_order: MaintainOrderJoin,
50 pub build_side: Option<JoinBuildSide>,
51}
52
53impl JoinArgs {
54 pub fn should_coalesce(&self) -> bool {
55 self.coalesce.coalesce(&self.how)
56 }
57}
58
59#[derive(Clone, PartialEq, Hash, Default, IntoStaticStr)]
60#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
61#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
62pub enum JoinType {
63 #[default]
64 Inner,
65 Left,
66 Right,
67 Full,
68 #[cfg(feature = "asof_join")]
70 AsOf(Box<AsOfOptions>),
71 #[cfg(feature = "semi_anti_join")]
72 Semi,
73 #[cfg(feature = "semi_anti_join")]
74 Anti,
75 #[cfg(feature = "iejoin")]
76 IEJoin,
79 #[cfg(feature = "iejoin")]
80 Range,
83 Cross,
85}
86
87#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
88#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
89#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
90pub enum JoinCoalesce {
91 #[default]
92 JoinSpecific,
93 CoalesceColumns,
94 KeepColumns,
95}
96
97impl JoinCoalesce {
98 pub fn coalesce(&self, join_type: &JoinType) -> bool {
99 use JoinCoalesce::*;
100 use JoinType::*;
101 match join_type {
102 Left | Inner | Right => {
103 matches!(self, JoinSpecific | CoalesceColumns)
104 },
105 Full => {
106 matches!(self, CoalesceColumns)
107 },
108 #[cfg(feature = "asof_join")]
109 AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
110 #[cfg(feature = "iejoin")]
111 IEJoin | Range => false,
112 Cross => false,
113 #[cfg(feature = "semi_anti_join")]
114 Semi | Anti => false,
115 }
116 }
117}
118
119#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]
120#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
121#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
122#[strum(serialize_all = "snake_case")]
123pub enum MaintainOrderJoin {
124 #[default]
125 None,
126 Left,
127 Right,
128 LeftRight,
129 RightLeft,
130}
131
132impl MaintainOrderJoin {
133 pub(super) fn flip(&self) -> Self {
134 match self {
135 MaintainOrderJoin::None => MaintainOrderJoin::None,
136 MaintainOrderJoin::Left => MaintainOrderJoin::Right,
137 MaintainOrderJoin::Right => MaintainOrderJoin::Left,
138 MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,
139 MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,
140 }
141 }
142}
143
144impl JoinArgs {
145 pub fn new(how: JoinType) -> Self {
146 Self {
147 how,
148 validation: Default::default(),
149 suffix: None,
150 slice: None,
151 nulls_equal: false,
152 coalesce: Default::default(),
153 maintain_order: Default::default(),
154 build_side: None,
155 }
156 }
157
158 pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
159 self.coalesce = coalesce;
160 self
161 }
162
163 pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
164 self.suffix = suffix;
165 self
166 }
167
168 pub fn with_build_side(mut self, build_side: Option<JoinBuildSide>) -> Self {
169 self.build_side = build_side;
170 self
171 }
172
173 pub fn suffix(&self) -> &PlSmallStr {
174 const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
175 self.suffix.as_ref().unwrap_or(DEFAULT)
176 }
177}
178
179impl From<JoinType> for JoinArgs {
180 fn from(value: JoinType) -> Self {
181 JoinArgs::new(value)
182 }
183}
184
185pub trait CrossJoinFilter: Send + Sync {
186 fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;
187}
188
189impl<T> CrossJoinFilter for T
190where
191 T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
192{
193 fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {
194 self(df)
195 }
196}
197
198#[derive(Clone)]
199pub struct CrossJoinOptions {
200 pub predicate: Arc<dyn CrossJoinFilter>,
201}
202
203impl CrossJoinOptions {
204 fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {
205 Arc::as_ptr(&self.predicate)
206 }
207}
208
209impl Eq for CrossJoinOptions {}
210
211impl PartialEq for CrossJoinOptions {
212 fn eq(&self, other: &Self) -> bool {
213 std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())
214 }
215}
216
217impl Hash for CrossJoinOptions {
218 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
219 self.as_ptr_ref().hash(state);
220 }
221}
222
223impl Debug for CrossJoinOptions {
224 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
225 write!(f, "CrossJoinOptions",)
226 }
227}
228
229#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]
230#[strum(serialize_all = "snake_case")]
231pub enum JoinTypeOptions {
232 #[cfg(feature = "iejoin")]
233 IEJoin(IEJoinOptions),
234 Cross(CrossJoinOptions),
235}
236
237impl Display for JoinType {
238 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
239 use JoinType::*;
240 let val = match self {
241 Left => "LEFT",
242 Right => "RIGHT",
243 Inner => "INNER",
244 Full => "FULL",
245 #[cfg(feature = "asof_join")]
246 AsOf(_) => "ASOF",
247 #[cfg(feature = "iejoin")]
248 IEJoin => "IEJOIN",
249 #[cfg(feature = "iejoin")]
250 Range => "RANGE",
251 Cross => "CROSS",
252 #[cfg(feature = "semi_anti_join")]
253 Semi => "SEMI",
254 #[cfg(feature = "semi_anti_join")]
255 Anti => "ANTI",
256 };
257 write!(f, "{val}")
258 }
259}
260
261impl Debug for JoinType {
262 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263 write!(f, "{self}")
264 }
265}
266
267impl JoinType {
268 pub fn is_equi(&self) -> bool {
269 matches!(
270 self,
271 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
272 )
273 }
274
275 pub fn is_semi_anti(&self) -> bool {
276 #[cfg(feature = "semi_anti_join")]
277 {
278 matches!(self, JoinType::Semi | JoinType::Anti)
279 }
280 #[cfg(not(feature = "semi_anti_join"))]
281 {
282 false
283 }
284 }
285
286 pub fn is_semi(&self) -> bool {
287 #[cfg(feature = "semi_anti_join")]
288 {
289 matches!(self, JoinType::Semi)
290 }
291 #[cfg(not(feature = "semi_anti_join"))]
292 {
293 false
294 }
295 }
296
297 pub fn is_anti(&self) -> bool {
298 #[cfg(feature = "semi_anti_join")]
299 {
300 matches!(self, JoinType::Anti)
301 }
302 #[cfg(not(feature = "semi_anti_join"))]
303 {
304 false
305 }
306 }
307
308 pub fn is_asof(&self) -> bool {
309 #[cfg(feature = "asof_join")]
310 {
311 matches!(self, JoinType::AsOf(_))
312 }
313 #[cfg(not(feature = "asof_join"))]
314 {
315 false
316 }
317 }
318
319 pub fn is_cross(&self) -> bool {
320 matches!(self, JoinType::Cross)
321 }
322
323 pub fn is_ie(&self) -> bool {
324 #[cfg(feature = "iejoin")]
325 {
326 matches!(self, JoinType::IEJoin)
327 }
328 #[cfg(not(feature = "iejoin"))]
329 {
330 false
331 }
332 }
333
334 pub fn is_range(&self) -> bool {
335 #[cfg(feature = "iejoin")]
336 {
337 matches!(self, JoinType::Range)
338 }
339 #[cfg(not(feature = "iejoin"))]
340 {
341 false
342 }
343 }
344}
345
346#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
347#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
348#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
349pub enum JoinValidation {
350 #[default]
352 ManyToMany,
353 ManyToOne,
355 OneToMany,
357 OneToOne,
359}
360
361impl JoinValidation {
362 pub fn needs_checks(&self) -> bool {
363 !matches!(self, JoinValidation::ManyToMany)
364 }
365
366 fn swap(self, swap: bool) -> Self {
367 use JoinValidation::*;
368 if swap {
369 match self {
370 ManyToMany => ManyToMany,
371 ManyToOne => OneToMany,
372 OneToMany => ManyToOne,
373 OneToOne => OneToOne,
374 }
375 } else {
376 self
377 }
378 }
379
380 pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
381 if !self.needs_checks() {
382 return Ok(());
383 }
384 polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),
385 ComputeError: "{self} validation on a {join_type} join is not supported");
386 Ok(())
387 }
388
389 pub(super) fn validate_probe(
390 &self,
391 s_left: &Series,
392 s_right: &Series,
393 build_shortest_table: bool,
394 nulls_equal: bool,
395 ) -> PolarsResult<()> {
396 let should_swap = build_shortest_table && s_left.len() <= s_right.len();
404 let probe = if should_swap { s_right } else { s_left };
405
406 use JoinValidation::*;
407 let valid = match self.swap(should_swap) {
408 ManyToMany | ManyToOne => true,
411 OneToMany | OneToOne => {
412 if !nulls_equal && probe.null_count() > 0 {
413 probe.n_unique()? - 1 == probe.len() - probe.null_count()
414 } else {
415 probe.n_unique()? == probe.len()
416 }
417 },
418 };
419 polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
420 Ok(())
421 }
422
423 pub(super) fn validate_build(
424 &self,
425 build_size: usize,
426 expected_size: usize,
427 swapped: bool,
428 ) -> PolarsResult<()> {
429 use JoinValidation::*;
430
431 let valid = match self.swap(swapped) {
433 ManyToMany | OneToMany => true,
436 ManyToOne | OneToOne => build_size == expected_size,
437 };
438 polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
439 Ok(())
440 }
441}
442
443impl Display for JoinValidation {
444 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
445 let s = match self {
446 JoinValidation::ManyToMany => "m:m",
447 JoinValidation::ManyToOne => "m:1",
448 JoinValidation::OneToMany => "1:m",
449 JoinValidation::OneToOne => "1:1",
450 };
451 write!(f, "{s}")
452 }
453}
454
455impl Debug for JoinValidation {
456 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
457 write!(f, "JoinValidation: {self}")
458 }
459}