polars_ops/frame/join/
cross_join.rs

1use polars_core::utils::{
2    _set_partition_size, CustomIterTools, NoNull, accumulate_dataframes_vertical_unchecked,
3    concat_df_unchecked, split,
4};
5use polars_utils::pl_str::PlSmallStr;
6
7use super::*;
8
9fn slice_take(
10    total_rows: IdxSize,
11    n_rows_right: IdxSize,
12    slice: Option<(i64, usize)>,
13    inner: fn(IdxSize, IdxSize, IdxSize) -> IdxCa,
14) -> IdxCa {
15    match slice {
16        None => inner(0, total_rows, n_rows_right),
17        Some((offset, len)) => {
18            let (offset, len) = slice_offsets(offset, len, total_rows as usize);
19            inner(offset as IdxSize, (len + offset) as IdxSize, n_rows_right)
20        },
21    }
22}
23
24fn take_left(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {
25    fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {
26        let mut take: NoNull<IdxCa> = (offset..total_rows)
27            .map(|i| i / n_rows_right)
28            .collect_trusted();
29        take.set_sorted_flag(IsSorted::Ascending);
30        take.into_inner()
31    }
32    slice_take(total_rows, n_rows_right, slice, inner)
33}
34
35fn take_right(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, usize)>) -> IdxCa {
36    fn inner(offset: IdxSize, total_rows: IdxSize, n_rows_right: IdxSize) -> IdxCa {
37        let take: NoNull<IdxCa> = (offset..total_rows)
38            .map(|i| i % n_rows_right)
39            .collect_trusted();
40        take.into_inner()
41    }
42    slice_take(total_rows, n_rows_right, slice, inner)
43}
44
45pub trait CrossJoin: IntoDf {
46    #[doc(hidden)]
47    /// used by streaming
48    fn _cross_join_with_names(
49        &self,
50        other: &DataFrame,
51        names: &[PlSmallStr],
52    ) -> PolarsResult<DataFrame> {
53        let (mut l_df, r_df) = cross_join_dfs(self.to_df(), other, None, false)?;
54        l_df.clear_schema();
55
56        unsafe {
57            l_df.get_columns_mut().extend_from_slice(r_df.get_columns());
58
59            l_df.get_columns_mut()
60                .iter_mut()
61                .zip(names)
62                .for_each(|(s, name)| {
63                    if s.name() != name {
64                        s.rename(name.clone());
65                    }
66                });
67        }
68        Ok(l_df)
69    }
70
71    /// Creates the Cartesian product from both frames, preserves the order of the left keys.
72    fn cross_join(
73        &self,
74        other: &DataFrame,
75        suffix: Option<PlSmallStr>,
76        slice: Option<(i64, usize)>,
77    ) -> PolarsResult<DataFrame> {
78        let (l_df, r_df) = cross_join_dfs(self.to_df(), other, slice, true)?;
79
80        _finish_join(l_df, r_df, suffix)
81    }
82}
83
84impl CrossJoin for DataFrame {}
85
86fn cross_join_dfs(
87    df_self: &DataFrame,
88    other: &DataFrame,
89    slice: Option<(i64, usize)>,
90    parallel: bool,
91) -> PolarsResult<(DataFrame, DataFrame)> {
92    let n_rows_left = df_self.height() as IdxSize;
93    let n_rows_right = other.height() as IdxSize;
94    let Some(total_rows) = n_rows_left.checked_mul(n_rows_right) else {
95        polars_bail!(
96            ComputeError: "cross joins would produce more rows than fits into 2^32; \
97            consider compiling with polars-big-idx feature, or set 'streaming'"
98        );
99    };
100    if n_rows_left == 0 || n_rows_right == 0 {
101        return Ok((df_self.clear(), other.clear()));
102    }
103
104    // the left side has the Nth row combined with every row from right.
105    // So let's say we have the following no. of rows
106    // left: 3
107    // right: 4
108    //
109    // left take idx:   000011112222
110    // right take idx:  012301230123
111
112    let create_left_df = || {
113        // SAFETY:
114        // take left is in bounds
115        unsafe {
116            df_self.take_unchecked_impl(&take_left(total_rows, n_rows_right, slice), parallel)
117        }
118    };
119
120    let create_right_df = || {
121        // concatenation of dataframes is very expensive if we need to make the series mutable
122        // many times, these are atomic operations
123        // so we choose a different strategy at > 100 rows (arbitrarily small number)
124        if n_rows_left > 100 || slice.is_some() {
125            // SAFETY:
126            // take right is in bounds
127            unsafe {
128                other.take_unchecked_impl(&take_right(total_rows, n_rows_right, slice), parallel)
129            }
130        } else {
131            let iter = (0..n_rows_left).map(|_| other);
132            concat_df_unchecked(iter)
133        }
134    };
135    let (l_df, r_df) = if parallel {
136        try_raise_keyboard_interrupt();
137        POOL.install(|| rayon::join(create_left_df, create_right_df))
138    } else {
139        (create_left_df(), create_right_df())
140    };
141    Ok((l_df, r_df))
142}
143
144pub(super) fn fused_cross_filter(
145    left: &DataFrame,
146    right: &DataFrame,
147    suffix: Option<PlSmallStr>,
148    cross_join_options: &CrossJoinOptions,
149) -> PolarsResult<DataFrame> {
150    // Because we do a cartesian product, the number of partitions is squared.
151    // We take the sqrt, but we don't expect every partition to produce results and work can be
152    // imbalanced, so we multiply the number of partitions by 2;
153    let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2;
154    let splitted_a = split(left, n_partitions);
155    let splitted_b = split(right, n_partitions);
156
157    let cartesian_prod = splitted_a
158        .iter()
159        .flat_map(|l| splitted_b.iter().map(move |r| (l, r)))
160        .collect::<Vec<_>>();
161
162    let names = _finish_join(left.clear(), right.clear(), suffix)?;
163    let rename_names = names.get_column_names();
164    let rename_names = &rename_names[left.width()..];
165
166    let dfs = POOL
167        .install(|| {
168            cartesian_prod.par_iter().map(|(left, right)| {
169                let (mut left, right) = cross_join_dfs(left, right, None, false)?;
170                let mut right_columns = right.take_columns();
171
172                for (c, name) in right_columns.iter_mut().zip(rename_names) {
173                    c.rename((*name).clone());
174                }
175
176                unsafe { left.hstack_mut_unchecked(&right_columns) };
177
178                cross_join_options.predicate.apply(left)
179            })
180        })
181        .collect::<PolarsResult<Vec<_>>>()?;
182
183    Ok(accumulate_dataframes_vertical_unchecked(dfs))
184}