polars_core/series/ops/
reshape.rs

1use std::borrow::Cow;
2
3use arrow::array::*;
4use arrow::bitmap::Bitmap;
5use arrow::offset::{Offsets, OffsetsBuffer};
6use polars_compute::gather::sublist::list::array_to_unit_list;
7use polars_error::{PolarsResult, polars_bail, polars_ensure};
8use polars_utils::format_tuple;
9
10use crate::chunked_array::builder::get_list_builder;
11use crate::datatypes::{DataType, ListChunked};
12use crate::prelude::{IntoSeries, Series, *};
13
14fn reshape_fast_path(name: PlSmallStr, s: &Series) -> Series {
15    let mut ca = ListChunked::from_chunk_iter(
16        name,
17        s.chunks().iter().map(|arr| array_to_unit_list(arr.clone())),
18    );
19
20    ca.set_inner_dtype(s.dtype().clone());
21    ca.set_fast_explode();
22    ca.into_series()
23}
24
25impl Series {
26    /// Recurse nested types until we are at the leaf array.
27    pub fn get_leaf_array(&self) -> Series {
28        let s = self;
29        match s.dtype() {
30            #[cfg(feature = "dtype-array")]
31            DataType::Array(dtype, _) => {
32                let ca = s.array().unwrap();
33                let chunks = ca
34                    .downcast_iter()
35                    .map(|arr| arr.values().clone())
36                    .collect::<Vec<_>>();
37                // Safety: guarded by the type system
38                unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) }
39                    .get_leaf_array()
40            },
41            DataType::List(dtype) => {
42                let ca = s.list().unwrap();
43                let chunks = ca
44                    .downcast_iter()
45                    .map(|arr| arr.values().clone())
46                    .collect::<Vec<_>>();
47                // Safety: guarded by the type system
48                unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) }
49                    .get_leaf_array()
50            },
51            _ => s.clone(),
52        }
53    }
54
55    /// TODO: Move this somewhere else?
56    pub fn list_offsets_and_validities_recursive(
57        &self,
58    ) -> (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>) {
59        let mut offsets = vec![];
60        let mut validities = vec![];
61
62        let mut s = self.rechunk();
63
64        while let DataType::List(_) = s.dtype() {
65            let ca = s.list().unwrap();
66            offsets.push(ca.offsets().unwrap());
67            validities.push(ca.rechunk_validity());
68            s = ca.get_inner();
69        }
70
71        (offsets, validities)
72    }
73
74    /// For ListArrays, recursively normalizes the offsets to begin from 0, and
75    /// slices excess length from the values array.
76    pub fn list_rechunk_and_trim_to_normalized_offsets(&self) -> Self {
77        if let Some(ca) = self.try_list() {
78            ca.rechunk_and_trim_to_normalized_offsets().into_series()
79        } else {
80            self.rechunk()
81        }
82    }
83
84    /// Convert the values of this Series to a ListChunked with a length of 1,
85    /// so a Series of `[1, 2, 3]` becomes `[[1, 2, 3]]`.
86    pub fn implode(&self) -> PolarsResult<ListChunked> {
87        let s = self;
88        let s = s.rechunk();
89        let values = s.array_ref(0);
90
91        let offsets = vec![0i64, values.len() as i64];
92        let inner_type = s.dtype();
93
94        let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());
95
96        // SAFETY: offsets are correct.
97        let arr = unsafe {
98            ListArray::new(
99                dtype,
100                Offsets::new_unchecked(offsets).into(),
101                values.clone(),
102                None,
103            )
104        };
105
106        let mut ca = ListChunked::with_chunk(s.name().clone(), arr);
107        unsafe { ca.to_logical(inner_type.clone()) };
108        ca.set_fast_explode();
109        Ok(ca)
110    }
111
112    #[cfg(feature = "dtype-array")]
113    pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {
114        polars_ensure!(
115            !dimensions.is_empty(),
116            InvalidOperation: "at least one dimension must be specified"
117        );
118
119        let leaf_array = self.get_leaf_array().rechunk();
120        let size = leaf_array.len();
121
122        let mut total_dim_size = 1;
123        let mut num_infers = 0;
124        for &dim in dimensions {
125            match dim {
126                ReshapeDimension::Infer => num_infers += 1,
127                ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize,
128            }
129        }
130
131        polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension");
132
133        if size == 0 {
134            polars_ensure!(
135                num_infers > 0 || total_dim_size == 0,
136                InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}",
137                format_tuple!(dimensions),
138            );
139
140            let mut prev_arrow_dtype = leaf_array
141                .dtype()
142                .to_physical()
143                .to_arrow(CompatLevel::newest());
144            let mut prev_dtype = leaf_array.dtype().clone();
145            let mut prev_array = leaf_array.chunks()[0].clone();
146
147            // @NOTE: We need to collect the iterator here because it is lazily processed.
148            let mut current_length = dimensions[0].get_or_infer(0);
149            let len_iter = dimensions[1..]
150                .iter()
151                .map(|d| {
152                    let length = current_length as usize;
153                    current_length *= d.get_or_infer(0);
154                    length
155                })
156                .collect::<Vec<_>>();
157
158            // We pop the outer dimension as that is the height of the series.
159            for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() {
160                // Infer dimension if needed
161                let dim = dim.get_or_infer(0);
162                prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);
163                prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);
164
165                prev_array =
166                    FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None)
167                        .boxed();
168            }
169
170            return Ok(unsafe {
171                Series::from_chunks_and_dtype_unchecked(
172                    leaf_array.name().clone(),
173                    vec![prev_array],
174                    &prev_dtype,
175                )
176            });
177        }
178
179        polars_ensure!(
180            total_dim_size > 0,
181            InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}",
182            format_tuple!(dimensions)
183        );
184
185        polars_ensure!(
186            size % total_dim_size == 0,
187            InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions)
188        );
189
190        let leaf_array = leaf_array.rechunk();
191        let mut prev_arrow_dtype = leaf_array
192            .dtype()
193            .to_physical()
194            .to_arrow(CompatLevel::newest());
195        let mut prev_dtype = leaf_array.dtype().clone();
196        let mut prev_array = leaf_array.chunks()[0].clone();
197
198        // We pop the outer dimension as that is the height of the series.
199        for dim in dimensions[1..].iter().rev() {
200            // Infer dimension if needed
201            let dim = dim.get_or_infer((size / total_dim_size) as u64);
202            prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);
203            prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);
204
205            prev_array = FixedSizeListArray::new(
206                prev_arrow_dtype.clone(),
207                prev_array.len() / dim as usize,
208                prev_array,
209                None,
210            )
211            .boxed();
212        }
213        Ok(unsafe {
214            Series::from_chunks_and_dtype_unchecked(
215                leaf_array.name().clone(),
216                vec![prev_array],
217                &prev_dtype,
218            )
219        })
220    }
221
222    pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {
223        polars_ensure!(
224            !dimensions.is_empty(),
225            InvalidOperation: "at least one dimension must be specified"
226        );
227
228        let s = self;
229        let s = if let DataType::List(_) = s.dtype() {
230            Cow::Owned(s.explode()?)
231        } else {
232            Cow::Borrowed(s)
233        };
234
235        let s_ref = s.as_ref();
236
237        // let dimensions = dimensions.to_vec();
238
239        match dimensions.len() {
240            1 => {
241                polars_ensure!(
242                    dimensions[0].get().is_none_or( |dim| dim as usize == s_ref.len()),
243                    InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,
244                );
245                Ok(s_ref.clone())
246            },
247            2 => {
248                let rows = dimensions[0];
249                let cols = dimensions[1];
250
251                if s_ref.is_empty() {
252                    if rows.get_or_infer(0) == 0 && cols.get_or_infer(0) <= 1 {
253                        let s = reshape_fast_path(s.name().clone(), s_ref);
254                        return Ok(s);
255                    } else {
256                        polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {}", format_tuple!(dimensions))
257                    }
258                }
259
260                use ReshapeDimension as RD;
261                // Infer dimension.
262
263                let (rows, cols) = match (rows, cols) {
264                    (RD::Infer, RD::Specified(cols)) if cols.get() >= 1 => {
265                        (s_ref.len() as u64 / cols.get(), cols.get())
266                    },
267                    (RD::Specified(rows), RD::Infer) if rows.get() >= 1 => {
268                        (rows.get(), s_ref.len() as u64 / rows.get())
269                    },
270                    (RD::Infer, RD::Infer) => (s_ref.len() as u64, 1u64),
271                    (RD::Specified(rows), RD::Specified(cols)) => (rows.get(), cols.get()),
272                    _ => polars_bail!(InvalidOperation: "reshape of non-zero list into zero list"),
273                };
274
275                // Fast path, we can create a unit list so we only allocate offsets.
276                if rows as usize == s_ref.len() && cols == 1 {
277                    let s = reshape_fast_path(s.name().clone(), s_ref);
278                    return Ok(s);
279                }
280
281                polars_ensure!(
282                    (rows*cols) as usize == s_ref.len() && rows >= 1 && cols >= 1,
283                    InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,
284                );
285
286                let mut builder =
287                    get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone());
288
289                let mut offset = 0u64;
290                for _ in 0..rows {
291                    let row = s_ref.slice(offset as i64, cols as usize);
292                    builder.append_series(&row).unwrap();
293                    offset += cols;
294                }
295                Ok(builder.finish().into_series())
296            },
297            _ => {
298                polars_bail!(InvalidOperation: "more than two dimensions not supported in reshaping to List.\n\nConsider reshaping to Array type.");
299            },
300        }
301    }
302}
303
304#[cfg(test)]
305mod test {
306    use super::*;
307    use crate::prelude::*;
308
309    #[test]
310    fn test_to_list() -> PolarsResult<()> {
311        let s = Series::new("a".into(), &[1, 2, 3]);
312
313        let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone());
314        builder.append_series(&s).unwrap();
315        let expected = builder.finish();
316
317        let out = s.implode()?;
318        assert!(expected.into_series().equals(&out.into_series()));
319
320        Ok(())
321    }
322
323    #[test]
324    fn test_reshape() -> PolarsResult<()> {
325        let s = Series::new("a".into(), &[1, 2, 3, 4]);
326
327        for (dims, list_len) in [
328            (&[-1, 1], 4),
329            (&[4, 1], 4),
330            (&[2, 2], 2),
331            (&[-1, 2], 2),
332            (&[2, -1], 2),
333        ] {
334            let dims = dims
335                .iter()
336                .map(|&v| ReshapeDimension::new(v))
337                .collect::<Vec<_>>();
338            let out = s.reshape_list(&dims)?;
339            assert_eq!(out.len(), list_len);
340            assert!(matches!(out.dtype(), DataType::List(_)));
341            assert_eq!(out.explode()?.len(), 4);
342        }
343
344        Ok(())
345    }
346}