1use polars_core::utils::{
2 _set_partition_size, CustomIterTools, NoNull, accumulate_dataframes_vertical_unchecked,
3 concat_df_unchecked, split,
4};
5use polars_utils::pl_str::PlSmallStr;
6
7use super::*;
8
9fn slice_take(
10 total_rows: IdxSize,
11 n_rows_right: IdxSize,
12 slice: Option<(i64, usize)>,
13 inner: fn(IdxSize, IdxSize, IdxSize) -> IdxCa,
14) -> IdxCa {
15 match slice {
16 None => inner(0, total_rows, n_rows_right),
17 Some((offset, len)) => {
18 let (offset, len) = slice_offsets(offset, len, total_rows as usize);
19 inner(offset as IdxSize, (len + offset) as IdxSize, n_rows_right)
20 },
21 }
22}
23
24fn take_left(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {
25 fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {
26 let mut take: NoNull<IdxCa> = (offset..total_rows)
27 .map(|i| i / n_rows_right)
28 .collect_trusted();
29 take.set_sorted_flag(IsSorted::Ascending);
30 take.into_inner()
31 }
32 slice_take(total_rows, n_rows_right, slice, inner)
33}
34
35fn take_right(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {
36 fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {
37 let take: NoNull<IdxCa> = (offset..total_rows)
38 .map(|i| i % n_rows_right)
39 .collect_trusted();
40 take.into_inner()
41 }
42 slice_take(total_rows, n_rows_right, slice, inner)
43}
44
45pub trait CrossJoin: IntoDf {
46 #[doc(hidden)]
47 fn _cross_join_with_names(
49 &self,
50 other: &DataFrame,
51 names: &[PlSmallStr],
52 ) -> PolarsResult<DataFrame> {
53 let (mut l_df, r_df) = cross_join_dfs(self.to_df(), other, None, false)?;
54 l_df.clear_schema();
55
56 unsafe {
57 l_df.get_columns_mut().extend_from_slice(r_df.get_columns());
58
59 l_df.get_columns_mut()
60 .iter_mut()
61 .zip(names)
62 .for_each(|(s, name)| {
63 if s.name() != name {
64 s.rename(name.clone());
65 }
66 });
67 }
68 Ok(l_df)
69 }
70
71 fn cross_join(
73 &self,
74 other: &DataFrame,
75 suffix: Option<PlSmallStr>,
76 slice: Option<(i64, usize)>,
77 ) -> PolarsResult<DataFrame> {
78 let (l_df, r_df) = cross_join_dfs(self.to_df(), other, slice, true)?;
79
80 _finish_join(l_df, r_df, suffix)
81 }
82}
83
84impl CrossJoin for DataFrame {}
85
86fn cross_join_dfs(
87 df_self: &DataFrame,
88 other: &DataFrame,
89 slice: Option<(i64, usize)>,
90 parallel: bool,
91) -> PolarsResult<(DataFrame, DataFrame)> {
92 let n_rows_left = df_self.height() as IdxSize;
93 let n_rows_right = other.height() as IdxSize;
94 let Some(total_rows) = n_rows_left.checked_mul(n_rows_right) else {
95 polars_bail!(
96 ComputeError: "cross joins would produce more rows than fits into 2^32; \
97 consider compiling with polars-big-idx feature, or set 'streaming'"
98 );
99 };
100 if n_rows_left == 0 || n_rows_right == 0 {
101 return Ok((df_self.clear(), other.clear()));
102 }
103
104 let create_left_df = || {
113 unsafe {
116 df_self.take_unchecked_impl(&take_left(total_rows, n_rows_right, slice), parallel)
117 }
118 };
119
120 let create_right_df = || {
121 if n_rows_left > 100 || slice.is_some() {
125 unsafe {
128 other.take_unchecked_impl(&take_right(total_rows, n_rows_right, slice), parallel)
129 }
130 } else {
131 let iter = (0..n_rows_left).map(|_| other);
132 concat_df_unchecked(iter)
133 }
134 };
135 let (l_df, r_df) = if parallel {
136 try_raise_keyboard_interrupt();
137 POOL.install(|| rayon::join(create_left_df, create_right_df))
138 } else {
139 (create_left_df(), create_right_df())
140 };
141 Ok((l_df, r_df))
142}
143
144pub(super) fn fused_cross_filter(
145 left: &DataFrame,
146 right: &DataFrame,
147 suffix: Option<PlSmallStr>,
148 cross_join_options: &CrossJoinOptions,
149) -> PolarsResult<DataFrame> {
150 let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2;
154 let splitted_a = split(left, n_partitions);
155 let splitted_b = split(right, n_partitions);
156
157 let cartesian_prod = splitted_a
158 .iter()
159 .flat_map(|l| splitted_b.iter().map(move |r| (l, r)))
160 .collect::<Vec<_>>();
161
162 let names = _finish_join(left.clear(), right.clear(), suffix)?;
163 let rename_names = names.get_column_names();
164 let rename_names = &rename_names[left.width()..];
165
166 let dfs = POOL
167 .install(|| {
168 cartesian_prod.par_iter().map(|(left, right)| {
169 let (mut left, right) = cross_join_dfs(left, right, None, false)?;
170 let mut right_columns = right.take_columns();
171
172 for (c, name) in right_columns.iter_mut().zip(rename_names) {
173 c.rename((*name).clone());
174 }
175
176 unsafe { left.hstack_mut_unchecked(&right_columns) };
177
178 cross_join_options.predicate.apply(left)
179 })
180 })
181 .collect::<PolarsResult<Vec<_>>>()?;
182
183 Ok(accumulate_dataframes_vertical_unchecked(dfs))
184}