polars_core/chunked_array/ops/
gather.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::bitmap::Bitmap;
3use arrow::bitmap::bitmask::BitMask;
4use polars_compute::gather::take_unchecked;
5use polars_error::polars_ensure;
6use polars_utils::index::check_bounds;
7
8use crate::prelude::*;
9use crate::series::IsSorted;
10
11const BINARY_SEARCH_LIMIT: usize = 8;
12
13pub fn check_bounds_nulls(idx: &PrimitiveArray<IdxSize>, len: IdxSize) -> PolarsResult<()> {
14    let mask = BitMask::from_bitmap(idx.validity().unwrap());
15
16    // We iterate in chunks to make the inner loop branch-free.
17    for (block_idx, block) in idx.values().chunks(32).enumerate() {
18        let mut in_bounds = 0;
19        for (i, x) in block.iter().enumerate() {
20            in_bounds |= ((*x < len) as u32) << i;
21        }
22        let m = mask.get_u32(32 * block_idx);
23        polars_ensure!(m == m & in_bounds, ComputeError: "gather indices are out of bounds");
24    }
25    Ok(())
26}
27
28pub fn check_bounds_ca(indices: &IdxCa, len: IdxSize) -> PolarsResult<()> {
29    let all_valid = indices.downcast_iter().all(|a| {
30        if a.null_count() == 0 {
31            check_bounds(a.values(), len).is_ok()
32        } else {
33            check_bounds_nulls(a, len).is_ok()
34        }
35    });
36    polars_ensure!(all_valid, OutOfBounds: "gather indices are out of bounds");
37    Ok(())
38}
39
40impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTake<I> for ChunkedArray<T>
41where
42    ChunkedArray<T>: ChunkTakeUnchecked<I>,
43{
44    /// Gather values from ChunkedArray by index.
45    fn take(&self, indices: &I) -> PolarsResult<Self> {
46        check_bounds(indices.as_ref(), self.len() as IdxSize)?;
47
48        // SAFETY: we just checked the indices are valid.
49        Ok(unsafe { self.take_unchecked(indices) })
50    }
51}
52
53impl<T: PolarsDataType> ChunkTake<IdxCa> for ChunkedArray<T>
54where
55    ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
56{
57    /// Gather values from ChunkedArray by index.
58    fn take(&self, indices: &IdxCa) -> PolarsResult<Self> {
59        check_bounds_ca(indices, self.len() as IdxSize)?;
60
61        // SAFETY: we just checked the indices are valid.
62        Ok(unsafe { self.take_unchecked(indices) })
63    }
64}
65
66/// Computes cumulative lengths for efficient branchless binary search
67/// lookup. The first element is always 0, and the last length of arrs
68/// is always ignored (as we already checked that all indices are
69/// in-bounds we don't need to check against the last length).
70fn cumulative_lengths<A: StaticArray>(arrs: &[&A]) -> [IdxSize; BINARY_SEARCH_LIMIT] {
71    assert!(arrs.len() <= BINARY_SEARCH_LIMIT);
72    let mut ret = [IdxSize::MAX; BINARY_SEARCH_LIMIT];
73    ret[0] = 0;
74    for i in 1..arrs.len() {
75        ret[i] = ret[i - 1] + arrs[i - 1].len() as IdxSize;
76    }
77    ret
78}
79
80#[rustfmt::skip]
81#[inline]
82fn resolve_chunked_idx(idx: IdxSize, cumlens: &[IdxSize; BINARY_SEARCH_LIMIT]) -> (usize, usize) {
83    // Branchless bitwise binary search.
84    let mut chunk_idx = 0;
85    chunk_idx += if idx >= cumlens[chunk_idx + 0b100] { 0b0100 } else { 0 };
86    chunk_idx += if idx >= cumlens[chunk_idx + 0b010] { 0b0010 } else { 0 };
87    chunk_idx += if idx >= cumlens[chunk_idx + 0b001] { 0b0001 } else { 0 };
88    (chunk_idx, (idx - cumlens[chunk_idx]) as usize)
89}
90
91#[inline]
92unsafe fn target_value_unchecked<'a, A: StaticArray>(
93    targets: &[&'a A],
94    cumlens: &[IdxSize; BINARY_SEARCH_LIMIT],
95    idx: IdxSize,
96) -> A::ValueT<'a> {
97    let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
98    let arr = targets.get_unchecked(chunk_idx);
99    arr.value_unchecked(arr_idx)
100}
101
102#[inline]
103unsafe fn target_get_unchecked<'a, A: StaticArray>(
104    targets: &[&'a A],
105    cumlens: &[IdxSize; BINARY_SEARCH_LIMIT],
106    idx: IdxSize,
107) -> Option<A::ValueT<'a>> {
108    let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
109    let arr = targets.get_unchecked(chunk_idx);
110    arr.get_unchecked(arr_idx)
111}
112
113unsafe fn gather_idx_array_unchecked<A: StaticArray>(
114    dtype: ArrowDataType,
115    targets: &[&A],
116    has_nulls: bool,
117    indices: &[IdxSize],
118) -> A {
119    let it = indices.iter().copied();
120    if targets.len() == 1 {
121        let target = targets.first().unwrap();
122        if has_nulls {
123            it.map(|i| target.get_unchecked(i as usize))
124                .collect_arr_trusted_with_dtype(dtype)
125        } else if let Some(sl) = target.as_slice() {
126            // Avoid the Arc overhead from value_unchecked.
127            it.map(|i| sl.get_unchecked(i as usize).clone())
128                .collect_arr_trusted_with_dtype(dtype)
129        } else {
130            it.map(|i| target.value_unchecked(i as usize))
131                .collect_arr_trusted_with_dtype(dtype)
132        }
133    } else {
134        let cumlens = cumulative_lengths(targets);
135        if has_nulls {
136            it.map(|i| target_get_unchecked(targets, &cumlens, i))
137                .collect_arr_trusted_with_dtype(dtype)
138        } else {
139            it.map(|i| target_value_unchecked(targets, &cumlens, i))
140                .collect_arr_trusted_with_dtype(dtype)
141        }
142    }
143}
144
145impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ChunkedArray<T>
146where
147    T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
148{
149    /// Gather values from ChunkedArray by index.
150    unsafe fn take_unchecked(&self, indices: &I) -> Self {
151        let rechunked;
152        let mut ca = self;
153        if self.chunks().len() > BINARY_SEARCH_LIMIT {
154            rechunked = self.rechunk();
155            ca = &rechunked;
156        }
157        let targets: Vec<_> = ca.downcast_iter().collect();
158        let arr = gather_idx_array_unchecked(
159            ca.dtype().to_arrow(CompatLevel::newest()),
160            &targets,
161            ca.null_count() > 0,
162            indices.as_ref(),
163        );
164        ChunkedArray::from_chunk_iter_like(ca, [arr])
165    }
166}
167
168pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> IsSorted {
169    use crate::series::IsSorted::*;
170    match (sorted_arr, sorted_idx) {
171        (_, Not) => Not,
172        (Not, _) => Not,
173        (Ascending, Ascending) => Ascending,
174        (Ascending, Descending) => Descending,
175        (Descending, Ascending) => Descending,
176        (Descending, Descending) => Ascending,
177    }
178}
179
180impl<T: PolarsDataType> ChunkTakeUnchecked<IdxCa> for ChunkedArray<T>
181where
182    T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
183{
184    /// Gather values from ChunkedArray by index.
185    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
186        let rechunked;
187        let mut ca = self;
188        if self.chunks().len() > BINARY_SEARCH_LIMIT {
189            rechunked = self.rechunk();
190            ca = &rechunked;
191        }
192        let targets_have_nulls = ca.null_count() > 0;
193        let targets: Vec<_> = ca.downcast_iter().collect();
194
195        let chunks = indices.downcast_iter().map(|idx_arr| {
196            let dtype = ca.dtype().to_arrow(CompatLevel::newest());
197            if idx_arr.null_count() == 0 {
198                gather_idx_array_unchecked(dtype, &targets, targets_have_nulls, idx_arr.values())
199            } else if targets.len() == 1 {
200                let target = targets.first().unwrap();
201                if targets_have_nulls {
202                    idx_arr
203                        .iter()
204                        .map(|i| target.get_unchecked(*i? as usize))
205                        .collect_arr_trusted_with_dtype(dtype)
206                } else {
207                    idx_arr
208                        .iter()
209                        .map(|i| Some(target.value_unchecked(*i? as usize)))
210                        .collect_arr_trusted_with_dtype(dtype)
211                }
212            } else {
213                let cumlens = cumulative_lengths(&targets);
214                if targets_have_nulls {
215                    idx_arr
216                        .iter()
217                        .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
218                        .collect_arr_trusted_with_dtype(dtype)
219                } else {
220                    idx_arr
221                        .iter()
222                        .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
223                        .collect_arr_trusted_with_dtype(dtype)
224                }
225            }
226        });
227
228        let mut out = ChunkedArray::from_chunk_iter_like(ca, chunks);
229        let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
230
231        out.set_sorted_flag(sorted_flag);
232        out
233    }
234}
235
236impl ChunkTakeUnchecked<IdxCa> for BinaryChunked {
237    /// Gather values from ChunkedArray by index.
238    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
239        let rechunked = self.rechunk();
240        let indices = indices.rechunk();
241        let indices_arr = indices.downcast_iter().next().unwrap();
242        let chunks = rechunked
243            .chunks()
244            .iter()
245            .map(|arr| take_unchecked(arr.as_ref(), indices_arr))
246            .collect::<Vec<_>>();
247
248        let mut out = ChunkedArray::from_chunks(self.name().clone(), chunks);
249
250        let sorted_flag =
251            _update_gather_sorted_flag(self.is_sorted_flag(), indices.is_sorted_flag());
252        out.set_sorted_flag(sorted_flag);
253        out
254    }
255}
256
257impl ChunkTakeUnchecked<IdxCa> for StringChunked {
258    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
259        let rechunked = self.rechunk();
260        let indices = indices.rechunk();
261        let indices_arr = indices.downcast_iter().next().unwrap();
262        let chunks = rechunked
263            .chunks()
264            .iter()
265            .map(|arr| take_unchecked(arr.as_ref(), indices_arr))
266            .collect::<Vec<_>>();
267
268        let mut out = ChunkedArray::from_chunks(self.name().clone(), chunks);
269        let sorted_flag =
270            _update_gather_sorted_flag(self.is_sorted_flag(), indices.is_sorted_flag());
271        out.set_sorted_flag(sorted_flag);
272        out
273    }
274}
275
276impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for BinaryChunked {
277    /// Gather values from ChunkedArray by index.
278    unsafe fn take_unchecked(&self, indices: &I) -> Self {
279        let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
280        self.take_unchecked(&indices)
281    }
282}
283
284impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StringChunked {
285    /// Gather values from ChunkedArray by index.
286    unsafe fn take_unchecked(&self, indices: &I) -> Self {
287        let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
288        self.take_unchecked(&indices)
289    }
290}
291
292#[cfg(feature = "dtype-struct")]
293impl ChunkTakeUnchecked<IdxCa> for StructChunked {
294    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
295        let a = self.rechunk();
296        let index = indices.rechunk();
297
298        let chunks = a
299            .downcast_iter()
300            .zip(index.downcast_iter())
301            .map(|(arr, idx)| take_unchecked(arr, idx))
302            .collect::<Vec<_>>();
303        self.copy_with_chunks(chunks)
304    }
305}
306
307#[cfg(feature = "dtype-struct")]
308impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StructChunked {
309    unsafe fn take_unchecked(&self, indices: &I) -> Self {
310        let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
311        self.take_unchecked(&idx)
312    }
313}
314
315impl IdxCa {
316    pub fn with_nullable_idx<T, F: FnOnce(&IdxCa) -> T>(idx: &[NullableIdxSize], f: F) -> T {
317        let validity: Bitmap = idx.iter().map(|idx| !idx.is_null_idx()).collect_trusted();
318        let idx = bytemuck::cast_slice::<_, IdxSize>(idx);
319        let arr = unsafe { arrow::ffi::mmap::slice(idx) };
320        let arr = arr.with_validity_typed(Some(validity));
321        let ca = IdxCa::with_chunk(PlSmallStr::EMPTY, arr);
322
323        f(&ca)
324    }
325}
326
327#[cfg(feature = "dtype-array")]
328impl ChunkTakeUnchecked<IdxCa> for ArrayChunked {
329    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
330        let chunks = vec![take_unchecked(
331            self.rechunk().downcast_as_array(),
332            indices.rechunk().downcast_as_array(),
333        )];
334        self.copy_with_chunks(chunks)
335    }
336}
337
338#[cfg(feature = "dtype-array")]
339impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ArrayChunked {
340    unsafe fn take_unchecked(&self, indices: &I) -> Self {
341        let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
342        self.take_unchecked(&idx)
343    }
344}
345
346impl ChunkTakeUnchecked<IdxCa> for ListChunked {
347    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
348        let chunks = vec![take_unchecked(
349            self.rechunk().downcast_as_array(),
350            indices.rechunk().downcast_as_array(),
351        )];
352        self.copy_with_chunks(chunks)
353    }
354}
355
356impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ListChunked {
357    unsafe fn take_unchecked(&self, indices: &I) -> Self {
358        let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
359        self.take_unchecked(&idx)
360    }
361}