polars_ops/chunked_array/list/
get.rs

1use polars_core::prelude::{Column, IdxCa, Int64Chunked, ListChunked};
2use polars_core::series::Series;
3use polars_error::{PolarsResult, polars_bail};
4use polars_utils::IdxSize;
5
6use super::ListNameSpaceImpl;
7
8pub fn lst_get(ca: &ListChunked, index: &Int64Chunked, null_on_oob: bool) -> PolarsResult<Column> {
9    match index.len() {
10        1 => {
11            let index = index.get(0);
12            if let Some(index) = index {
13                ca.lst_get(index, null_on_oob).map(Column::from)
14            } else {
15                Ok(Column::full_null(
16                    ca.name().clone(),
17                    ca.len(),
18                    ca.inner_dtype(),
19                ))
20            }
21        },
22        len if len == ca.len() => {
23            let tmp = ca.rechunk();
24            let arr = tmp.downcast_as_array();
25            let offsets = arr.offsets().as_slice();
26            let take_by = if ca.null_count() == 0 {
27                index
28                    .iter()
29                    .enumerate()
30                    .map(|(i, opt_idx)| match opt_idx {
31                        Some(idx) => {
32                            let (start, end) = unsafe {
33                                (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1))
34                            };
35                            let offset = if idx >= 0 { start + idx } else { end + idx };
36                            if offset >= end || offset < start || start == end {
37                                if null_on_oob {
38                                    Ok(None)
39                                } else {
40                                    polars_bail!(ComputeError: "get index is out of bounds");
41                                }
42                            } else {
43                                Ok(Some(offset as IdxSize))
44                            }
45                        },
46                        None => Ok(None),
47                    })
48                    .collect::<Result<IdxCa, _>>()?
49            } else {
50                index
51                    .iter()
52                    .zip(arr.validity().unwrap())
53                    .enumerate()
54                    .map(|(i, (opt_idx, valid))| match (valid, opt_idx) {
55                        (true, Some(idx)) => {
56                            let (start, end) = unsafe {
57                                (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1))
58                            };
59                            let offset = if idx >= 0 { start + idx } else { end + idx };
60                            if offset >= end || offset < start || start == end {
61                                if null_on_oob {
62                                    Ok(None)
63                                } else {
64                                    polars_bail!(ComputeError: "get index is out of bounds");
65                                }
66                            } else {
67                                Ok(Some(offset as IdxSize))
68                            }
69                        },
70                        _ => Ok(None),
71                    })
72                    .collect::<Result<IdxCa, _>>()?
73            };
74            let s = Series::try_from((ca.name().clone(), arr.values().clone())).unwrap();
75            unsafe { s.take_unchecked(&take_by) }
76                .cast(ca.inner_dtype())
77                .map(Column::from)
78        },
79        _ if ca.len() == 1 => {
80            if ca.null_count() > 0 {
81                return Ok(Column::full_null(
82                    ca.name().clone(),
83                    index.len(),
84                    ca.inner_dtype(),
85                ));
86            }
87            let tmp = ca.rechunk();
88            let arr = tmp.downcast_as_array();
89            let offsets = arr.offsets().as_slice();
90            let start = offsets[0];
91            let end = offsets[1];
92            let out_of_bounds = |offset| offset >= end || offset < start || start == end;
93            let take_by: IdxCa = index
94                .iter()
95                .map(|opt_idx| match opt_idx {
96                    Some(idx) => {
97                        let offset = if idx >= 0 { start + idx } else { end + idx };
98                        if out_of_bounds(offset) {
99                            if null_on_oob {
100                                Ok(None)
101                            } else {
102                                polars_bail!(ComputeError: "get index is out of bounds");
103                            }
104                        } else {
105                            let Ok(offset) = IdxSize::try_from(offset) else {
106                                polars_bail!(ComputeError: "get index is out of bounds");
107                            };
108                            Ok(Some(offset))
109                        }
110                    },
111                    None => Ok(None),
112                })
113                .collect::<Result<IdxCa, _>>()?;
114
115            let s = Series::try_from((ca.name().clone(), arr.values().clone())).unwrap();
116            unsafe { s.take_unchecked(&take_by) }
117                .cast(ca.inner_dtype())
118                .map(Column::from)
119        },
120        len => polars_bail!(
121            ComputeError:
122            "`list.get` expression got an index array of length {} while the list has {} elements",
123            len, ca.len()
124        ),
125    }
126}