Skip to main content

polars_core/frame/
horizontal.rs

1use polars_error::{PolarsResult, polars_err};
2
3use super::Column;
4use crate::datatypes::AnyValue;
5use crate::frame::DataFrame;
6use crate::frame::validation::validate_columns_slice;
7
8impl DataFrame {
9    /// Add columns horizontally.
10    ///
11    /// # Safety
12    /// The caller must ensure:
13    /// - the length of all [`Column`] is equal to the height of this [`DataFrame`]
14    /// - the columns names are unique
15    ///
16    /// Note: If `self` is empty, `self.height` will always be overridden by the height of the first
17    /// column in `columns`.
18    ///
19    /// Note that on a debug build this will panic on duplicates / height mismatch.
20    pub unsafe fn hstack_mut_unchecked(&mut self, columns: &[Column]) -> &mut Self {
21        if self.shape() == (0, 0)
22            && let Some(c) = columns.first()
23        {
24            unsafe { self.set_height(c.len()) };
25        }
26
27        unsafe { self.columns_mut() }.extend_from_slice(columns);
28
29        if cfg!(debug_assertions) {
30            if let err @ Err(_) = validate_columns_slice(self.height(), self.columns()) {
31                let initial_width = self.width() - columns.len();
32                unsafe { self.columns_mut() }.truncate(initial_width);
33                err.unwrap();
34            }
35        }
36
37        self
38    }
39
40    /// Add multiple [`Column`] to a [`DataFrame`].
41    /// Errors if the resulting DataFrame columns have duplicate names or unequal heights.
42    ///
43    /// Note: If `self` is empty, `self.height` will always be overridden by the height of the first
44    /// column in `columns`.
45    ///
46    /// # Example
47    ///
48    /// ```rust
49    /// # use polars_core::prelude::*;
50    /// fn stack(df: &mut DataFrame, columns: &[Column]) {
51    ///     df.hstack_mut(columns);
52    /// }
53    /// ```
54    pub fn hstack_mut(&mut self, columns: &[Column]) -> PolarsResult<&mut Self> {
55        if self.shape() == (0, 0)
56            && let Some(c) = columns.first()
57        {
58            unsafe { self.set_height(c.len()) };
59        }
60
61        unsafe { self.columns_mut() }.extend_from_slice(columns);
62
63        if let err @ Err(_) = validate_columns_slice(self.height(), self.columns()) {
64            let initial_width = self.width() - columns.len();
65            unsafe { self.columns_mut() }.truncate(initial_width);
66            err?;
67        }
68
69        Ok(self)
70    }
71}
72
73/// Concat [`DataFrame`]s horizontally.
74///
75/// If the lengths don't match and strict is false we pad with nulls, or return a `ShapeError` if strict is true.
76pub fn concat_df_horizontal(
77    dfs: &[DataFrame],
78    check_duplicates: bool,
79    strict: bool,
80    unit_length_as_scalar: bool,
81) -> PolarsResult<DataFrame> {
82    let output_height = dfs
83        .iter()
84        .map(|df| df.height())
85        .max()
86        .ok_or_else(|| polars_err!(ComputeError: "cannot concat empty dataframes"))?;
87
88    let owned_df;
89
90    let mut out_width = 0;
91
92    let all_equal_height = dfs.iter().filter(|df| df.shape() != (0, 0)).all(|df| {
93        out_width += df.width();
94        df.height() == output_height
95    });
96
97    // if not all equal length, extend the DataFrame with nulls
98    let dfs = if !all_equal_height {
99        if strict {
100            return Err(
101                polars_err!(ShapeMismatch: "cannot concat dataframes with different heights in 'strict' mode"),
102            );
103        }
104        out_width = 0;
105
106        owned_df = dfs
107            .iter()
108            .filter(|df| df.shape() != (0, 0))
109            .cloned()
110            .map(|mut df| {
111                out_width += df.width();
112                let h = df.height();
113
114                if h != output_height {
115                    if unit_length_as_scalar && h == 1 {
116                        // SAFETY: We extend each scalar column length to
117                        // `output_height`. Then, we set the height of the resulting dataframe.
118                        unsafe { df.columns_mut() }.iter_mut().for_each(|c| {
119                            let Column::Scalar(s) = c else {
120                                panic!("only supported for scalars");
121                            };
122
123                            *c = Column::Scalar(s.resize(output_height));
124                        });
125                    } else {
126                        let diff = output_height - h;
127
128                        // SAFETY: We extend each column with nulls to the point of being of length
129                        // `output_height`. Then, we set the height of the resulting dataframe.
130                        unsafe { df.columns_mut() }.iter_mut().for_each(|c| {
131                            *c = c.extend_constant(AnyValue::Null, diff).unwrap();
132                        });
133                    }
134                    unsafe {
135                        df.set_height(output_height);
136                    }
137                }
138
139                df
140            })
141            .collect::<Vec<_>>();
142        owned_df.as_slice()
143    } else {
144        dfs
145    };
146
147    let mut acc_cols = Vec::with_capacity(out_width);
148
149    for df in dfs {
150        acc_cols.extend(df.columns().iter().cloned());
151    }
152
153    let df = if check_duplicates {
154        DataFrame::new(output_height, acc_cols)?
155    } else {
156        unsafe { DataFrame::new_unchecked(output_height, acc_cols) }
157    };
158
159    Ok(df)
160}
161
162#[cfg(test)]
163mod tests {
164    use polars_error::PolarsError;
165
166    #[test]
167    fn test_hstack_mut_empty_frame_height_validation() {
168        use crate::frame::DataFrame;
169        use crate::prelude::{Column, DataType};
170        let mut df = DataFrame::empty();
171        let result = df.hstack_mut(&[
172            Column::full_null("a".into(), 1, &DataType::Null),
173            Column::full_null("b".into(), 3, &DataType::Null),
174        ]);
175
176        assert!(
177            matches!(result, Err(PolarsError::ShapeMismatch(_))),
178            "expected shape mismatch error"
179        );
180
181        // Ensure the DataFrame is not mutated in the error case.
182        assert_eq!(df.width(), 0);
183    }
184}