Skip to main content

polars_ops/frame/join/
args.rs

1use super::*;
2
3pub(super) type JoinIds = Vec<IdxSize>;
4pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
5pub type InnerJoinIds = (JoinIds, JoinIds);
6
7#[cfg(feature = "chunked_ids")]
8pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
9#[cfg(feature = "chunked_ids")]
10pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
11
12#[cfg(not(feature = "chunked_ids"))]
13pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
14
15#[cfg(not(feature = "chunked_ids"))]
16pub type ChunkJoinIds = Vec<IdxSize>;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20use strum_macros::IntoStaticStr;
21
22/// Parameters for which side to use as the build side in a join. Currently only
23/// respected by the streaming engine.
24#[derive(Clone, PartialEq, Debug, Hash)]
25#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
26#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
27pub enum JoinBuildSide {
28    /// Unless there's a very good reason to believe that the right side is
29    /// smaller, use the left side.
30    PreferLeft,
31    /// Regardless of other heuristics, use the left side as build side.
32    ForceLeft,
33
34    // Similar to above.
35    PreferRight,
36    ForceRight,
37}
38
39#[derive(Clone, PartialEq, Debug, Hash, Default)]
40#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
41#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
42pub struct JoinArgs {
43    pub how: JoinType,
44    pub validation: JoinValidation,
45    pub suffix: Option<PlSmallStr>,
46    pub slice: Option<(i64, usize)>,
47    pub nulls_equal: bool,
48    pub coalesce: JoinCoalesce,
49    pub maintain_order: MaintainOrderJoin,
50    pub build_side: Option<JoinBuildSide>,
51}
52
53impl JoinArgs {
54    pub fn should_coalesce(&self) -> bool {
55        self.coalesce.coalesce(&self.how)
56    }
57}
58
59#[derive(Clone, PartialEq, Hash, Default, IntoStaticStr)]
60#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
61#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
62pub enum JoinType {
63    #[default]
64    Inner,
65    Left,
66    Right,
67    Full,
68    // Box is okay because this is inside a `Arc<JoinOptionsIR>`
69    #[cfg(feature = "asof_join")]
70    AsOf(Box<AsOfOptions>),
71    #[cfg(feature = "semi_anti_join")]
72    Semi,
73    #[cfg(feature = "semi_anti_join")]
74    Anti,
75    #[cfg(feature = "iejoin")]
76    /// Inequality join with two arbitrary predicates
77    // Options are set by optimizer/planner in Options
78    IEJoin,
79    #[cfg(feature = "iejoin")]
80    /// Inequality join with col ∈ [lo, hi] predicate
81    // Options are set by optimizer/planner in Options
82    Range,
83    // Options are set by optimizer/planner in Options
84    Cross,
85}
86
87#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
88#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
89#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
90pub enum JoinCoalesce {
91    #[default]
92    JoinSpecific,
93    CoalesceColumns,
94    KeepColumns,
95}
96
97impl JoinCoalesce {
98    pub fn coalesce(&self, join_type: &JoinType) -> bool {
99        use JoinCoalesce::*;
100        use JoinType::*;
101        match join_type {
102            Left | Inner | Right => {
103                matches!(self, JoinSpecific | CoalesceColumns)
104            },
105            Full => {
106                matches!(self, CoalesceColumns)
107            },
108            #[cfg(feature = "asof_join")]
109            AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
110            #[cfg(feature = "iejoin")]
111            IEJoin | Range => false,
112            Cross => false,
113            #[cfg(feature = "semi_anti_join")]
114            Semi | Anti => false,
115        }
116    }
117}
118
119#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]
120#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
121#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
122#[strum(serialize_all = "snake_case")]
123pub enum MaintainOrderJoin {
124    #[default]
125    None,
126    Left,
127    Right,
128    LeftRight,
129    RightLeft,
130}
131
132impl MaintainOrderJoin {
133    pub(super) fn flip(&self) -> Self {
134        match self {
135            MaintainOrderJoin::None => MaintainOrderJoin::None,
136            MaintainOrderJoin::Left => MaintainOrderJoin::Right,
137            MaintainOrderJoin::Right => MaintainOrderJoin::Left,
138            MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,
139            MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,
140        }
141    }
142}
143
144impl JoinArgs {
145    pub fn new(how: JoinType) -> Self {
146        Self {
147            how,
148            validation: Default::default(),
149            suffix: None,
150            slice: None,
151            nulls_equal: false,
152            coalesce: Default::default(),
153            maintain_order: Default::default(),
154            build_side: None,
155        }
156    }
157
158    pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
159        self.coalesce = coalesce;
160        self
161    }
162
163    pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
164        self.suffix = suffix;
165        self
166    }
167
168    pub fn with_build_side(mut self, build_side: Option<JoinBuildSide>) -> Self {
169        self.build_side = build_side;
170        self
171    }
172
173    pub fn suffix(&self) -> &PlSmallStr {
174        const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
175        self.suffix.as_ref().unwrap_or(DEFAULT)
176    }
177}
178
179impl From<JoinType> for JoinArgs {
180    fn from(value: JoinType) -> Self {
181        JoinArgs::new(value)
182    }
183}
184
185pub trait CrossJoinFilter: Send + Sync {
186    fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;
187}
188
189impl<T> CrossJoinFilter for T
190where
191    T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
192{
193    fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {
194        self(df)
195    }
196}
197
198#[derive(Clone)]
199pub struct CrossJoinOptions {
200    pub predicate: Arc<dyn CrossJoinFilter>,
201}
202
203impl CrossJoinOptions {
204    fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {
205        Arc::as_ptr(&self.predicate)
206    }
207}
208
209impl Eq for CrossJoinOptions {}
210
211impl PartialEq for CrossJoinOptions {
212    fn eq(&self, other: &Self) -> bool {
213        std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())
214    }
215}
216
217impl Hash for CrossJoinOptions {
218    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
219        self.as_ptr_ref().hash(state);
220    }
221}
222
223impl Debug for CrossJoinOptions {
224    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
225        write!(f, "CrossJoinOptions",)
226    }
227}
228
229#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]
230#[strum(serialize_all = "snake_case")]
231pub enum JoinTypeOptions {
232    #[cfg(feature = "iejoin")]
233    IEJoin(IEJoinOptions),
234    Cross(CrossJoinOptions),
235}
236
237impl Display for JoinType {
238    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
239        use JoinType::*;
240        let val = match self {
241            Left => "LEFT",
242            Right => "RIGHT",
243            Inner => "INNER",
244            Full => "FULL",
245            #[cfg(feature = "asof_join")]
246            AsOf(_) => "ASOF",
247            #[cfg(feature = "iejoin")]
248            IEJoin => "IEJOIN",
249            #[cfg(feature = "iejoin")]
250            Range => "RANGE",
251            Cross => "CROSS",
252            #[cfg(feature = "semi_anti_join")]
253            Semi => "SEMI",
254            #[cfg(feature = "semi_anti_join")]
255            Anti => "ANTI",
256        };
257        write!(f, "{val}")
258    }
259}
260
261impl Debug for JoinType {
262    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263        write!(f, "{self}")
264    }
265}
266
267impl JoinType {
268    pub fn is_equi(&self) -> bool {
269        matches!(
270            self,
271            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
272        )
273    }
274
275    pub fn is_semi_anti(&self) -> bool {
276        #[cfg(feature = "semi_anti_join")]
277        {
278            matches!(self, JoinType::Semi | JoinType::Anti)
279        }
280        #[cfg(not(feature = "semi_anti_join"))]
281        {
282            false
283        }
284    }
285
286    pub fn is_semi(&self) -> bool {
287        #[cfg(feature = "semi_anti_join")]
288        {
289            matches!(self, JoinType::Semi)
290        }
291        #[cfg(not(feature = "semi_anti_join"))]
292        {
293            false
294        }
295    }
296
297    pub fn is_anti(&self) -> bool {
298        #[cfg(feature = "semi_anti_join")]
299        {
300            matches!(self, JoinType::Anti)
301        }
302        #[cfg(not(feature = "semi_anti_join"))]
303        {
304            false
305        }
306    }
307
308    pub fn is_asof(&self) -> bool {
309        #[cfg(feature = "asof_join")]
310        {
311            matches!(self, JoinType::AsOf(_))
312        }
313        #[cfg(not(feature = "asof_join"))]
314        {
315            false
316        }
317    }
318
319    pub fn is_cross(&self) -> bool {
320        matches!(self, JoinType::Cross)
321    }
322
323    pub fn is_ie(&self) -> bool {
324        #[cfg(feature = "iejoin")]
325        {
326            matches!(self, JoinType::IEJoin)
327        }
328        #[cfg(not(feature = "iejoin"))]
329        {
330            false
331        }
332    }
333
334    pub fn is_range(&self) -> bool {
335        #[cfg(feature = "iejoin")]
336        {
337            matches!(self, JoinType::Range)
338        }
339        #[cfg(not(feature = "iejoin"))]
340        {
341            false
342        }
343    }
344}
345
346#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
347#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
348#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
349pub enum JoinValidation {
350    /// No unique checks
351    #[default]
352    ManyToMany,
353    /// Check if join keys are unique in right dataset.
354    ManyToOne,
355    /// Check if join keys are unique in left dataset.
356    OneToMany,
357    /// Check if join keys are unique in both left and right datasets
358    OneToOne,
359}
360
361impl JoinValidation {
362    pub fn needs_checks(&self) -> bool {
363        !matches!(self, JoinValidation::ManyToMany)
364    }
365
366    fn swap(self, swap: bool) -> Self {
367        use JoinValidation::*;
368        if swap {
369            match self {
370                ManyToMany => ManyToMany,
371                ManyToOne => OneToMany,
372                OneToMany => ManyToOne,
373                OneToOne => OneToOne,
374            }
375        } else {
376            self
377        }
378    }
379
380    pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
381        if !self.needs_checks() {
382            return Ok(());
383        }
384        polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),
385                      ComputeError: "{self} validation on a {join_type} join is not supported");
386        Ok(())
387    }
388
389    pub(super) fn validate_probe(
390        &self,
391        s_left: &Series,
392        s_right: &Series,
393        build_shortest_table: bool,
394        nulls_equal: bool,
395    ) -> PolarsResult<()> {
396        // In default, probe is the left series.
397        //
398        // In inner join and outer join, the shortest relation will be used to create a hash table.
399        // In left join, always use the right side to create.
400        //
401        // If `build_shortest_table` and left is shorter, swap. Then rhs will be the probe.
402        // If left == right, swap too. (apply the same logic as `det_hash_prone_order`)
403        let should_swap = build_shortest_table && s_left.len() <= s_right.len();
404        let probe = if should_swap { s_right } else { s_left };
405
406        use JoinValidation::*;
407        let valid = match self.swap(should_swap) {
408            // Only check the `build` side.
409            // The other side use `validate_build` to check
410            ManyToMany | ManyToOne => true,
411            OneToMany | OneToOne => {
412                if !nulls_equal && probe.null_count() > 0 {
413                    probe.n_unique()? - 1 == probe.len() - probe.null_count()
414                } else {
415                    probe.n_unique()? == probe.len()
416                }
417            },
418        };
419        polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
420        Ok(())
421    }
422
423    pub(super) fn validate_build(
424        &self,
425        build_size: usize,
426        expected_size: usize,
427        swapped: bool,
428    ) -> PolarsResult<()> {
429        use JoinValidation::*;
430
431        // In default, build is in rhs.
432        let valid = match self.swap(swapped) {
433            // Only check the `build` side.
434            // The other side use `validate_prone` to check
435            ManyToMany | OneToMany => true,
436            ManyToOne | OneToOne => build_size == expected_size,
437        };
438        polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
439        Ok(())
440    }
441}
442
443impl Display for JoinValidation {
444    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
445        let s = match self {
446            JoinValidation::ManyToMany => "m:m",
447            JoinValidation::ManyToOne => "m:1",
448            JoinValidation::OneToMany => "1:m",
449            JoinValidation::OneToOne => "1:1",
450        };
451        write!(f, "{s}")
452    }
453}
454
455impl Debug for JoinValidation {
456    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
457        write!(f, "JoinValidation: {self}")
458    }
459}