polars_ops/chunked_array/list/
get.rs

1use polars_core::prelude::{Column, IdxCa, Int64Chunked, ListChunked};
2use polars_core::series::Series;
3use polars_error::{PolarsResult, polars_bail};
4use polars_utils::IdxSize;
5
6use super::ListNameSpaceImpl;
7use crate::series::convert_and_bound_idx_ca;
8
9pub fn lst_get(ca: &ListChunked, index: &Int64Chunked, null_on_oob: bool) -> PolarsResult<Column> {
10    match index.len() {
11        1 => {
12            let index = index.get(0);
13            if let Some(index) = index {
14                ca.lst_get(index, null_on_oob).map(Column::from)
15            } else {
16                Ok(Column::full_null(
17                    ca.name().clone(),
18                    ca.len(),
19                    ca.inner_dtype(),
20                ))
21            }
22        },
23        len if len == ca.len() => {
24            let tmp = ca.rechunk();
25            let arr = tmp.downcast_as_array();
26            let offsets = arr.offsets().as_slice();
27            let take_by = if ca.null_count() == 0 {
28                index
29                    .iter()
30                    .enumerate()
31                    .map(|(i, opt_idx)| match opt_idx {
32                        Some(idx) => {
33                            let (start, end) = unsafe {
34                                (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1))
35                            };
36                            let offset = if idx >= 0 { start + idx } else { end + idx };
37                            if offset >= end || offset < start || start == end {
38                                if null_on_oob {
39                                    Ok(None)
40                                } else {
41                                    polars_bail!(ComputeError: "get index is out of bounds");
42                                }
43                            } else {
44                                Ok(Some(offset as IdxSize))
45                            }
46                        },
47                        None => Ok(None),
48                    })
49                    .collect::<Result<IdxCa, _>>()?
50            } else {
51                index
52                    .iter()
53                    .zip(arr.validity().unwrap())
54                    .enumerate()
55                    .map(|(i, (opt_idx, valid))| match (valid, opt_idx) {
56                        (true, Some(idx)) => {
57                            let (start, end) = unsafe {
58                                (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1))
59                            };
60                            let offset = if idx >= 0 { start + idx } else { end + idx };
61                            if offset >= end || offset < start || start == end {
62                                if null_on_oob {
63                                    Ok(None)
64                                } else {
65                                    polars_bail!(ComputeError: "get index is out of bounds");
66                                }
67                            } else {
68                                Ok(Some(offset as IdxSize))
69                            }
70                        },
71                        _ => Ok(None),
72                    })
73                    .collect::<Result<IdxCa, _>>()?
74            };
75            let s = Series::try_from((ca.name().clone(), arr.values().clone())).unwrap();
76            unsafe {
77                s.take_unchecked(&take_by)
78                    .from_physical_unchecked(ca.inner_dtype())
79                    .map(Column::from)
80            }
81        },
82        _ if ca.len() == 1 => {
83            if let Some(list) = ca.get(0) {
84                let idx = convert_and_bound_idx_ca(index, list.len(), null_on_oob)?;
85                let s = Series::try_from((ca.name().clone(), vec![list])).unwrap();
86                unsafe {
87                    s.take_unchecked(&idx)
88                        .from_physical_unchecked(ca.inner_dtype())
89                        .map(Column::from)
90                }
91            } else {
92                Ok(Column::full_null(
93                    ca.name().clone(),
94                    ca.len(),
95                    ca.inner_dtype(),
96                ))
97            }
98        },
99        len => polars_bail!(
100            ComputeError:
101            "`list.get` expression got an index array of length {} while the list has {} elements",
102            len, ca.len()
103        ),
104    }
105}