Skip to main content

polars_core/chunked_array/
ndarray.rs

1use ndarray::prelude::*;
2use polars_utils::sync::SyncPtr;
3use rayon::prelude::*;
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use crate::prelude::*;
8use crate::runtime::RAYON;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
11#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
12pub enum IndexOrder {
13    C,
14    #[default]
15    Fortran,
16}
17
18impl<T> ChunkedArray<T>
19where
20    T: PolarsNumericType,
21{
22    /// If data is aligned in a single chunk and has no Null values a zero copy view is returned
23    /// as an [ndarray]
24    pub fn to_ndarray(&self) -> PolarsResult<ArrayView1<'_, T::Native>> {
25        let slice = self.cont_slice()?;
26        Ok(aview1(slice))
27    }
28}
29
30impl ListChunked {
31    /// If all nested [`Series`] have the same length, a 2 dimensional [`ndarray::Array`] is returned.
32    pub fn to_ndarray<N>(&self) -> PolarsResult<Array2<N::Native>>
33    where
34        N: PolarsNumericType,
35    {
36        polars_ensure!(
37            self.null_count() == 0,
38            ComputeError: "creation of ndarray with null values is not supported"
39        );
40
41        // first iteration determine the size
42        let mut iter = self.no_null_series_iter();
43        let series = iter
44            .next()
45            .ok_or_else(|| polars_err!(NoData: "unable to create ndarray of empty ListChunked"))?;
46
47        let width = series.len();
48        let mut row_idx = 0;
49        let mut ndarray = ndarray::Array::uninit((self.len(), width));
50
51        let series = series.cast(&N::get_static_dtype())?;
52        let ca = series.unpack::<N>()?;
53        let a = ca.to_ndarray()?;
54        let mut row = ndarray.slice_mut(s![row_idx, ..]);
55        a.assign_to(&mut row);
56        row_idx += 1;
57
58        for series in iter {
59            polars_ensure!(
60                series.len() == width,
61                ShapeMismatch: "unable to create a 2-D array, series have different lengths"
62            );
63            let series = series.cast(&N::get_static_dtype())?;
64            let ca = series.unpack::<N>()?;
65            let a = ca.to_ndarray()?;
66            let mut row = ndarray.slice_mut(s![row_idx, ..]);
67            a.assign_to(&mut row);
68            row_idx += 1;
69        }
70
71        debug_assert_eq!(row_idx, self.len());
72        // SAFETY:
73        // We have assigned to every row and element of the array
74        unsafe { Ok(ndarray.assume_init()) }
75    }
76}
77
78impl DataFrame {
79    /// Create a 2D [`ndarray::Array`] from this [`DataFrame`]. This requires all columns in the
80    /// [`DataFrame`] to be non-null and numeric. They will be cast to the same data type
81    /// (if they aren't already).
82    ///
83    /// For floating point data we implicitly convert `None` to `NaN` without failure.
84    ///
85    /// ```rust
86    /// use polars_core::prelude::*;
87    /// let a = UInt32Chunked::new("a".into(), &[1, 2, 3]).into_column();
88    /// let b = Float64Chunked::new("b".into(), &[10., 8., 6.]).into_column();
89    ///
90    /// let df = DataFrame::new_infer_height(vec![a, b]).unwrap();
91    /// let ndarray = df.to_ndarray::<Float64Type>(IndexOrder::Fortran).unwrap();
92    /// println!("{:?}", ndarray);
93    /// ```
94    /// Outputs:
95    /// ```text
96    /// [[1.0, 10.0],
97    ///  [2.0, 8.0],
98    ///  [3.0, 6.0]], shape=[3, 2], strides=[1, 3], layout=Ff (0xa), const ndim=2
99    /// ```
100    pub fn to_ndarray<N>(&self, ordering: IndexOrder) -> PolarsResult<Array2<N::Native>>
101    where
102        N: PolarsNumericType,
103    {
104        let shape = self.shape();
105        let height = self.height();
106        let columns = self.columns();
107        let num_cols = columns.len();
108
109        let mut membuf: Vec<N::Native> = Vec::with_capacity(shape.0.checked_mul(shape.1).unwrap());
110        // SAFETY: parallel work units below write to disjoint regions of `membuf`.
111        let ptr = unsafe { SyncPtr::new(membuf.as_mut_ptr()) };
112
113        // Cast a column to N's dtype, replace float nulls with NaN, error on remaining
114        // nulls. Shared by both writer paths.
115        let cast_to_target = |s: &Column| -> PolarsResult<Series> {
116            let s = s.as_materialized_series().cast(&N::get_static_dtype())?;
117            let s = match s.dtype() {
118                DataType::Float32 => s.f32().unwrap().none_to_nan().into_series(),
119                DataType::Float64 => s.f64().unwrap().none_to_nan().into_series(),
120                _ => s,
121            };
122            polars_ensure!(
123                s.null_count() == 0,
124                ComputeError: "creation of ndarray with null values is not supported"
125            );
126            Ok(s)
127        };
128
129        match (ordering, num_cols) {
130            // F-order, or C-order with 0/1 columns (same memory layout). Per-column
131            // parallel writer; each column owns a disjoint contiguous stripe of the
132            // output buffer.
133            (IndexOrder::Fortran, _) | (IndexOrder::C, 0 | 1) => {
134                RAYON.install(|| {
135                    columns.par_iter().enumerate().try_for_each(
136                        |(col_idx, s)| -> PolarsResult<()> {
137                            let s = cast_to_target(s)?;
138                            let ca = s.unpack::<N>()?;
139
140                            let mut chunk_offset = 0;
141                            for arr in ca.downcast_iter() {
142                                let vals = arr.values();
143
144                                // SAFETY:
145                                // We get parallel access to the vector by offsetting index access
146                                // accordingly. We only operate on n contiguous elements, offset by
147                                // n * the column index.
148                                unsafe {
149                                    let dst = ptr.get().add(col_idx * height + chunk_offset);
150                                    std::ptr::copy_nonoverlapping(vals.as_ptr(), dst, vals.len());
151                                }
152                                chunk_offset += vals.len();
153                            }
154
155                            Ok(())
156                        },
157                    )
158                })?;
159            },
160            // C-order with > 1 column. Cache-blocked transpose writer: row-block
161            // parallel work units own disjoint output rows; each column's row-block
162            // is gathered into a register-blocked sub-tile then bulk-written
163            // contiguously. Avoids the per-column strided write that would touch
164            // one cache line per element and false-share between threads.
165            (IndexOrder::C, _) => {
166                // Sequential below ~1M cells (~4 MB f32) skips rayon dispatch overhead.
167                const PARALLEL_THRESHOLD: usize = 1_000_000;
168                let parallel = num_cols.saturating_mul(height) >= PARALLEL_THRESHOLD;
169
170                // The cast result must outlive Phase 2; the per-chunk slices below
171                // borrow into it. No-op when source dtype already matches and the
172                // column has no nulls. When a cast is needed it costs an extra
173                // memory pass over the source; fusing it into the sub-tile gather
174                // would halve the cast-required traffic but needs per-source-dtype
175                // specialisation of the gather.
176                let cast_columns: Vec<Series> = if parallel {
177                    RAYON.install(|| {
178                        columns
179                            .par_iter()
180                            .map(cast_to_target)
181                            .collect::<PolarsResult<_>>()
182                    })?
183                } else {
184                    columns
185                        .iter()
186                        .map(cast_to_target)
187                        .collect::<PolarsResult<_>>()?
188                };
189
190                const COL_BLOCK: usize = 64;
191                const TARGET_BLOCK_CELLS: usize = 32_768;
192                const MIN_ROW_BLOCK: usize = 64;
193                // row_block scales inversely with num_cols so each work unit carries
194                // about TARGET_BLOCK_CELLS cells, clamped at MIN_ROW_BLOCK for narrow
195                // frames.
196                let row_block = (TARGET_BLOCK_CELLS / num_cols.max(1)).max(MIN_ROW_BLOCK);
197                let num_blocks = height.div_ceil(row_block);
198
199                let column_chunks: Vec<Vec<&[N::Native]>> = cast_columns
200                    .iter()
201                    .map(|s| {
202                        s.unpack::<N>()
203                            .unwrap()
204                            .downcast_iter()
205                            .map(|arr| arr.values().as_slice())
206                            .collect()
207                    })
208                    .collect();
209
210                // Cursor into one column's chunk list. Advanced only at chunk boundaries.
211                #[derive(Clone, Copy, Default)]
212                struct Cursor {
213                    idx: usize, // chunk index
214                    off: usize, // first row of the chunk in the column's row space
215                    end: usize, // first row past the chunk
216                }
217
218                // Sub-tile dimension for the register-blocked transpose. SUB=32 keeps
219                // the SUB*SUB stack scratch under L1 (4 KB read + 4 KB write working
220                // set per sub-tile) and makes each row write 128 B for f32, which
221                // lowers to ~8 SIMD stores via copy_nonoverlapping.
222                const SUB: usize = 32;
223
224                let writer = |block: usize| {
225                    let row_start = block * row_block;
226                    let row_end = (row_start + row_block).min(height);
227                    let block_rows = row_end - row_start;
228
229                    for col_start in (0..num_cols).step_by(COL_BLOCK) {
230                        let block_cols = (num_cols - col_start).min(COL_BLOCK);
231
232                        // Position each cursor at row_start. The length-accumulator
233                        // walk skips zero-length chunks naturally.
234                        let mut cursors: [Cursor; COL_BLOCK] = std::array::from_fn(|ci_offset| {
235                            if ci_offset >= block_cols {
236                                return Cursor::default();
237                            }
238                            let chunks = &column_chunks[col_start + ci_offset];
239                            let mut acc = 0usize;
240                            let mut idx = 0;
241                            for (i, c) in chunks.iter().enumerate() {
242                                if acc + c.len() > row_start {
243                                    idx = i;
244                                    break;
245                                }
246                                acc += c.len();
247                            }
248                            Cursor {
249                                idx,
250                                off: acc,
251                                end: acc + chunks[idx].len(),
252                            }
253                        });
254
255                        for sr in (0..block_rows).step_by(SUB) {
256                            let tile_rows = (block_rows - sr).min(SUB);
257                            let abs_row_start = row_start + sr;
258
259                            for sc in (0..block_cols).step_by(SUB) {
260                                let tile_cols = (block_cols - sc).min(SUB);
261
262                                // Resolve each column's slice and starting offset
263                                // for this sub-tile. all_simple stays true when
264                                // every column's current chunk fully covers
265                                // tile_rows; the inner gather is then a tight loop
266                                // with no per-cell cursor work. It falls to false
267                                // when a chunk boundary lands inside this sub-tile.
268                                let mut col_slices: [&[N::Native]; SUB] = [&[]; SUB];
269                                let mut col_offs = [0usize; SUB];
270                                let mut all_simple = true;
271                                for ci in 0..tile_cols {
272                                    let chunks = &column_chunks[col_start + sc + ci];
273                                    let cur = &mut cursors[sc + ci];
274                                    // `while` (not `if`) skips zero-length chunks.
275                                    while abs_row_start >= cur.end {
276                                        cur.idx += 1;
277                                        cur.off = cur.end;
278                                        cur.end = cur.off + chunks[cur.idx].len();
279                                    }
280                                    if abs_row_start + tile_rows <= cur.end {
281                                        col_slices[ci] = chunks[cur.idx];
282                                        col_offs[ci] = abs_row_start - cur.off;
283                                    } else {
284                                        all_simple = false;
285                                        break;
286                                    }
287                                }
288
289                                let mut buf = [N::Native::default(); SUB * SUB];
290
291                                if all_simple {
292                                    for ri in 0..tile_rows {
293                                        let buf_row = ri * SUB;
294                                        for ci in 0..tile_cols {
295                                            // SAFETY: col_offs[ci] + ri < col_slices[ci].len()
296                                            // by the resolution check above.
297                                            unsafe {
298                                                buf[buf_row + ci] = *col_slices[ci]
299                                                    .get_unchecked(col_offs[ci] + ri);
300                                            }
301                                        }
302                                    }
303                                } else {
304                                    for ri in 0..tile_rows {
305                                        let abs_row = abs_row_start + ri;
306                                        let buf_row = ri * SUB;
307                                        for ci in 0..tile_cols {
308                                            let chunks = &column_chunks[col_start + sc + ci];
309                                            let cur = &mut cursors[sc + ci];
310                                            while abs_row >= cur.end {
311                                                cur.idx += 1;
312                                                cur.off = cur.end;
313                                                cur.end = cur.off + chunks[cur.idx].len();
314                                            }
315                                            // SAFETY: cursor advanced only while
316                                            // row < total length so cur.idx is in
317                                            // bounds; offset < chunks[cur.idx].len()
318                                            // by the chunk-end invariant.
319                                            unsafe {
320                                                buf[buf_row + ci] = *chunks[cur.idx]
321                                                    .get_unchecked(abs_row - cur.off);
322                                            }
323                                        }
324                                    }
325                                }
326
327                                for ri in 0..tile_rows {
328                                    let abs_row = abs_row_start + ri;
329                                    let dst_off = abs_row * num_cols + col_start + sc;
330                                    let src = &buf[ri * SUB..ri * SUB + tile_cols];
331                                    // SAFETY: dst_off + tile_cols <= height * num_cols;
332                                    // disjoint blocks own disjoint output rows.
333                                    unsafe {
334                                        std::ptr::copy_nonoverlapping(
335                                            src.as_ptr(),
336                                            ptr.get().add(dst_off),
337                                            tile_cols,
338                                        );
339                                    }
340                                }
341                            }
342                        }
343                    }
344                };
345                if parallel {
346                    RAYON.install(|| (0..num_blocks).into_par_iter().for_each(writer));
347                } else {
348                    (0..num_blocks).for_each(writer);
349                }
350            },
351        }
352
353        // SAFETY:
354        // we have written all data, so we can now safely set length
355        unsafe {
356            membuf.set_len(shape.0 * shape.1);
357        }
358        // Depending on the desired order, we can either return the array buffer as-is or reverse
359        // the axes.
360        match ordering {
361            IndexOrder::C => Ok(Array2::from_shape_vec((shape.0, shape.1), membuf).unwrap()),
362            IndexOrder::Fortran => {
363                let ndarr = Array2::from_shape_vec((shape.1, shape.0), membuf).unwrap();
364                Ok(ndarr.reversed_axes())
365            },
366        }
367    }
368}
369
370#[cfg(test)]
371mod test {
372    use super::*;
373
374    #[test]
375    fn test_ndarray_from_ca() -> PolarsResult<()> {
376        let ca = Float64Chunked::new(PlSmallStr::EMPTY, &[1.0, 2.0, 3.0]);
377        let ndarr = ca.to_ndarray()?;
378        assert_eq!(ndarr, ArrayView1::from(&[1.0, 2.0, 3.0]));
379
380        let mut builder = ListPrimitiveChunkedBuilder::<Float64Type>::new(
381            PlSmallStr::EMPTY,
382            10,
383            10,
384            DataType::Float64,
385        );
386        builder.append_opt_slice(Some(&[1.0, 2.0, 3.0]));
387        builder.append_opt_slice(Some(&[2.0, 4.0, 5.0]));
388        builder.append_opt_slice(Some(&[6.0, 7.0, 8.0]));
389        let list = builder.finish();
390
391        let ndarr = list.to_ndarray::<Float64Type>()?;
392        let expected = array![[1.0, 2.0, 3.0], [2.0, 4.0, 5.0], [6.0, 7.0, 8.0]];
393        assert_eq!(ndarr, expected);
394
395        // test list array that is not square
396        let mut builder = ListPrimitiveChunkedBuilder::<Float64Type>::new(
397            PlSmallStr::EMPTY,
398            10,
399            10,
400            DataType::Float64,
401        );
402        builder.append_opt_slice(Some(&[1.0, 2.0, 3.0]));
403        builder.append_opt_slice(Some(&[2.0]));
404        builder.append_opt_slice(Some(&[6.0, 7.0, 8.0]));
405        let list = builder.finish();
406        assert!(list.to_ndarray::<Float64Type>().is_err());
407        Ok(())
408    }
409
410    #[test]
411    fn test_ndarray_from_df_order_fortran() -> PolarsResult<()> {
412        let df = df!["a"=> [1.0, 2.0, 3.0],
413            "b" => [2.0, 3.0, 4.0]
414        ]?;
415
416        let ndarr = df.to_ndarray::<Float64Type>(IndexOrder::Fortran)?;
417        let expected = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
418        assert!(!ndarr.is_standard_layout());
419        assert_eq!(ndarr, expected);
420
421        Ok(())
422    }
423
424    #[test]
425    fn test_ndarray_from_df_order_c() -> PolarsResult<()> {
426        let df = df!["a"=> [1.0, 2.0, 3.0],
427            "b" => [2.0, 3.0, 4.0]
428        ]?;
429
430        let ndarr = df.to_ndarray::<Float64Type>(IndexOrder::C)?;
431        let expected = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
432        assert!(ndarr.is_standard_layout());
433        assert_eq!(ndarr, expected);
434
435        Ok(())
436    }
437}