1use super::*;
2
3pub(super) type JoinIds = Vec<IdxSize>;
4pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
5pub type InnerJoinIds = (JoinIds, JoinIds);
6
7#[cfg(feature = "chunked_ids")]
8pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
9#[cfg(feature = "chunked_ids")]
10pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
11
12#[cfg(not(feature = "chunked_ids"))]
13pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
14
15#[cfg(not(feature = "chunked_ids"))]
16pub type ChunkJoinIds = Vec<IdxSize>;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20use strum_macros::IntoStaticStr;
21
22#[derive(Clone, PartialEq, Debug, Hash, Default)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
25pub struct JoinArgs {
26 pub how: JoinType,
27 pub validation: JoinValidation,
28 pub suffix: Option<PlSmallStr>,
29 pub slice: Option<(i64, usize)>,
30 pub nulls_equal: bool,
31 pub coalesce: JoinCoalesce,
32 pub maintain_order: MaintainOrderJoin,
33}
34
35impl JoinArgs {
36 pub fn should_coalesce(&self) -> bool {
37 self.coalesce.coalesce(&self.how)
38 }
39}
40
41#[derive(Clone, PartialEq, Hash, Default, IntoStaticStr)]
42#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
43#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
44pub enum JoinType {
45 #[default]
46 Inner,
47 Left,
48 Right,
49 Full,
50 #[cfg(feature = "asof_join")]
52 AsOf(Box<AsOfOptions>),
53 #[cfg(feature = "semi_anti_join")]
54 Semi,
55 #[cfg(feature = "semi_anti_join")]
56 Anti,
57 #[cfg(feature = "iejoin")]
58 IEJoin,
60 Cross,
62}
63
64#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
65#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
66#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
67pub enum JoinCoalesce {
68 #[default]
69 JoinSpecific,
70 CoalesceColumns,
71 KeepColumns,
72}
73
74impl JoinCoalesce {
75 pub fn coalesce(&self, join_type: &JoinType) -> bool {
76 use JoinCoalesce::*;
77 use JoinType::*;
78 match join_type {
79 Left | Inner | Right => {
80 matches!(self, JoinSpecific | CoalesceColumns)
81 },
82 Full => {
83 matches!(self, CoalesceColumns)
84 },
85 #[cfg(feature = "asof_join")]
86 AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
87 #[cfg(feature = "iejoin")]
88 IEJoin => false,
89 Cross => false,
90 #[cfg(feature = "semi_anti_join")]
91 Semi | Anti => false,
92 }
93 }
94}
95
96#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]
97#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
98#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
99#[strum(serialize_all = "snake_case")]
100pub enum MaintainOrderJoin {
101 #[default]
102 None,
103 Left,
104 Right,
105 LeftRight,
106 RightLeft,
107}
108
109impl MaintainOrderJoin {
110 pub(super) fn flip(&self) -> Self {
111 match self {
112 MaintainOrderJoin::None => MaintainOrderJoin::None,
113 MaintainOrderJoin::Left => MaintainOrderJoin::Right,
114 MaintainOrderJoin::Right => MaintainOrderJoin::Left,
115 MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,
116 MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,
117 }
118 }
119}
120
121impl JoinArgs {
122 pub fn new(how: JoinType) -> Self {
123 Self {
124 how,
125 validation: Default::default(),
126 suffix: None,
127 slice: None,
128 nulls_equal: false,
129 coalesce: Default::default(),
130 maintain_order: Default::default(),
131 }
132 }
133
134 pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
135 self.coalesce = coalesce;
136 self
137 }
138
139 pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
140 self.suffix = suffix;
141 self
142 }
143
144 pub fn suffix(&self) -> &PlSmallStr {
145 const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
146 self.suffix.as_ref().unwrap_or(DEFAULT)
147 }
148}
149
150impl From<JoinType> for JoinArgs {
151 fn from(value: JoinType) -> Self {
152 JoinArgs::new(value)
153 }
154}
155
156pub trait CrossJoinFilter: Send + Sync {
157 fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;
158}
159
160impl<T> CrossJoinFilter for T
161where
162 T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
163{
164 fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {
165 self(df)
166 }
167}
168
169#[derive(Clone)]
170pub struct CrossJoinOptions {
171 pub predicate: Arc<dyn CrossJoinFilter>,
172}
173
174impl CrossJoinOptions {
175 fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {
176 Arc::as_ptr(&self.predicate)
177 }
178}
179
180impl Eq for CrossJoinOptions {}
181
182impl PartialEq for CrossJoinOptions {
183 fn eq(&self, other: &Self) -> bool {
184 std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())
185 }
186}
187
188impl Hash for CrossJoinOptions {
189 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
190 self.as_ptr_ref().hash(state);
191 }
192}
193
194impl Debug for CrossJoinOptions {
195 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
196 write!(f, "CrossJoinOptions",)
197 }
198}
199
200#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]
201#[strum(serialize_all = "snake_case")]
202pub enum JoinTypeOptions {
203 #[cfg(feature = "iejoin")]
204 IEJoin(IEJoinOptions),
205 Cross(CrossJoinOptions),
206}
207
208impl Display for JoinType {
209 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
210 use JoinType::*;
211 let val = match self {
212 Left => "LEFT",
213 Right => "RIGHT",
214 Inner => "INNER",
215 Full => "FULL",
216 #[cfg(feature = "asof_join")]
217 AsOf(_) => "ASOF",
218 #[cfg(feature = "iejoin")]
219 IEJoin => "IEJOIN",
220 Cross => "CROSS",
221 #[cfg(feature = "semi_anti_join")]
222 Semi => "SEMI",
223 #[cfg(feature = "semi_anti_join")]
224 Anti => "ANTI",
225 };
226 write!(f, "{val}")
227 }
228}
229
230impl Debug for JoinType {
231 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
232 write!(f, "{self}")
233 }
234}
235
236impl JoinType {
237 pub fn is_equi(&self) -> bool {
238 matches!(
239 self,
240 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
241 )
242 }
243
244 pub fn is_semi_anti(&self) -> bool {
245 #[cfg(feature = "semi_anti_join")]
246 {
247 matches!(self, JoinType::Semi | JoinType::Anti)
248 }
249 #[cfg(not(feature = "semi_anti_join"))]
250 {
251 false
252 }
253 }
254
255 pub fn is_semi(&self) -> bool {
256 #[cfg(feature = "semi_anti_join")]
257 {
258 matches!(self, JoinType::Semi)
259 }
260 #[cfg(not(feature = "semi_anti_join"))]
261 {
262 false
263 }
264 }
265
266 pub fn is_anti(&self) -> bool {
267 #[cfg(feature = "semi_anti_join")]
268 {
269 matches!(self, JoinType::Anti)
270 }
271 #[cfg(not(feature = "semi_anti_join"))]
272 {
273 false
274 }
275 }
276
277 pub fn is_asof(&self) -> bool {
278 #[cfg(feature = "asof_join")]
279 {
280 matches!(self, JoinType::AsOf(_))
281 }
282 #[cfg(not(feature = "asof_join"))]
283 {
284 false
285 }
286 }
287
288 pub fn is_cross(&self) -> bool {
289 matches!(self, JoinType::Cross)
290 }
291
292 pub fn is_ie(&self) -> bool {
293 #[cfg(feature = "iejoin")]
294 {
295 matches!(self, JoinType::IEJoin)
296 }
297 #[cfg(not(feature = "iejoin"))]
298 {
299 false
300 }
301 }
302}
303
304#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
305#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
306#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
307pub enum JoinValidation {
308 #[default]
310 ManyToMany,
311 ManyToOne,
313 OneToMany,
315 OneToOne,
317}
318
319impl JoinValidation {
320 pub fn needs_checks(&self) -> bool {
321 !matches!(self, JoinValidation::ManyToMany)
322 }
323
324 fn swap(self, swap: bool) -> Self {
325 use JoinValidation::*;
326 if swap {
327 match self {
328 ManyToMany => ManyToMany,
329 ManyToOne => OneToMany,
330 OneToMany => ManyToOne,
331 OneToOne => OneToOne,
332 }
333 } else {
334 self
335 }
336 }
337
338 pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
339 if !self.needs_checks() {
340 return Ok(());
341 }
342 polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),
343 ComputeError: "{self} validation on a {join_type} join is not supported");
344 Ok(())
345 }
346
347 pub(super) fn validate_probe(
348 &self,
349 s_left: &Series,
350 s_right: &Series,
351 build_shortest_table: bool,
352 nulls_equal: bool,
353 ) -> PolarsResult<()> {
354 let should_swap = build_shortest_table && s_left.len() <= s_right.len();
362 let probe = if should_swap { s_right } else { s_left };
363
364 use JoinValidation::*;
365 let valid = match self.swap(should_swap) {
366 ManyToMany | ManyToOne => true,
369 OneToMany | OneToOne => {
370 if !nulls_equal && probe.null_count() > 0 {
371 probe.n_unique()? - 1 == probe.len() - probe.null_count()
372 } else {
373 probe.n_unique()? == probe.len()
374 }
375 },
376 };
377 polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
378 Ok(())
379 }
380
381 pub(super) fn validate_build(
382 &self,
383 build_size: usize,
384 expected_size: usize,
385 swapped: bool,
386 ) -> PolarsResult<()> {
387 use JoinValidation::*;
388
389 let valid = match self.swap(swapped) {
391 ManyToMany | OneToMany => true,
394 ManyToOne | OneToOne => build_size == expected_size,
395 };
396 polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
397 Ok(())
398 }
399}
400
401impl Display for JoinValidation {
402 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
403 let s = match self {
404 JoinValidation::ManyToMany => "m:m",
405 JoinValidation::ManyToOne => "m:1",
406 JoinValidation::OneToMany => "1:m",
407 JoinValidation::OneToOne => "1:1",
408 };
409 write!(f, "{s}")
410 }
411}
412
413impl Debug for JoinValidation {
414 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
415 write!(f, "JoinValidation: {self}")
416 }
417}