1use polars_core::utils::{
2 _set_partition_size, CustomIterTools, NoNull, accumulate_dataframes_vertical_unchecked,
3 concat_df_unchecked, split,
4};
5use polars_utils::pl_str::PlSmallStr;
6
7use super::*;
8
9fn slice_take(
10 total_rows: IdxSize,
11 n_rows_right: IdxSize,
12 slice: Option<(i64, usize)>,
13 inner: fn(IdxSize, IdxSize, IdxSize) -> IdxCa,
14) -> IdxCa {
15 match slice {
16 None => inner(0, total_rows, n_rows_right),
17 Some((offset, len)) => {
18 let (offset, len) = slice_offsets(offset, len, total_rows as usize);
19 inner(offset as IdxSize, (len + offset) as IdxSize, n_rows_right)
20 },
21 }
22}
23
24fn take_left(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {
25 fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {
26 let mut take: NoNull<IdxCa> = (offset..total_rows)
27 .map(|i| i / n_rows_right)
28 .collect_trusted();
29 take.set_sorted_flag(IsSorted::Ascending);
30 take.into_inner()
31 }
32 slice_take(total_rows, n_rows_right, slice, inner)
33}
34
35fn take_right(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {
36 fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {
37 let take: NoNull<IdxCa> = (offset..total_rows)
38 .map(|i| i % n_rows_right)
39 .collect_trusted();
40 take.into_inner()
41 }
42 slice_take(total_rows, n_rows_right, slice, inner)
43}
44
45pub trait CrossJoin: IntoDf {
46 fn cross_join(
48 &self,
49 other: &DataFrame,
50 suffix: Option<PlSmallStr>,
51 slice: Option<(i64, usize)>,
52 maintain_order: MaintainOrderJoin,
53 ) -> PolarsResult<DataFrame> {
54 let (l_df, r_df) = cross_join_dfs(self.to_df(), other, slice, true, maintain_order)?;
55
56 _finish_join(l_df, r_df, suffix)
57 }
58}
59
60impl CrossJoin for DataFrame {}
61
62fn cross_join_dfs<'a>(
63 mut df_self: &'a DataFrame,
64 mut other: &'a DataFrame,
65 slice: Option<(i64, usize)>,
66 parallel: bool,
67 maintain_order: MaintainOrderJoin,
68) -> PolarsResult<(DataFrame, DataFrame)> {
69 if df_self.height() == 0 || other.height() == 0 {
70 return Ok((df_self.clear(), other.clear()));
71 }
72
73 let left_is_primary = match maintain_order {
74 MaintainOrderJoin::None => true,
75 MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => true,
76 MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => false,
77 };
78
79 if !left_is_primary {
80 core::mem::swap(&mut df_self, &mut other);
81 }
82
83 let n_rows_left = df_self.height() as IdxSize;
84 let n_rows_right = other.height() as IdxSize;
85 let Some(total_rows) = n_rows_left.checked_mul(n_rows_right) else {
86 polars_bail!(
87 ComputeError: "cross joins would produce more rows than fits into 2^32; \
88 consider compiling with polars-big-idx feature, or set 'streaming'"
89 );
90 };
91
92 let create_left_df = || {
101 unsafe {
104 df_self.take_unchecked_impl(&take_left(total_rows, n_rows_right, slice), parallel)
105 }
106 };
107
108 let create_right_df = || {
109 if n_rows_left > 100 || slice.is_some() {
113 unsafe {
116 other.take_unchecked_impl(&take_right(total_rows, n_rows_right, slice), parallel)
117 }
118 } else {
119 let iter = (0..n_rows_left).map(|_| other);
120 concat_df_unchecked(iter)
121 }
122 };
123 let (l_df, r_df) = if parallel {
124 try_raise_keyboard_interrupt();
125 POOL.install(|| rayon::join(create_left_df, create_right_df))
126 } else {
127 (create_left_df(), create_right_df())
128 };
129 if left_is_primary {
130 Ok((l_df, r_df))
131 } else {
132 Ok((r_df, l_df))
133 }
134}
135
136pub(super) fn fused_cross_filter(
137 left: &DataFrame,
138 right: &DataFrame,
139 suffix: Option<PlSmallStr>,
140 cross_join_options: &CrossJoinOptions,
141 maintain_order: MaintainOrderJoin,
142) -> PolarsResult<DataFrame> {
143 let unfiltered_size = (left.height() as u64).saturating_mul(right.height() as u64);
144 let chunk_size = (unfiltered_size / _set_partition_size() as u64).clamp(1, 100_000);
145 let num_chunks = (unfiltered_size / chunk_size).max(1) as usize;
146
147 let left_is_primary = match maintain_order {
148 MaintainOrderJoin::None => true,
149 MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight => true,
150 MaintainOrderJoin::Right | MaintainOrderJoin::RightLeft => false,
151 };
152
153 let split_chunks;
154 let cartesian_prod = if left_is_primary {
155 split_chunks = split(left, num_chunks);
156 split_chunks.iter().map(|l| (l, right)).collect::<Vec<_>>()
157 } else {
158 split_chunks = split(right, num_chunks);
159 split_chunks.iter().map(|r| (left, r)).collect::<Vec<_>>()
160 };
161
162 let names = _finish_join(left.clear(), right.clear(), suffix)?;
163 let rename_names = names.get_column_names();
164 let rename_names = &rename_names[left.width()..];
165
166 let dfs = POOL
167 .install(|| {
168 cartesian_prod.par_iter().map(|(left, right)| {
169 let (mut left, right) = cross_join_dfs(left, right, None, false, maintain_order)?;
170 let mut right_columns = right.take_columns();
171
172 for (c, name) in right_columns.iter_mut().zip(rename_names) {
173 c.rename((*name).clone());
174 }
175
176 unsafe { left.hstack_mut_unchecked(&right_columns) };
177
178 cross_join_options.predicate.apply(left)
179 })
180 })
181 .collect::<PolarsResult<Vec<_>>>()?;
182
183 Ok(accumulate_dataframes_vertical_unchecked(dfs))
184}