polars_core/chunked_array/ops/
gather.rs
1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::bitmap::Bitmap;
3use arrow::bitmap::bitmask::BitMask;
4use polars_compute::gather::take_unchecked;
5use polars_error::polars_ensure;
6use polars_utils::index::check_bounds;
7
8use crate::prelude::*;
9use crate::series::IsSorted;
10
11pub fn check_bounds_nulls(idx: &PrimitiveArray<IdxSize>, len: IdxSize) -> PolarsResult<()> {
12 let mask = BitMask::from_bitmap(idx.validity().unwrap());
13
14 for (block_idx, block) in idx.values().chunks(32).enumerate() {
16 let mut in_bounds = 0;
17 for (i, x) in block.iter().enumerate() {
18 in_bounds |= ((*x < len) as u32) << i;
19 }
20 let m = mask.get_u32(32 * block_idx);
21 polars_ensure!(m == m & in_bounds, ComputeError: "gather indices are out of bounds");
22 }
23 Ok(())
24}
25
26pub fn check_bounds_ca(indices: &IdxCa, len: IdxSize) -> PolarsResult<()> {
27 let all_valid = indices.downcast_iter().all(|a| {
28 if a.null_count() == 0 {
29 check_bounds(a.values(), len).is_ok()
30 } else {
31 check_bounds_nulls(a, len).is_ok()
32 }
33 });
34 polars_ensure!(all_valid, OutOfBounds: "gather indices are out of bounds");
35 Ok(())
36}
37
38impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTake<I> for ChunkedArray<T>
39where
40 ChunkedArray<T>: ChunkTakeUnchecked<I>,
41{
42 fn take(&self, indices: &I) -> PolarsResult<Self> {
44 check_bounds(indices.as_ref(), self.len() as IdxSize)?;
45
46 Ok(unsafe { self.take_unchecked(indices) })
48 }
49}
50
51impl<T: PolarsDataType> ChunkTake<IdxCa> for ChunkedArray<T>
52where
53 ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
54{
55 fn take(&self, indices: &IdxCa) -> PolarsResult<Self> {
57 check_bounds_ca(indices, self.len() as IdxSize)?;
58
59 Ok(unsafe { self.take_unchecked(indices) })
61 }
62}
63
64fn cumulative_lengths<A: StaticArray>(arrs: &[&A]) -> Vec<IdxSize> {
69 let mut ret = Vec::with_capacity(arrs.len());
70 let mut cumsum: IdxSize = 0;
71 for arr in arrs {
72 ret.push(cumsum);
73 cumsum = cumsum.checked_add(arr.len().try_into().unwrap()).unwrap();
74 }
75 ret
76}
77
78#[rustfmt::skip]
79#[inline]
80fn resolve_chunked_idx(idx: IdxSize, cumlens: &[IdxSize]) -> (usize, usize) {
81 let chunk_idx = cumlens.partition_point(|cl| idx >= *cl) - 1;
82 (chunk_idx, (idx - cumlens[chunk_idx]) as usize)
83}
84
85#[inline]
86unsafe fn target_value_unchecked<'a, A: StaticArray>(
87 targets: &[&'a A],
88 cumlens: &[IdxSize],
89 idx: IdxSize,
90) -> A::ValueT<'a> {
91 let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
92 let arr = targets.get_unchecked(chunk_idx);
93 arr.value_unchecked(arr_idx)
94}
95
96#[inline]
97unsafe fn target_get_unchecked<'a, A: StaticArray>(
98 targets: &[&'a A],
99 cumlens: &[IdxSize],
100 idx: IdxSize,
101) -> Option<A::ValueT<'a>> {
102 let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
103 let arr = targets.get_unchecked(chunk_idx);
104 arr.get_unchecked(arr_idx)
105}
106
107unsafe fn gather_idx_array_unchecked<A: StaticArray>(
108 dtype: ArrowDataType,
109 targets: &[&A],
110 has_nulls: bool,
111 indices: &[IdxSize],
112) -> A {
113 let it = indices.iter().copied();
114 if targets.len() == 1 {
115 let target = targets.first().unwrap();
116 if has_nulls {
117 it.map(|i| target.get_unchecked(i as usize))
118 .collect_arr_trusted_with_dtype(dtype)
119 } else if let Some(sl) = target.as_slice() {
120 it.map(|i| sl.get_unchecked(i as usize).clone())
122 .collect_arr_trusted_with_dtype(dtype)
123 } else {
124 it.map(|i| target.value_unchecked(i as usize))
125 .collect_arr_trusted_with_dtype(dtype)
126 }
127 } else {
128 let cumlens = cumulative_lengths(targets);
129 if has_nulls {
130 it.map(|i| target_get_unchecked(targets, &cumlens, i))
131 .collect_arr_trusted_with_dtype(dtype)
132 } else {
133 it.map(|i| target_value_unchecked(targets, &cumlens, i))
134 .collect_arr_trusted_with_dtype(dtype)
135 }
136 }
137}
138
139impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ChunkedArray<T>
140where
141 T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
142{
143 unsafe fn take_unchecked(&self, indices: &I) -> Self {
145 let ca = self;
146 let targets: Vec<_> = ca.downcast_iter().collect();
147 let arr = gather_idx_array_unchecked(
148 ca.dtype().to_arrow(CompatLevel::newest()),
149 &targets,
150 ca.null_count() > 0,
151 indices.as_ref(),
152 );
153 ChunkedArray::from_chunk_iter_like(ca, [arr])
154 }
155}
156
157pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> IsSorted {
158 use crate::series::IsSorted::*;
159 match (sorted_arr, sorted_idx) {
160 (_, Not) => Not,
161 (Not, _) => Not,
162 (Ascending, Ascending) => Ascending,
163 (Ascending, Descending) => Descending,
164 (Descending, Ascending) => Descending,
165 (Descending, Descending) => Ascending,
166 }
167}
168
169impl<T: PolarsDataType> ChunkTakeUnchecked<IdxCa> for ChunkedArray<T>
170where
171 T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
172{
173 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
175 let ca = self;
176 let targets_have_nulls = ca.null_count() > 0;
177 let targets: Vec<_> = ca.downcast_iter().collect();
178
179 let chunks = indices.downcast_iter().map(|idx_arr| {
180 let dtype = ca.dtype().to_arrow(CompatLevel::newest());
181 if idx_arr.null_count() == 0 {
182 gather_idx_array_unchecked(dtype, &targets, targets_have_nulls, idx_arr.values())
183 } else if targets.len() == 1 {
184 let target = targets.first().unwrap();
185 if targets_have_nulls {
186 idx_arr
187 .iter()
188 .map(|i| target.get_unchecked(*i? as usize))
189 .collect_arr_trusted_with_dtype(dtype)
190 } else {
191 idx_arr
192 .iter()
193 .map(|i| Some(target.value_unchecked(*i? as usize)))
194 .collect_arr_trusted_with_dtype(dtype)
195 }
196 } else {
197 let cumlens = cumulative_lengths(&targets);
198 if targets_have_nulls {
199 idx_arr
200 .iter()
201 .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
202 .collect_arr_trusted_with_dtype(dtype)
203 } else {
204 idx_arr
205 .iter()
206 .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
207 .collect_arr_trusted_with_dtype(dtype)
208 }
209 }
210 });
211
212 let mut out = ChunkedArray::from_chunk_iter_like(ca, chunks);
213 let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
214
215 out.set_sorted_flag(sorted_flag);
216 out
217 }
218}
219
220impl ChunkTakeUnchecked<IdxCa> for BinaryChunked {
221 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
223 let ca = self;
224 let targets_have_nulls = ca.null_count() > 0;
225 let targets: Vec<_> = ca.downcast_iter().collect();
226
227 let chunks = indices.downcast_iter().map(|idx_arr| {
228 let dtype = ca.dtype().to_arrow(CompatLevel::newest());
229 if targets.len() == 1 {
230 let target = targets.first().unwrap();
231 take_unchecked(&**target, idx_arr)
232 } else {
233 let cumlens = cumulative_lengths(&targets);
234 if targets_have_nulls {
235 let arr: BinaryViewArray = idx_arr
236 .iter()
237 .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
238 .collect_arr_trusted_with_dtype(dtype);
239 arr.to_boxed()
240 } else {
241 let arr: BinaryViewArray = idx_arr
242 .iter()
243 .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
244 .collect_arr_trusted_with_dtype(dtype);
245 arr.to_boxed()
246 }
247 }
248 });
249
250 let mut out = ChunkedArray::from_chunks(ca.name().clone(), chunks.collect());
251 let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
252 out.set_sorted_flag(sorted_flag);
253 out
254 }
255}
256
257impl ChunkTakeUnchecked<IdxCa> for StringChunked {
258 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
259 let ca = self;
260 let targets_have_nulls = ca.null_count() > 0;
261 let targets: Vec<_> = ca.downcast_iter().collect();
262
263 let chunks = indices.downcast_iter().map(|idx_arr| {
264 let dtype = ca.dtype().to_arrow(CompatLevel::newest());
265 if targets.len() == 1 {
266 let target = targets.first().unwrap();
267 take_unchecked(&**target, idx_arr)
268 } else {
269 let cumlens = cumulative_lengths(&targets);
270 if targets_have_nulls {
271 let arr: Utf8ViewArray = idx_arr
272 .iter()
273 .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
274 .collect_arr_trusted_with_dtype(dtype);
275 arr.to_boxed()
276 } else {
277 let arr: Utf8ViewArray = idx_arr
278 .iter()
279 .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
280 .collect_arr_trusted_with_dtype(dtype);
281 arr.to_boxed()
282 }
283 }
284 });
285
286 let mut out = ChunkedArray::from_chunks(ca.name().clone(), chunks.collect());
287 let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
288 out.set_sorted_flag(sorted_flag);
289 out
290 }
291}
292
293impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for BinaryChunked {
294 unsafe fn take_unchecked(&self, indices: &I) -> Self {
296 let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
297 self.take_unchecked(&indices)
298 }
299}
300
301impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StringChunked {
302 unsafe fn take_unchecked(&self, indices: &I) -> Self {
304 let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
305 self.take_unchecked(&indices)
306 }
307}
308
309#[cfg(feature = "dtype-struct")]
310impl ChunkTakeUnchecked<IdxCa> for StructChunked {
311 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
312 let a = self.rechunk();
313 let index = indices.rechunk();
314
315 let chunks = a
316 .downcast_iter()
317 .zip(index.downcast_iter())
318 .map(|(arr, idx)| take_unchecked(arr, idx))
319 .collect::<Vec<_>>();
320 self.copy_with_chunks(chunks)
321 }
322}
323
324#[cfg(feature = "dtype-struct")]
325impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StructChunked {
326 unsafe fn take_unchecked(&self, indices: &I) -> Self {
327 let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
328 self.take_unchecked(&idx)
329 }
330}
331
332impl IdxCa {
333 pub fn with_nullable_idx<T, F: FnOnce(&IdxCa) -> T>(idx: &[NullableIdxSize], f: F) -> T {
334 let validity: Bitmap = idx.iter().map(|idx| !idx.is_null_idx()).collect_trusted();
335 let idx = bytemuck::cast_slice::<_, IdxSize>(idx);
336 let arr = unsafe { arrow::ffi::mmap::slice(idx) };
337 let arr = arr.with_validity_typed(Some(validity));
338 let ca = IdxCa::with_chunk(PlSmallStr::EMPTY, arr);
339
340 f(&ca)
341 }
342}
343
344#[cfg(feature = "dtype-array")]
345impl ChunkTakeUnchecked<IdxCa> for ArrayChunked {
346 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
347 let chunks = vec![take_unchecked(
348 self.rechunk().downcast_as_array(),
349 indices.rechunk().downcast_as_array(),
350 )];
351 self.copy_with_chunks(chunks)
352 }
353}
354
355#[cfg(feature = "dtype-array")]
356impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ArrayChunked {
357 unsafe fn take_unchecked(&self, indices: &I) -> Self {
358 let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
359 self.take_unchecked(&idx)
360 }
361}
362
363impl ChunkTakeUnchecked<IdxCa> for ListChunked {
364 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
365 let chunks = vec![take_unchecked(
366 self.rechunk().downcast_as_array(),
367 indices.rechunk().downcast_as_array(),
368 )];
369 self.copy_with_chunks(chunks)
370 }
371}
372
373impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ListChunked {
374 unsafe fn take_unchecked(&self, indices: &I) -> Self {
375 let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
376 self.take_unchecked(&idx)
377 }
378}