polars_core/series/ops/
reshape.rs

1use std::borrow::Cow;
2
3use arrow::array::*;
4use arrow::bitmap::Bitmap;
5use arrow::offset::{Offsets, OffsetsBuffer};
6use polars_compute::gather::sublist::list::array_to_unit_list;
7use polars_error::{PolarsResult, polars_bail, polars_ensure};
8use polars_utils::format_tuple;
9
10use crate::chunked_array::builder::get_list_builder;
11use crate::datatypes::{DataType, ListChunked};
12use crate::prelude::{IntoSeries, Series, *};
13
14fn reshape_fast_path(name: PlSmallStr, s: &Series) -> Series {
15    let mut ca = ListChunked::from_chunk_iter(
16        name,
17        s.chunks().iter().map(|arr| array_to_unit_list(arr.clone())),
18    );
19
20    ca.set_inner_dtype(s.dtype().clone());
21    ca.set_fast_explode();
22    ca.into_series()
23}
24
25impl Series {
26    /// Recurse nested types until we are at the leaf array.
27    pub fn get_leaf_array(&self) -> Series {
28        let s = self;
29        match s.dtype() {
30            #[cfg(feature = "dtype-array")]
31            DataType::Array(dtype, _) => {
32                let ca = s.array().unwrap();
33                let chunks = ca
34                    .downcast_iter()
35                    .map(|arr| arr.values().clone())
36                    .collect::<Vec<_>>();
37                // Safety: guarded by the type system
38                unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) }
39                    .get_leaf_array()
40            },
41            DataType::List(dtype) => {
42                let ca = s.list().unwrap();
43                let chunks = ca
44                    .downcast_iter()
45                    .map(|arr| arr.values().clone())
46                    .collect::<Vec<_>>();
47                // Safety: guarded by the type system
48                unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) }
49                    .get_leaf_array()
50            },
51            _ => s.clone(),
52        }
53    }
54
55    /// TODO: Move this somewhere else?
56    pub fn list_offsets_and_validities_recursive(
57        &self,
58    ) -> (Vec<OffsetsBuffer<i64>>, Vec<Option<Bitmap>>) {
59        let mut offsets = vec![];
60        let mut validities = vec![];
61
62        let mut s = self.rechunk();
63
64        while let DataType::List(_) = s.dtype() {
65            let ca = s.list().unwrap();
66            offsets.push(ca.offsets().unwrap());
67            validities.push(ca.rechunk_validity());
68            s = ca.get_inner();
69        }
70
71        (offsets, validities)
72    }
73
74    /// Convert the values of this Series to a ListChunked with a length of 1,
75    /// so a Series of `[1, 2, 3]` becomes `[[1, 2, 3]]`.
76    pub fn implode(&self) -> PolarsResult<ListChunked> {
77        let s = self;
78        let s = s.rechunk();
79        let values = s.array_ref(0);
80
81        let offsets = vec![0i64, values.len() as i64];
82        let inner_type = s.dtype();
83
84        let dtype = ListArray::<i64>::default_datatype(values.dtype().clone());
85
86        // SAFETY: offsets are correct.
87        let arr = unsafe {
88            ListArray::new(
89                dtype,
90                Offsets::new_unchecked(offsets).into(),
91                values.clone(),
92                None,
93            )
94        };
95
96        let mut ca = ListChunked::with_chunk(s.name().clone(), arr);
97        unsafe { ca.to_logical(inner_type.clone()) };
98        ca.set_fast_explode();
99        Ok(ca)
100    }
101
102    #[cfg(feature = "dtype-array")]
103    pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {
104        polars_ensure!(
105            !dimensions.is_empty(),
106            InvalidOperation: "at least one dimension must be specified"
107        );
108
109        let leaf_array = self.get_leaf_array().rechunk();
110        let size = leaf_array.len();
111
112        let mut total_dim_size = 1;
113        let mut num_infers = 0;
114        for &dim in dimensions {
115            match dim {
116                ReshapeDimension::Infer => num_infers += 1,
117                ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize,
118            }
119        }
120
121        polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension");
122
123        if size == 0 {
124            polars_ensure!(
125                num_infers > 0 || total_dim_size == 0,
126                InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}",
127                format_tuple!(dimensions),
128            );
129
130            let mut prev_arrow_dtype = leaf_array
131                .dtype()
132                .to_physical()
133                .to_arrow(CompatLevel::newest());
134            let mut prev_dtype = leaf_array.dtype().clone();
135            let mut prev_array = leaf_array.chunks()[0].clone();
136
137            // @NOTE: We need to collect the iterator here because it is lazily processed.
138            let mut current_length = dimensions[0].get_or_infer(0);
139            let len_iter = dimensions[1..]
140                .iter()
141                .map(|d| {
142                    let length = current_length as usize;
143                    current_length *= d.get_or_infer(0);
144                    length
145                })
146                .collect::<Vec<_>>();
147
148            // We pop the outer dimension as that is the height of the series.
149            for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() {
150                // Infer dimension if needed
151                let dim = dim.get_or_infer(0);
152                prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);
153                prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);
154
155                prev_array =
156                    FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None)
157                        .boxed();
158            }
159
160            return Ok(unsafe {
161                Series::from_chunks_and_dtype_unchecked(
162                    leaf_array.name().clone(),
163                    vec![prev_array],
164                    &prev_dtype,
165                )
166            });
167        }
168
169        polars_ensure!(
170            total_dim_size > 0,
171            InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}",
172            format_tuple!(dimensions)
173        );
174
175        polars_ensure!(
176            size.is_multiple_of(total_dim_size),
177            InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions)
178        );
179
180        let leaf_array = leaf_array.rechunk();
181        let mut prev_arrow_dtype = leaf_array
182            .dtype()
183            .to_physical()
184            .to_arrow(CompatLevel::newest());
185        let mut prev_dtype = leaf_array.dtype().clone();
186        let mut prev_array = leaf_array.chunks()[0].clone();
187
188        // We pop the outer dimension as that is the height of the series.
189        for dim in dimensions[1..].iter().rev() {
190            // Infer dimension if needed
191            let dim = dim.get_or_infer((size / total_dim_size) as u64);
192            prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);
193            prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);
194
195            prev_array = FixedSizeListArray::new(
196                prev_arrow_dtype.clone(),
197                prev_array.len() / dim as usize,
198                prev_array,
199                None,
200            )
201            .boxed();
202        }
203        Ok(unsafe {
204            Series::from_chunks_and_dtype_unchecked(
205                leaf_array.name().clone(),
206                vec![prev_array],
207                &prev_dtype,
208            )
209        })
210    }
211
212    pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult<Series> {
213        polars_ensure!(
214            !dimensions.is_empty(),
215            InvalidOperation: "at least one dimension must be specified"
216        );
217
218        let s = self;
219        let s = if let DataType::List(_) = s.dtype() {
220            Cow::Owned(s.explode(true)?)
221        } else {
222            Cow::Borrowed(s)
223        };
224
225        let s_ref = s.as_ref();
226
227        // let dimensions = dimensions.to_vec();
228
229        match dimensions.len() {
230            1 => {
231                polars_ensure!(
232                    dimensions[0].get().is_none_or( |dim| dim as usize == s_ref.len()),
233                    InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,
234                );
235                Ok(s_ref.clone())
236            },
237            2 => {
238                let rows = dimensions[0];
239                let cols = dimensions[1];
240
241                if s_ref.is_empty() {
242                    if rows.get_or_infer(0) == 0 && cols.get_or_infer(0) <= 1 {
243                        let s = reshape_fast_path(s.name().clone(), s_ref);
244                        return Ok(s);
245                    } else {
246                        polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {}", format_tuple!(dimensions))
247                    }
248                }
249
250                use ReshapeDimension as RD;
251                // Infer dimension.
252
253                let (rows, cols) = match (rows, cols) {
254                    (RD::Infer, RD::Specified(cols)) if cols.get() >= 1 => {
255                        (s_ref.len() as u64 / cols.get(), cols.get())
256                    },
257                    (RD::Specified(rows), RD::Infer) if rows.get() >= 1 => {
258                        (rows.get(), s_ref.len() as u64 / rows.get())
259                    },
260                    (RD::Infer, RD::Infer) => (s_ref.len() as u64, 1u64),
261                    (RD::Specified(rows), RD::Specified(cols)) => (rows.get(), cols.get()),
262                    _ => polars_bail!(InvalidOperation: "reshape of non-zero list into zero list"),
263                };
264
265                // Fast path, we can create a unit list so we only allocate offsets.
266                if rows as usize == s_ref.len() && cols == 1 {
267                    let s = reshape_fast_path(s.name().clone(), s_ref);
268                    return Ok(s);
269                }
270
271                polars_ensure!(
272                    (rows*cols) as usize == s_ref.len() && rows >= 1 && cols >= 1,
273                    InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions,
274                );
275
276                let mut builder =
277                    get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone());
278
279                let mut offset = 0u64;
280                for _ in 0..rows {
281                    let row = s_ref.slice(offset as i64, cols as usize);
282                    builder.append_series(&row).unwrap();
283                    offset += cols;
284                }
285                Ok(builder.finish().into_series())
286            },
287            _ => {
288                polars_bail!(InvalidOperation: "more than two dimensions not supported in reshaping to List.\n\nConsider reshaping to Array type.");
289            },
290        }
291    }
292}
293
294#[cfg(test)]
295mod test {
296    use super::*;
297    use crate::prelude::*;
298
299    #[test]
300    fn test_to_list() -> PolarsResult<()> {
301        let s = Series::new("a".into(), &[1, 2, 3]);
302
303        let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone());
304        builder.append_series(&s).unwrap();
305        let expected = builder.finish();
306
307        let out = s.implode()?;
308        assert!(expected.into_series().equals(&out.into_series()));
309
310        Ok(())
311    }
312
313    #[test]
314    fn test_reshape() -> PolarsResult<()> {
315        let s = Series::new("a".into(), &[1, 2, 3, 4]);
316
317        for (dims, list_len) in [
318            (&[-1, 1], 4),
319            (&[4, 1], 4),
320            (&[2, 2], 2),
321            (&[-1, 2], 2),
322            (&[2, -1], 2),
323        ] {
324            let dims = dims
325                .iter()
326                .map(|&v| ReshapeDimension::new(v))
327                .collect::<Vec<_>>();
328            let out = s.reshape_list(&dims)?;
329            assert_eq!(out.len(), list_len);
330            assert!(matches!(out.dtype(), DataType::List(_)));
331            assert_eq!(out.explode(false)?.len(), 4);
332        }
333
334        Ok(())
335    }
336}