polars_ops/frame/join/
args.rs

1use super::*;
2
3pub(super) type JoinIds = Vec<IdxSize>;
4pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
5pub type InnerJoinIds = (JoinIds, JoinIds);
6
7#[cfg(feature = "chunked_ids")]
8pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
9#[cfg(feature = "chunked_ids")]
10pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
11
12#[cfg(not(feature = "chunked_ids"))]
13pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
14
15#[cfg(not(feature = "chunked_ids"))]
16pub type ChunkJoinIds = Vec<IdxSize>;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20use strum_macros::IntoStaticStr;
21
22#[derive(Clone, PartialEq, Debug, Hash, Default)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
25pub struct JoinArgs {
26    pub how: JoinType,
27    pub validation: JoinValidation,
28    pub suffix: Option<PlSmallStr>,
29    pub slice: Option<(i64, usize)>,
30    pub nulls_equal: bool,
31    pub coalesce: JoinCoalesce,
32    pub maintain_order: MaintainOrderJoin,
33}
34
35impl JoinArgs {
36    pub fn should_coalesce(&self) -> bool {
37        self.coalesce.coalesce(&self.how)
38    }
39}
40
41#[derive(Clone, PartialEq, Hash, Default, IntoStaticStr)]
42#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
43#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
44pub enum JoinType {
45    #[default]
46    Inner,
47    Left,
48    Right,
49    Full,
50    // Box is okay because this is inside a `Arc<JoinOptionsIR>`
51    #[cfg(feature = "asof_join")]
52    AsOf(Box<AsOfOptions>),
53    #[cfg(feature = "semi_anti_join")]
54    Semi,
55    #[cfg(feature = "semi_anti_join")]
56    Anti,
57    #[cfg(feature = "iejoin")]
58    // Options are set by optimizer/planner in Options
59    IEJoin,
60    // Options are set by optimizer/planner in Options
61    Cross,
62}
63
64#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
65#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
66#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
67pub enum JoinCoalesce {
68    #[default]
69    JoinSpecific,
70    CoalesceColumns,
71    KeepColumns,
72}
73
74impl JoinCoalesce {
75    pub fn coalesce(&self, join_type: &JoinType) -> bool {
76        use JoinCoalesce::*;
77        use JoinType::*;
78        match join_type {
79            Left | Inner | Right => {
80                matches!(self, JoinSpecific | CoalesceColumns)
81            },
82            Full => {
83                matches!(self, CoalesceColumns)
84            },
85            #[cfg(feature = "asof_join")]
86            AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
87            #[cfg(feature = "iejoin")]
88            IEJoin => false,
89            Cross => false,
90            #[cfg(feature = "semi_anti_join")]
91            Semi | Anti => false,
92        }
93    }
94}
95
96#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]
97#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
98#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
99#[strum(serialize_all = "snake_case")]
100pub enum MaintainOrderJoin {
101    #[default]
102    None,
103    Left,
104    Right,
105    LeftRight,
106    RightLeft,
107}
108
109impl MaintainOrderJoin {
110    pub(super) fn flip(&self) -> Self {
111        match self {
112            MaintainOrderJoin::None => MaintainOrderJoin::None,
113            MaintainOrderJoin::Left => MaintainOrderJoin::Right,
114            MaintainOrderJoin::Right => MaintainOrderJoin::Left,
115            MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,
116            MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,
117        }
118    }
119}
120
121impl JoinArgs {
122    pub fn new(how: JoinType) -> Self {
123        Self {
124            how,
125            validation: Default::default(),
126            suffix: None,
127            slice: None,
128            nulls_equal: false,
129            coalesce: Default::default(),
130            maintain_order: Default::default(),
131        }
132    }
133
134    pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
135        self.coalesce = coalesce;
136        self
137    }
138
139    pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
140        self.suffix = suffix;
141        self
142    }
143
144    pub fn suffix(&self) -> &PlSmallStr {
145        const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
146        self.suffix.as_ref().unwrap_or(DEFAULT)
147    }
148}
149
150impl From<JoinType> for JoinArgs {
151    fn from(value: JoinType) -> Self {
152        JoinArgs::new(value)
153    }
154}
155
156pub trait CrossJoinFilter: Send + Sync {
157    fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;
158}
159
160impl<T> CrossJoinFilter for T
161where
162    T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
163{
164    fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {
165        self(df)
166    }
167}
168
169#[derive(Clone)]
170pub struct CrossJoinOptions {
171    pub predicate: Arc<dyn CrossJoinFilter>,
172}
173
174impl CrossJoinOptions {
175    fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {
176        Arc::as_ptr(&self.predicate)
177    }
178}
179
180impl Eq for CrossJoinOptions {}
181
182impl PartialEq for CrossJoinOptions {
183    fn eq(&self, other: &Self) -> bool {
184        std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())
185    }
186}
187
188impl Hash for CrossJoinOptions {
189    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
190        self.as_ptr_ref().hash(state);
191    }
192}
193
194impl Debug for CrossJoinOptions {
195    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
196        write!(f, "CrossJoinOptions",)
197    }
198}
199
200#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]
201#[strum(serialize_all = "snake_case")]
202pub enum JoinTypeOptions {
203    #[cfg(feature = "iejoin")]
204    IEJoin(IEJoinOptions),
205    Cross(CrossJoinOptions),
206}
207
208impl Display for JoinType {
209    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
210        use JoinType::*;
211        let val = match self {
212            Left => "LEFT",
213            Right => "RIGHT",
214            Inner => "INNER",
215            Full => "FULL",
216            #[cfg(feature = "asof_join")]
217            AsOf(_) => "ASOF",
218            #[cfg(feature = "iejoin")]
219            IEJoin => "IEJOIN",
220            Cross => "CROSS",
221            #[cfg(feature = "semi_anti_join")]
222            Semi => "SEMI",
223            #[cfg(feature = "semi_anti_join")]
224            Anti => "ANTI",
225        };
226        write!(f, "{val}")
227    }
228}
229
230impl Debug for JoinType {
231    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
232        write!(f, "{self}")
233    }
234}
235
236impl JoinType {
237    pub fn is_equi(&self) -> bool {
238        matches!(
239            self,
240            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
241        )
242    }
243
244    pub fn is_semi_anti(&self) -> bool {
245        #[cfg(feature = "semi_anti_join")]
246        {
247            matches!(self, JoinType::Semi | JoinType::Anti)
248        }
249        #[cfg(not(feature = "semi_anti_join"))]
250        {
251            false
252        }
253    }
254
255    pub fn is_semi(&self) -> bool {
256        #[cfg(feature = "semi_anti_join")]
257        {
258            matches!(self, JoinType::Semi)
259        }
260        #[cfg(not(feature = "semi_anti_join"))]
261        {
262            false
263        }
264    }
265
266    pub fn is_anti(&self) -> bool {
267        #[cfg(feature = "semi_anti_join")]
268        {
269            matches!(self, JoinType::Anti)
270        }
271        #[cfg(not(feature = "semi_anti_join"))]
272        {
273            false
274        }
275    }
276
277    pub fn is_asof(&self) -> bool {
278        #[cfg(feature = "asof_join")]
279        {
280            matches!(self, JoinType::AsOf(_))
281        }
282        #[cfg(not(feature = "asof_join"))]
283        {
284            false
285        }
286    }
287
288    pub fn is_cross(&self) -> bool {
289        matches!(self, JoinType::Cross)
290    }
291
292    pub fn is_ie(&self) -> bool {
293        #[cfg(feature = "iejoin")]
294        {
295            matches!(self, JoinType::IEJoin)
296        }
297        #[cfg(not(feature = "iejoin"))]
298        {
299            false
300        }
301    }
302}
303
304#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
305#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
306#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
307pub enum JoinValidation {
308    /// No unique checks
309    #[default]
310    ManyToMany,
311    /// Check if join keys are unique in right dataset.
312    ManyToOne,
313    /// Check if join keys are unique in left dataset.
314    OneToMany,
315    /// Check if join keys are unique in both left and right datasets
316    OneToOne,
317}
318
319impl JoinValidation {
320    pub fn needs_checks(&self) -> bool {
321        !matches!(self, JoinValidation::ManyToMany)
322    }
323
324    fn swap(self, swap: bool) -> Self {
325        use JoinValidation::*;
326        if swap {
327            match self {
328                ManyToMany => ManyToMany,
329                ManyToOne => OneToMany,
330                OneToMany => ManyToOne,
331                OneToOne => OneToOne,
332            }
333        } else {
334            self
335        }
336    }
337
338    pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
339        if !self.needs_checks() {
340            return Ok(());
341        }
342        polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),
343                      ComputeError: "{self} validation on a {join_type} join is not supported");
344        Ok(())
345    }
346
347    pub(super) fn validate_probe(
348        &self,
349        s_left: &Series,
350        s_right: &Series,
351        build_shortest_table: bool,
352        nulls_equal: bool,
353    ) -> PolarsResult<()> {
354        // In default, probe is the left series.
355        //
356        // In inner join and outer join, the shortest relation will be used to create a hash table.
357        // In left join, always use the right side to create.
358        //
359        // If `build_shortest_table` and left is shorter, swap. Then rhs will be the probe.
360        // If left == right, swap too. (apply the same logic as `det_hash_prone_order`)
361        let should_swap = build_shortest_table && s_left.len() <= s_right.len();
362        let probe = if should_swap { s_right } else { s_left };
363
364        use JoinValidation::*;
365        let valid = match self.swap(should_swap) {
366            // Only check the `build` side.
367            // The other side use `validate_build` to check
368            ManyToMany | ManyToOne => true,
369            OneToMany | OneToOne => {
370                if !nulls_equal && probe.null_count() > 0 {
371                    probe.n_unique()? - 1 == probe.len() - probe.null_count()
372                } else {
373                    probe.n_unique()? == probe.len()
374                }
375            },
376        };
377        polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
378        Ok(())
379    }
380
381    pub(super) fn validate_build(
382        &self,
383        build_size: usize,
384        expected_size: usize,
385        swapped: bool,
386    ) -> PolarsResult<()> {
387        use JoinValidation::*;
388
389        // In default, build is in rhs.
390        let valid = match self.swap(swapped) {
391            // Only check the `build` side.
392            // The other side use `validate_prone` to check
393            ManyToMany | OneToMany => true,
394            ManyToOne | OneToOne => build_size == expected_size,
395        };
396        polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
397        Ok(())
398    }
399}
400
401impl Display for JoinValidation {
402    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
403        let s = match self {
404            JoinValidation::ManyToMany => "m:m",
405            JoinValidation::ManyToOne => "m:1",
406            JoinValidation::OneToMany => "1:m",
407            JoinValidation::OneToOne => "1:1",
408        };
409        write!(f, "{s}")
410    }
411}
412
413impl Debug for JoinValidation {
414    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
415        write!(f, "JoinValidation: {self}")
416    }
417}