polars_ops/frame/join/
args.rs

1use super::*;
2
3pub(super) type JoinIds = Vec<IdxSize>;
4pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
5pub type InnerJoinIds = (JoinIds, JoinIds);
6
7#[cfg(feature = "chunked_ids")]
8pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
9#[cfg(feature = "chunked_ids")]
10pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
11
12#[cfg(not(feature = "chunked_ids"))]
13pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
14
15#[cfg(not(feature = "chunked_ids"))]
16pub type ChunkJoinIds = Vec<IdxSize>;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20use strum_macros::IntoStaticStr;
21
22#[derive(Clone, PartialEq, Eq, Debug, Hash, Default)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24pub struct JoinArgs {
25    pub how: JoinType,
26    pub validation: JoinValidation,
27    pub suffix: Option<PlSmallStr>,
28    pub slice: Option<(i64, usize)>,
29    pub nulls_equal: bool,
30    pub coalesce: JoinCoalesce,
31    pub maintain_order: MaintainOrderJoin,
32}
33
34impl JoinArgs {
35    pub fn should_coalesce(&self) -> bool {
36        self.coalesce.coalesce(&self.how)
37    }
38}
39
40#[derive(Clone, PartialEq, Eq, Hash, Default, IntoStaticStr)]
41#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
42pub enum JoinType {
43    #[default]
44    Inner,
45    Left,
46    Right,
47    Full,
48    #[cfg(feature = "asof_join")]
49    AsOf(AsOfOptions),
50    #[cfg(feature = "semi_anti_join")]
51    Semi,
52    #[cfg(feature = "semi_anti_join")]
53    Anti,
54    #[cfg(feature = "iejoin")]
55    // Options are set by optimizer/planner in Options
56    IEJoin,
57    // Options are set by optimizer/planner in Options
58    Cross,
59}
60
61#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
62#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
63pub enum JoinCoalesce {
64    #[default]
65    JoinSpecific,
66    CoalesceColumns,
67    KeepColumns,
68}
69
70impl JoinCoalesce {
71    pub fn coalesce(&self, join_type: &JoinType) -> bool {
72        use JoinCoalesce::*;
73        use JoinType::*;
74        match join_type {
75            Left | Inner | Right => {
76                matches!(self, JoinSpecific | CoalesceColumns)
77            },
78            Full => {
79                matches!(self, CoalesceColumns)
80            },
81            #[cfg(feature = "asof_join")]
82            AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
83            #[cfg(feature = "iejoin")]
84            IEJoin => false,
85            Cross => false,
86            #[cfg(feature = "semi_anti_join")]
87            Semi | Anti => false,
88        }
89    }
90}
91
92#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]
93#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
94#[strum(serialize_all = "snake_case")]
95pub enum MaintainOrderJoin {
96    #[default]
97    None,
98    Left,
99    Right,
100    LeftRight,
101    RightLeft,
102}
103
104impl MaintainOrderJoin {
105    pub(super) fn flip(&self) -> Self {
106        match self {
107            MaintainOrderJoin::None => MaintainOrderJoin::None,
108            MaintainOrderJoin::Left => MaintainOrderJoin::Right,
109            MaintainOrderJoin::Right => MaintainOrderJoin::Left,
110            MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,
111            MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,
112        }
113    }
114}
115
116impl JoinArgs {
117    pub fn new(how: JoinType) -> Self {
118        Self {
119            how,
120            validation: Default::default(),
121            suffix: None,
122            slice: None,
123            nulls_equal: false,
124            coalesce: Default::default(),
125            maintain_order: Default::default(),
126        }
127    }
128
129    pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
130        self.coalesce = coalesce;
131        self
132    }
133
134    pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
135        self.suffix = suffix;
136        self
137    }
138
139    pub fn suffix(&self) -> &PlSmallStr {
140        const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
141        self.suffix.as_ref().unwrap_or(DEFAULT)
142    }
143}
144
145impl From<JoinType> for JoinArgs {
146    fn from(value: JoinType) -> Self {
147        JoinArgs::new(value)
148    }
149}
150
151pub trait CrossJoinFilter: Send + Sync {
152    fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;
153}
154
155impl<T> CrossJoinFilter for T
156where
157    T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
158{
159    fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {
160        self(df)
161    }
162}
163
164#[derive(Clone)]
165pub struct CrossJoinOptions {
166    pub predicate: Arc<dyn CrossJoinFilter>,
167}
168
169impl CrossJoinOptions {
170    fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {
171        Arc::as_ptr(&self.predicate)
172    }
173}
174
175impl Eq for CrossJoinOptions {}
176
177impl PartialEq for CrossJoinOptions {
178    fn eq(&self, other: &Self) -> bool {
179        std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())
180    }
181}
182
183impl Hash for CrossJoinOptions {
184    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
185        self.as_ptr_ref().hash(state);
186    }
187}
188
189impl Debug for CrossJoinOptions {
190    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
191        write!(f, "CrossJoinOptions",)
192    }
193}
194
195#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]
196#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
197#[strum(serialize_all = "snake_case")]
198pub enum JoinTypeOptions {
199    #[cfg(feature = "iejoin")]
200    IEJoin(IEJoinOptions),
201    #[cfg_attr(feature = "serde", serde(skip))]
202    Cross(CrossJoinOptions),
203}
204
205impl Display for JoinType {
206    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
207        use JoinType::*;
208        let val = match self {
209            Left => "LEFT",
210            Right => "RIGHT",
211            Inner => "INNER",
212            Full => "FULL",
213            #[cfg(feature = "asof_join")]
214            AsOf(_) => "ASOF",
215            #[cfg(feature = "iejoin")]
216            IEJoin => "IEJOIN",
217            Cross => "CROSS",
218            #[cfg(feature = "semi_anti_join")]
219            Semi => "SEMI",
220            #[cfg(feature = "semi_anti_join")]
221            Anti => "ANTI",
222        };
223        write!(f, "{val}")
224    }
225}
226
227impl Debug for JoinType {
228    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
229        write!(f, "{self}")
230    }
231}
232
233impl JoinType {
234    pub fn is_equi(&self) -> bool {
235        matches!(
236            self,
237            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
238        )
239    }
240
241    pub fn is_semi_anti(&self) -> bool {
242        #[cfg(feature = "semi_anti_join")]
243        {
244            matches!(self, JoinType::Semi | JoinType::Anti)
245        }
246        #[cfg(not(feature = "semi_anti_join"))]
247        {
248            false
249        }
250    }
251
252    pub fn is_asof(&self) -> bool {
253        #[cfg(feature = "asof_join")]
254        {
255            matches!(self, JoinType::AsOf(_))
256        }
257        #[cfg(not(feature = "asof_join"))]
258        {
259            false
260        }
261    }
262
263    pub fn is_cross(&self) -> bool {
264        matches!(self, JoinType::Cross)
265    }
266
267    pub fn is_ie(&self) -> bool {
268        #[cfg(feature = "iejoin")]
269        {
270            matches!(self, JoinType::IEJoin)
271        }
272        #[cfg(not(feature = "iejoin"))]
273        {
274            false
275        }
276    }
277}
278
279#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
280#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
281pub enum JoinValidation {
282    /// No unique checks
283    #[default]
284    ManyToMany,
285    /// Check if join keys are unique in right dataset.
286    ManyToOne,
287    /// Check if join keys are unique in left dataset.
288    OneToMany,
289    /// Check if join keys are unique in both left and right datasets
290    OneToOne,
291}
292
293impl JoinValidation {
294    pub fn needs_checks(&self) -> bool {
295        !matches!(self, JoinValidation::ManyToMany)
296    }
297
298    fn swap(self, swap: bool) -> Self {
299        use JoinValidation::*;
300        if swap {
301            match self {
302                ManyToMany => ManyToMany,
303                ManyToOne => OneToMany,
304                OneToMany => ManyToOne,
305                OneToOne => OneToOne,
306            }
307        } else {
308            self
309        }
310    }
311
312    pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
313        if !self.needs_checks() {
314            return Ok(());
315        }
316        polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),
317                      ComputeError: "{self} validation on a {join_type} join is not supported");
318        Ok(())
319    }
320
321    pub(super) fn validate_probe(
322        &self,
323        s_left: &Series,
324        s_right: &Series,
325        build_shortest_table: bool,
326        nulls_equal: bool,
327    ) -> PolarsResult<()> {
328        // In default, probe is the left series.
329        //
330        // In inner join and outer join, the shortest relation will be used to create a hash table.
331        // In left join, always use the right side to create.
332        //
333        // If `build_shortest_table` and left is shorter, swap. Then rhs will be the probe.
334        // If left == right, swap too. (apply the same logic as `det_hash_prone_order`)
335        let should_swap = build_shortest_table && s_left.len() <= s_right.len();
336        let probe = if should_swap { s_right } else { s_left };
337
338        use JoinValidation::*;
339        let valid = match self.swap(should_swap) {
340            // Only check the `build` side.
341            // The other side use `validate_build` to check
342            ManyToMany | ManyToOne => true,
343            OneToMany | OneToOne => {
344                if !nulls_equal && probe.null_count() > 0 {
345                    probe.n_unique()? - 1 == probe.len() - probe.null_count()
346                } else {
347                    probe.n_unique()? == probe.len()
348                }
349            },
350        };
351        polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
352        Ok(())
353    }
354
355    pub(super) fn validate_build(
356        &self,
357        build_size: usize,
358        expected_size: usize,
359        swapped: bool,
360    ) -> PolarsResult<()> {
361        use JoinValidation::*;
362
363        // In default, build is in rhs.
364        let valid = match self.swap(swapped) {
365            // Only check the `build` side.
366            // The other side use `validate_prone` to check
367            ManyToMany | OneToMany => true,
368            ManyToOne | OneToOne => build_size == expected_size,
369        };
370        polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
371        Ok(())
372    }
373}
374
375impl Display for JoinValidation {
376    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
377        let s = match self {
378            JoinValidation::ManyToMany => "m:m",
379            JoinValidation::ManyToOne => "m:1",
380            JoinValidation::OneToMany => "1:m",
381            JoinValidation::OneToOne => "1:1",
382        };
383        write!(f, "{s}")
384    }
385}
386
387impl Debug for JoinValidation {
388    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
389        write!(f, "JoinValidation: {self}")
390    }
391}