polars_core/chunked_array/ops/
gather.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::bitmap::Bitmap;
3use arrow::bitmap::bitmask::BitMask;
4use polars_compute::gather::take_unchecked;
5use polars_error::polars_ensure;
6use polars_utils::index::check_bounds;
7
8use crate::prelude::*;
9use crate::series::IsSorted;
10
11pub fn check_bounds_nulls(idx: &PrimitiveArray<IdxSize>, len: IdxSize) -> PolarsResult<()> {
12    let mask = BitMask::from_bitmap(idx.validity().unwrap());
13
14    // We iterate in chunks to make the inner loop branch-free.
15    for (block_idx, block) in idx.values().chunks(32).enumerate() {
16        let mut in_bounds = 0;
17        for (i, x) in block.iter().enumerate() {
18            in_bounds |= ((*x < len) as u32) << i;
19        }
20        let m = mask.get_u32(32 * block_idx);
21        polars_ensure!(m == m & in_bounds, ComputeError: "gather indices are out of bounds");
22    }
23    Ok(())
24}
25
26pub fn check_bounds_ca(indices: &IdxCa, len: IdxSize) -> PolarsResult<()> {
27    let all_valid = indices.downcast_iter().all(|a| {
28        if a.null_count() == 0 {
29            check_bounds(a.values(), len).is_ok()
30        } else {
31            check_bounds_nulls(a, len).is_ok()
32        }
33    });
34    polars_ensure!(all_valid, OutOfBounds: "gather indices are out of bounds");
35    Ok(())
36}
37
38impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTake<I> for ChunkedArray<T>
39where
40    ChunkedArray<T>: ChunkTakeUnchecked<I>,
41{
42    /// Gather values from ChunkedArray by index.
43    fn take(&self, indices: &I) -> PolarsResult<Self> {
44        check_bounds(indices.as_ref(), self.len() as IdxSize)?;
45
46        // SAFETY: we just checked the indices are valid.
47        Ok(unsafe { self.take_unchecked(indices) })
48    }
49}
50
51impl<T: PolarsDataType> ChunkTake<IdxCa> for ChunkedArray<T>
52where
53    ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
54{
55    /// Gather values from ChunkedArray by index.
56    fn take(&self, indices: &IdxCa) -> PolarsResult<Self> {
57        check_bounds_ca(indices, self.len() as IdxSize)?;
58
59        // SAFETY: we just checked the indices are valid.
60        Ok(unsafe { self.take_unchecked(indices) })
61    }
62}
63
64/// Computes cumulative lengths for efficient branchless binary search
65/// lookup. The first element is always 0, and the last length of arrs
66/// is always ignored (as we already checked that all indices are
67/// in-bounds we don't need to check against the last length).
68fn cumulative_lengths<A: StaticArray>(arrs: &[&A]) -> Vec<IdxSize> {
69    let mut ret = Vec::with_capacity(arrs.len());
70    let mut cumsum: IdxSize = 0;
71    for arr in arrs {
72        ret.push(cumsum);
73        cumsum = cumsum.checked_add(arr.len().try_into().unwrap()).unwrap();
74    }
75    ret
76}
77
78#[rustfmt::skip]
79#[inline]
80fn resolve_chunked_idx(idx: IdxSize, cumlens: &[IdxSize]) -> (usize, usize) {
81    let chunk_idx = cumlens.partition_point(|cl| idx >= *cl) - 1;
82    (chunk_idx, (idx - cumlens[chunk_idx]) as usize)
83}
84
85#[inline]
86unsafe fn target_value_unchecked<'a, A: StaticArray>(
87    targets: &[&'a A],
88    cumlens: &[IdxSize],
89    idx: IdxSize,
90) -> A::ValueT<'a> {
91    let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
92    let arr = targets.get_unchecked(chunk_idx);
93    arr.value_unchecked(arr_idx)
94}
95
96#[inline]
97unsafe fn target_get_unchecked<'a, A: StaticArray>(
98    targets: &[&'a A],
99    cumlens: &[IdxSize],
100    idx: IdxSize,
101) -> Option<A::ValueT<'a>> {
102    let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
103    let arr = targets.get_unchecked(chunk_idx);
104    arr.get_unchecked(arr_idx)
105}
106
107unsafe fn gather_idx_array_unchecked<A: StaticArray>(
108    dtype: ArrowDataType,
109    targets: &[&A],
110    has_nulls: bool,
111    indices: &[IdxSize],
112) -> A {
113    let it = indices.iter().copied();
114    if targets.len() == 1 {
115        let target = targets.first().unwrap();
116        if has_nulls {
117            it.map(|i| target.get_unchecked(i as usize))
118                .collect_arr_trusted_with_dtype(dtype)
119        } else if let Some(sl) = target.as_slice() {
120            // Avoid the Arc overhead from value_unchecked.
121            it.map(|i| sl.get_unchecked(i as usize).clone())
122                .collect_arr_trusted_with_dtype(dtype)
123        } else {
124            it.map(|i| target.value_unchecked(i as usize))
125                .collect_arr_trusted_with_dtype(dtype)
126        }
127    } else {
128        let cumlens = cumulative_lengths(targets);
129        if has_nulls {
130            it.map(|i| target_get_unchecked(targets, &cumlens, i))
131                .collect_arr_trusted_with_dtype(dtype)
132        } else {
133            it.map(|i| target_value_unchecked(targets, &cumlens, i))
134                .collect_arr_trusted_with_dtype(dtype)
135        }
136    }
137}
138
139impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ChunkedArray<T>
140where
141    T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
142{
143    /// Gather values from ChunkedArray by index.
144    unsafe fn take_unchecked(&self, indices: &I) -> Self {
145        let ca = self;
146        let targets: Vec<_> = ca.downcast_iter().collect();
147        let arr = gather_idx_array_unchecked(
148            ca.dtype().to_arrow(CompatLevel::newest()),
149            &targets,
150            ca.null_count() > 0,
151            indices.as_ref(),
152        );
153        ChunkedArray::from_chunk_iter_like(ca, [arr])
154    }
155}
156
157pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> IsSorted {
158    use crate::series::IsSorted::*;
159    match (sorted_arr, sorted_idx) {
160        (_, Not) => Not,
161        (Not, _) => Not,
162        (Ascending, Ascending) => Ascending,
163        (Ascending, Descending) => Descending,
164        (Descending, Ascending) => Descending,
165        (Descending, Descending) => Ascending,
166    }
167}
168
169impl<T: PolarsDataType> ChunkTakeUnchecked<IdxCa> for ChunkedArray<T>
170where
171    T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
172{
173    /// Gather values from ChunkedArray by index.
174    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
175        let ca = self;
176        let targets_have_nulls = ca.null_count() > 0;
177        let targets: Vec<_> = ca.downcast_iter().collect();
178
179        let chunks = indices.downcast_iter().map(|idx_arr| {
180            let dtype = ca.dtype().to_arrow(CompatLevel::newest());
181            if idx_arr.null_count() == 0 {
182                gather_idx_array_unchecked(dtype, &targets, targets_have_nulls, idx_arr.values())
183            } else if targets.len() == 1 {
184                let target = targets.first().unwrap();
185                if targets_have_nulls {
186                    idx_arr
187                        .iter()
188                        .map(|i| target.get_unchecked(*i? as usize))
189                        .collect_arr_trusted_with_dtype(dtype)
190                } else {
191                    idx_arr
192                        .iter()
193                        .map(|i| Some(target.value_unchecked(*i? as usize)))
194                        .collect_arr_trusted_with_dtype(dtype)
195                }
196            } else {
197                let cumlens = cumulative_lengths(&targets);
198                if targets_have_nulls {
199                    idx_arr
200                        .iter()
201                        .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
202                        .collect_arr_trusted_with_dtype(dtype)
203                } else {
204                    idx_arr
205                        .iter()
206                        .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
207                        .collect_arr_trusted_with_dtype(dtype)
208                }
209            }
210        });
211
212        let mut out = ChunkedArray::from_chunk_iter_like(ca, chunks);
213        let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
214
215        out.set_sorted_flag(sorted_flag);
216        out
217    }
218}
219
220impl ChunkTakeUnchecked<IdxCa> for BinaryChunked {
221    /// Gather values from ChunkedArray by index.
222    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
223        let ca = self;
224        let targets_have_nulls = ca.null_count() > 0;
225        let targets: Vec<_> = ca.downcast_iter().collect();
226
227        let chunks = indices.downcast_iter().map(|idx_arr| {
228            let dtype = ca.dtype().to_arrow(CompatLevel::newest());
229            if targets.len() == 1 {
230                let target = targets.first().unwrap();
231                take_unchecked(&**target, idx_arr)
232            } else {
233                let cumlens = cumulative_lengths(&targets);
234                if targets_have_nulls {
235                    let arr: BinaryViewArray = idx_arr
236                        .iter()
237                        .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
238                        .collect_arr_trusted_with_dtype(dtype);
239                    arr.to_boxed()
240                } else {
241                    let arr: BinaryViewArray = idx_arr
242                        .iter()
243                        .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
244                        .collect_arr_trusted_with_dtype(dtype);
245                    arr.to_boxed()
246                }
247            }
248        });
249
250        let mut out = ChunkedArray::from_chunks(ca.name().clone(), chunks.collect());
251        let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
252        out.set_sorted_flag(sorted_flag);
253        out
254    }
255}
256
257impl ChunkTakeUnchecked<IdxCa> for StringChunked {
258    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
259        let ca = self;
260        let targets_have_nulls = ca.null_count() > 0;
261        let targets: Vec<_> = ca.downcast_iter().collect();
262
263        let chunks = indices.downcast_iter().map(|idx_arr| {
264            let dtype = ca.dtype().to_arrow(CompatLevel::newest());
265            if targets.len() == 1 {
266                let target = targets.first().unwrap();
267                take_unchecked(&**target, idx_arr)
268            } else {
269                let cumlens = cumulative_lengths(&targets);
270                if targets_have_nulls {
271                    let arr: Utf8ViewArray = idx_arr
272                        .iter()
273                        .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
274                        .collect_arr_trusted_with_dtype(dtype);
275                    arr.to_boxed()
276                } else {
277                    let arr: Utf8ViewArray = idx_arr
278                        .iter()
279                        .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
280                        .collect_arr_trusted_with_dtype(dtype);
281                    arr.to_boxed()
282                }
283            }
284        });
285
286        let mut out = ChunkedArray::from_chunks(ca.name().clone(), chunks.collect());
287        let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
288        out.set_sorted_flag(sorted_flag);
289        out
290    }
291}
292
293impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for BinaryChunked {
294    /// Gather values from ChunkedArray by index.
295    unsafe fn take_unchecked(&self, indices: &I) -> Self {
296        let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
297        self.take_unchecked(&indices)
298    }
299}
300
301impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StringChunked {
302    /// Gather values from ChunkedArray by index.
303    unsafe fn take_unchecked(&self, indices: &I) -> Self {
304        let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
305        self.take_unchecked(&indices)
306    }
307}
308
309#[cfg(feature = "dtype-struct")]
310impl ChunkTakeUnchecked<IdxCa> for StructChunked {
311    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
312        let a = self.rechunk();
313        let index = indices.rechunk();
314
315        let chunks = a
316            .downcast_iter()
317            .zip(index.downcast_iter())
318            .map(|(arr, idx)| take_unchecked(arr, idx))
319            .collect::<Vec<_>>();
320        self.copy_with_chunks(chunks)
321    }
322}
323
324#[cfg(feature = "dtype-struct")]
325impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StructChunked {
326    unsafe fn take_unchecked(&self, indices: &I) -> Self {
327        let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
328        self.take_unchecked(&idx)
329    }
330}
331
332impl IdxCa {
333    pub fn with_nullable_idx<T, F: FnOnce(&IdxCa) -> T>(idx: &[NullableIdxSize], f: F) -> T {
334        let validity: Bitmap = idx.iter().map(|idx| !idx.is_null_idx()).collect_trusted();
335        let idx = bytemuck::cast_slice::<_, IdxSize>(idx);
336        let arr = unsafe { arrow::ffi::mmap::slice(idx) };
337        let arr = arr.with_validity_typed(Some(validity));
338        let ca = IdxCa::with_chunk(PlSmallStr::EMPTY, arr);
339
340        f(&ca)
341    }
342}
343
344#[cfg(feature = "dtype-array")]
345impl ChunkTakeUnchecked<IdxCa> for ArrayChunked {
346    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
347        let chunks = vec![take_unchecked(
348            self.rechunk().downcast_as_array(),
349            indices.rechunk().downcast_as_array(),
350        )];
351        self.copy_with_chunks(chunks)
352    }
353}
354
355#[cfg(feature = "dtype-array")]
356impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ArrayChunked {
357    unsafe fn take_unchecked(&self, indices: &I) -> Self {
358        let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
359        self.take_unchecked(&idx)
360    }
361}
362
363impl ChunkTakeUnchecked<IdxCa> for ListChunked {
364    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
365        let chunks = vec![take_unchecked(
366            self.rechunk().downcast_as_array(),
367            indices.rechunk().downcast_as_array(),
368        )];
369        self.copy_with_chunks(chunks)
370    }
371}
372
373impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ListChunked {
374    unsafe fn take_unchecked(&self, indices: &I) -> Self {
375        let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
376        self.take_unchecked(&idx)
377    }
378}