1use super::*;
2
3pub(super) type JoinIds = Vec<IdxSize>;
4pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);
5pub type InnerJoinIds = (JoinIds, JoinIds);
6
7#[cfg(feature = "chunked_ids")]
8pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;
9#[cfg(feature = "chunked_ids")]
10pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;
11
12#[cfg(not(feature = "chunked_ids"))]
13pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
14
15#[cfg(not(feature = "chunked_ids"))]
16pub type ChunkJoinIds = Vec<IdxSize>;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20use strum_macros::IntoStaticStr;
21
22#[derive(Clone, PartialEq, Eq, Debug, Hash, Default)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24pub struct JoinArgs {
25 pub how: JoinType,
26 pub validation: JoinValidation,
27 pub suffix: Option<PlSmallStr>,
28 pub slice: Option<(i64, usize)>,
29 pub nulls_equal: bool,
30 pub coalesce: JoinCoalesce,
31 pub maintain_order: MaintainOrderJoin,
32}
33
34impl JoinArgs {
35 pub fn should_coalesce(&self) -> bool {
36 self.coalesce.coalesce(&self.how)
37 }
38}
39
40#[derive(Clone, PartialEq, Eq, Hash, Default, IntoStaticStr)]
41#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
42pub enum JoinType {
43 #[default]
44 Inner,
45 Left,
46 Right,
47 Full,
48 #[cfg(feature = "asof_join")]
49 AsOf(AsOfOptions),
50 #[cfg(feature = "semi_anti_join")]
51 Semi,
52 #[cfg(feature = "semi_anti_join")]
53 Anti,
54 #[cfg(feature = "iejoin")]
55 IEJoin,
57 Cross,
59}
60
61#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]
62#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
63pub enum JoinCoalesce {
64 #[default]
65 JoinSpecific,
66 CoalesceColumns,
67 KeepColumns,
68}
69
70impl JoinCoalesce {
71 pub fn coalesce(&self, join_type: &JoinType) -> bool {
72 use JoinCoalesce::*;
73 use JoinType::*;
74 match join_type {
75 Left | Inner | Right => {
76 matches!(self, JoinSpecific | CoalesceColumns)
77 },
78 Full => {
79 matches!(self, CoalesceColumns)
80 },
81 #[cfg(feature = "asof_join")]
82 AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),
83 #[cfg(feature = "iejoin")]
84 IEJoin => false,
85 Cross => false,
86 #[cfg(feature = "semi_anti_join")]
87 Semi | Anti => false,
88 }
89 }
90}
91
92#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]
93#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
94#[strum(serialize_all = "snake_case")]
95pub enum MaintainOrderJoin {
96 #[default]
97 None,
98 Left,
99 Right,
100 LeftRight,
101 RightLeft,
102}
103
104impl MaintainOrderJoin {
105 pub(super) fn flip(&self) -> Self {
106 match self {
107 MaintainOrderJoin::None => MaintainOrderJoin::None,
108 MaintainOrderJoin::Left => MaintainOrderJoin::Right,
109 MaintainOrderJoin::Right => MaintainOrderJoin::Left,
110 MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,
111 MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,
112 }
113 }
114}
115
116impl JoinArgs {
117 pub fn new(how: JoinType) -> Self {
118 Self {
119 how,
120 validation: Default::default(),
121 suffix: None,
122 slice: None,
123 nulls_equal: false,
124 coalesce: Default::default(),
125 maintain_order: Default::default(),
126 }
127 }
128
129 pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
130 self.coalesce = coalesce;
131 self
132 }
133
134 pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {
135 self.suffix = suffix;
136 self
137 }
138
139 pub fn suffix(&self) -> &PlSmallStr {
140 const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
141 self.suffix.as_ref().unwrap_or(DEFAULT)
142 }
143}
144
145impl From<JoinType> for JoinArgs {
146 fn from(value: JoinType) -> Self {
147 JoinArgs::new(value)
148 }
149}
150
151pub trait CrossJoinFilter: Send + Sync {
152 fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;
153}
154
155impl<T> CrossJoinFilter for T
156where
157 T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,
158{
159 fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {
160 self(df)
161 }
162}
163
164#[derive(Clone)]
165pub struct CrossJoinOptions {
166 pub predicate: Arc<dyn CrossJoinFilter>,
167}
168
169impl CrossJoinOptions {
170 fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {
171 Arc::as_ptr(&self.predicate)
172 }
173}
174
175impl Eq for CrossJoinOptions {}
176
177impl PartialEq for CrossJoinOptions {
178 fn eq(&self, other: &Self) -> bool {
179 std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())
180 }
181}
182
183impl Hash for CrossJoinOptions {
184 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
185 self.as_ptr_ref().hash(state);
186 }
187}
188
189impl Debug for CrossJoinOptions {
190 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
191 write!(f, "CrossJoinOptions",)
192 }
193}
194
195#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]
196#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
197#[strum(serialize_all = "snake_case")]
198pub enum JoinTypeOptions {
199 #[cfg(feature = "iejoin")]
200 IEJoin(IEJoinOptions),
201 #[cfg_attr(feature = "serde", serde(skip))]
202 Cross(CrossJoinOptions),
203}
204
205impl Display for JoinType {
206 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
207 use JoinType::*;
208 let val = match self {
209 Left => "LEFT",
210 Right => "RIGHT",
211 Inner => "INNER",
212 Full => "FULL",
213 #[cfg(feature = "asof_join")]
214 AsOf(_) => "ASOF",
215 #[cfg(feature = "iejoin")]
216 IEJoin => "IEJOIN",
217 Cross => "CROSS",
218 #[cfg(feature = "semi_anti_join")]
219 Semi => "SEMI",
220 #[cfg(feature = "semi_anti_join")]
221 Anti => "ANTI",
222 };
223 write!(f, "{val}")
224 }
225}
226
227impl Debug for JoinType {
228 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
229 write!(f, "{self}")
230 }
231}
232
233impl JoinType {
234 pub fn is_equi(&self) -> bool {
235 matches!(
236 self,
237 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
238 )
239 }
240
241 pub fn is_semi_anti(&self) -> bool {
242 #[cfg(feature = "semi_anti_join")]
243 {
244 matches!(self, JoinType::Semi | JoinType::Anti)
245 }
246 #[cfg(not(feature = "semi_anti_join"))]
247 {
248 false
249 }
250 }
251
252 pub fn is_asof(&self) -> bool {
253 #[cfg(feature = "asof_join")]
254 {
255 matches!(self, JoinType::AsOf(_))
256 }
257 #[cfg(not(feature = "asof_join"))]
258 {
259 false
260 }
261 }
262
263 pub fn is_cross(&self) -> bool {
264 matches!(self, JoinType::Cross)
265 }
266
267 pub fn is_ie(&self) -> bool {
268 #[cfg(feature = "iejoin")]
269 {
270 matches!(self, JoinType::IEJoin)
271 }
272 #[cfg(not(feature = "iejoin"))]
273 {
274 false
275 }
276 }
277}
278
279#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
280#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
281pub enum JoinValidation {
282 #[default]
284 ManyToMany,
285 ManyToOne,
287 OneToMany,
289 OneToOne,
291}
292
293impl JoinValidation {
294 pub fn needs_checks(&self) -> bool {
295 !matches!(self, JoinValidation::ManyToMany)
296 }
297
298 fn swap(self, swap: bool) -> Self {
299 use JoinValidation::*;
300 if swap {
301 match self {
302 ManyToMany => ManyToMany,
303 ManyToOne => OneToMany,
304 OneToMany => ManyToOne,
305 OneToOne => OneToOne,
306 }
307 } else {
308 self
309 }
310 }
311
312 pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {
313 if !self.needs_checks() {
314 return Ok(());
315 }
316 polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),
317 ComputeError: "{self} validation on a {join_type} join is not supported");
318 Ok(())
319 }
320
321 pub(super) fn validate_probe(
322 &self,
323 s_left: &Series,
324 s_right: &Series,
325 build_shortest_table: bool,
326 nulls_equal: bool,
327 ) -> PolarsResult<()> {
328 let should_swap = build_shortest_table && s_left.len() <= s_right.len();
336 let probe = if should_swap { s_right } else { s_left };
337
338 use JoinValidation::*;
339 let valid = match self.swap(should_swap) {
340 ManyToMany | ManyToOne => true,
343 OneToMany | OneToOne => {
344 if !nulls_equal && probe.null_count() > 0 {
345 probe.n_unique()? - 1 == probe.len() - probe.null_count()
346 } else {
347 probe.n_unique()? == probe.len()
348 }
349 },
350 };
351 polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
352 Ok(())
353 }
354
355 pub(super) fn validate_build(
356 &self,
357 build_size: usize,
358 expected_size: usize,
359 swapped: bool,
360 ) -> PolarsResult<()> {
361 use JoinValidation::*;
362
363 let valid = match self.swap(swapped) {
365 ManyToMany | OneToMany => true,
368 ManyToOne | OneToOne => build_size == expected_size,
369 };
370 polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
371 Ok(())
372 }
373}
374
375impl Display for JoinValidation {
376 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
377 let s = match self {
378 JoinValidation::ManyToMany => "m:m",
379 JoinValidation::ManyToOne => "m:1",
380 JoinValidation::OneToMany => "1:m",
381 JoinValidation::OneToOne => "1:1",
382 };
383 write!(f, "{s}")
384 }
385}
386
387impl Debug for JoinValidation {
388 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
389 write!(f, "JoinValidation: {self}")
390 }
391}