polars_core/chunked_array/ops/
set.rs

1use arrow::bitmap::{Bitmap, MutableBitmap};
2use arrow::legacy::kernels::set::{scatter_single_non_null, set_with_mask};
3
4use crate::prelude::*;
5use crate::utils::align_chunks_binary;
6
7macro_rules! impl_scatter_with {
8    ($self:ident, $builder:ident, $idx:ident, $f:ident) => {{
9        let mut ca_iter = $self.into_iter().enumerate();
10
11        for current_idx in $idx.into_iter().map(|i| i as usize) {
12            polars_ensure!(current_idx < $self.len(), oob = current_idx, $self.len());
13            while let Some((cnt_idx, opt_val)) = ca_iter.next() {
14                if cnt_idx == current_idx {
15                    $builder.append_option($f(opt_val));
16                    break;
17                } else {
18                    $builder.append_option(opt_val);
19                }
20            }
21        }
22        // the last idx is probably not the last value so we finish the iterator
23        while let Some((_, opt_val)) = ca_iter.next() {
24            $builder.append_option(opt_val);
25        }
26
27        let ca = $builder.finish();
28        Ok(ca)
29    }};
30}
31
32macro_rules! check_bounds {
33    ($self:ident, $mask:ident) => {{
34        polars_ensure!(
35            $self.len() == $mask.len(),
36            ShapeMismatch: "invalid mask in `get` operation: shape doesn't match array's shape"
37        );
38    }};
39}
40
41impl<'a, T> ChunkSet<'a, T::Native, T::Native> for ChunkedArray<T>
42where
43    T: PolarsNumericType,
44{
45    fn scatter_single<I: IntoIterator<Item = IdxSize>>(
46        &'a self,
47        idx: I,
48        value: Option<T::Native>,
49    ) -> PolarsResult<Self> {
50        if !self.has_nulls() {
51            if let Some(value) = value {
52                // Fast path uses kernel.
53                if self.chunks.len() == 1 {
54                    let arr = scatter_single_non_null(
55                        self.downcast_iter().next().unwrap(),
56                        idx,
57                        value,
58                        T::get_dtype().to_arrow(CompatLevel::newest()),
59                    )?;
60                    return Ok(Self::with_chunk(self.name().clone(), arr));
61                }
62                // Other fast path. Slightly slower as it does not do a memcpy.
63                else {
64                    let mut av = Vec::with_capacity(self.len());
65                    for chunk in self.downcast_iter() {
66                        av.extend_from_slice(chunk.values())
67                    }
68                    let data = av.as_mut_slice();
69
70                    idx.into_iter().try_for_each::<_, PolarsResult<_>>(|idx| {
71                        let val = data
72                            .get_mut(idx as usize)
73                            .ok_or_else(|| polars_err!(oob = idx as usize, self.len()))?;
74                        *val = value;
75                        Ok(())
76                    })?;
77                    return Ok(Self::from_vec(self.name().clone(), av));
78                }
79            }
80        }
81        self.scatter_with(idx, |_| value)
82    }
83
84    fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
85        &'a self,
86        idx: I,
87        f: F,
88    ) -> PolarsResult<Self>
89    where
90        F: Fn(Option<T::Native>) -> Option<T::Native>,
91    {
92        let mut builder = PrimitiveChunkedBuilder::<T>::new(self.name().clone(), self.len());
93        impl_scatter_with!(self, builder, idx, f)
94    }
95
96    fn set(&'a self, mask: &BooleanChunked, value: Option<T::Native>) -> PolarsResult<Self> {
97        check_bounds!(self, mask);
98
99        // Fast path uses the kernel in polars-arrow.
100        if let (Some(value), false) = (value, mask.has_nulls()) {
101            let (left, mask) = align_chunks_binary(self, mask);
102
103            // Apply binary kernel.
104            let chunks = left
105                .downcast_iter()
106                .zip(mask.downcast_iter())
107                .map(|(arr, mask)| {
108                    set_with_mask(
109                        arr,
110                        mask,
111                        value,
112                        T::get_dtype().to_arrow(CompatLevel::newest()),
113                    )
114                });
115            Ok(ChunkedArray::from_chunk_iter(self.name().clone(), chunks))
116        } else {
117            let mask = mask.rechunk();
118            let mask = mask.downcast_as_array();
119            let mask = mask.true_and_valid();
120            let iter = mask.true_idx_iter();
121            self.scatter_single(iter.map(|v| v as IdxSize), value)
122        }
123    }
124}
125
126impl<'a> ChunkSet<'a, bool, bool> for BooleanChunked {
127    fn scatter_single<I: IntoIterator<Item = IdxSize>>(
128        &'a self,
129        idx: I,
130        value: Option<bool>,
131    ) -> PolarsResult<Self> {
132        self.scatter_with(idx, |_| value)
133    }
134
135    fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
136        &'a self,
137        idx: I,
138        f: F,
139    ) -> PolarsResult<Self>
140    where
141        F: Fn(Option<bool>) -> Option<bool>,
142    {
143        let mut values = MutableBitmap::with_capacity(self.len());
144        let mut validity = MutableBitmap::with_capacity(self.len());
145
146        for a in self.downcast_iter() {
147            values.extend_from_bitmap(a.values());
148            if let Some(v) = a.validity() {
149                validity.extend_from_bitmap(v)
150            } else {
151                validity.extend_constant(a.len(), true);
152            }
153        }
154
155        for i in idx.into_iter().map(|i| i as usize) {
156            let input = validity.get(i).then(|| values.get(i));
157
158            match f(input) {
159                Some(v) => {
160                    values.set(i, v);
161                    validity.set(i, true);
162                },
163                None => {
164                    validity.set(i, false);
165                },
166            }
167        }
168        let validity: Bitmap = validity.into();
169        let validity = if validity.unset_bits() > 0 {
170            Some(validity)
171        } else {
172            None
173        };
174
175        let arr = BooleanArray::from_data_default(values.into(), validity);
176        Ok(BooleanChunked::with_chunk(self.name().clone(), arr))
177    }
178
179    fn set(&'a self, mask: &BooleanChunked, value: Option<bool>) -> PolarsResult<Self> {
180        let mask = mask.rechunk();
181        let mask = mask.downcast_as_array();
182        let mask = mask.true_and_valid();
183        let iter = mask.true_idx_iter();
184        self.scatter_single(iter.map(|v| v as IdxSize), value)
185    }
186}
187
188impl<'a> ChunkSet<'a, &'a str, String> for StringChunked {
189    fn scatter_single<I: IntoIterator<Item = IdxSize>>(
190        &'a self,
191        idx: I,
192        opt_value: Option<&'a str>,
193    ) -> PolarsResult<Self>
194    where
195        Self: Sized,
196    {
197        let idx_iter = idx.into_iter();
198        let mut ca_iter = self.into_iter().enumerate();
199        let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len());
200
201        for current_idx in idx_iter.into_iter().map(|i| i as usize) {
202            polars_ensure!(current_idx < self.len(), oob = current_idx, self.len());
203            for (cnt_idx, opt_val_self) in &mut ca_iter {
204                if cnt_idx == current_idx {
205                    builder.append_option(opt_value);
206                    break;
207                } else {
208                    builder.append_option(opt_val_self);
209                }
210            }
211        }
212        // the last idx is probably not the last value so we finish the iterator
213        for (_, opt_val_self) in ca_iter {
214            builder.append_option(opt_val_self);
215        }
216
217        let ca = builder.finish();
218        Ok(ca)
219    }
220
221    fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
222        &'a self,
223        idx: I,
224        f: F,
225    ) -> PolarsResult<Self>
226    where
227        Self: Sized,
228        F: Fn(Option<&'a str>) -> Option<String>,
229    {
230        let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len());
231        impl_scatter_with!(self, builder, idx, f)
232    }
233
234    fn set(&'a self, mask: &BooleanChunked, value: Option<&'a str>) -> PolarsResult<Self>
235    where
236        Self: Sized,
237    {
238        check_bounds!(self, mask);
239        let ca = mask
240            .into_iter()
241            .zip(self)
242            .map(|(mask_val, opt_val)| match mask_val {
243                Some(true) => value,
244                _ => opt_val,
245            })
246            .collect_trusted::<Self>()
247            .with_name(self.name().clone());
248        Ok(ca)
249    }
250}
251
252impl<'a> ChunkSet<'a, &'a [u8], Vec<u8>> for BinaryChunked {
253    fn scatter_single<I: IntoIterator<Item = IdxSize>>(
254        &'a self,
255        idx: I,
256        opt_value: Option<&'a [u8]>,
257    ) -> PolarsResult<Self>
258    where
259        Self: Sized,
260    {
261        let mut ca_iter = self.into_iter().enumerate();
262        let mut builder = BinaryChunkedBuilder::new(self.name().clone(), self.len());
263
264        for current_idx in idx.into_iter().map(|i| i as usize) {
265            polars_ensure!(current_idx < self.len(), oob = current_idx, self.len());
266            for (cnt_idx, opt_val_self) in &mut ca_iter {
267                if cnt_idx == current_idx {
268                    builder.append_option(opt_value);
269                    break;
270                } else {
271                    builder.append_option(opt_val_self);
272                }
273            }
274        }
275        // the last idx is probably not the last value so we finish the iterator
276        for (_, opt_val_self) in ca_iter {
277            builder.append_option(opt_val_self);
278        }
279
280        let ca = builder.finish();
281        Ok(ca)
282    }
283
284    fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
285        &'a self,
286        idx: I,
287        f: F,
288    ) -> PolarsResult<Self>
289    where
290        Self: Sized,
291        F: Fn(Option<&'a [u8]>) -> Option<Vec<u8>>,
292    {
293        let mut builder = BinaryChunkedBuilder::new(self.name().clone(), self.len());
294        impl_scatter_with!(self, builder, idx, f)
295    }
296
297    fn set(&'a self, mask: &BooleanChunked, value: Option<&'a [u8]>) -> PolarsResult<Self>
298    where
299        Self: Sized,
300    {
301        check_bounds!(self, mask);
302        let ca = mask
303            .into_iter()
304            .zip(self)
305            .map(|(mask_val, opt_val)| match mask_val {
306                Some(true) => value,
307                _ => opt_val,
308            })
309            .collect_trusted::<Self>()
310            .with_name(self.name().clone());
311        Ok(ca)
312    }
313}
314
315#[cfg(test)]
316mod test {
317    use crate::prelude::*;
318
319    #[test]
320    fn test_set() {
321        let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
322        let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]);
323        let ca = ca.set(&mask, Some(5)).unwrap();
324        assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]);
325
326        let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
327        let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[None, Some(true), None]);
328        let ca = ca.set(&mask, Some(5)).unwrap();
329        assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]);
330
331        let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
332        let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[None, None, None]);
333        let ca = ca.set(&mask, Some(5)).unwrap();
334        assert_eq!(Vec::from(&ca), &[Some(1), Some(2), Some(3)]);
335
336        let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
337        let mask = BooleanChunked::new(
338            PlSmallStr::from_static("mask"),
339            &[Some(true), Some(false), None],
340        );
341        let ca = ca.set(&mask, Some(5)).unwrap();
342        assert_eq!(Vec::from(&ca), &[Some(5), Some(2), Some(3)]);
343
344        let ca = ca.scatter_single(vec![0, 1], Some(10)).unwrap();
345        assert_eq!(Vec::from(&ca), &[Some(10), Some(10), Some(3)]);
346
347        assert!(ca.scatter_single(vec![0, 10], Some(0)).is_err());
348
349        // test booleans
350        let ca = BooleanChunked::new(PlSmallStr::from_static("a"), &[true, true, true]);
351        let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]);
352        let ca = ca.set(&mask, None).unwrap();
353        assert_eq!(Vec::from(&ca), &[Some(true), None, Some(true)]);
354
355        // test string
356        let ca = StringChunked::new(PlSmallStr::from_static("a"), &["foo", "foo", "foo"]);
357        let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]);
358        let ca = ca.set(&mask, Some("bar")).unwrap();
359        assert_eq!(Vec::from(&ca), &[Some("foo"), Some("bar"), Some("foo")]);
360    }
361
362    #[test]
363    fn test_set_null_values() {
364        let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[Some(1), None, Some(3)]);
365        let mask = BooleanChunked::new(
366            PlSmallStr::from_static("mask"),
367            &[Some(false), Some(true), None],
368        );
369        let ca = ca.set(&mask, Some(2)).unwrap();
370        assert_eq!(Vec::from(&ca), &[Some(1), Some(2), Some(3)]);
371
372        let ca = StringChunked::new(
373            PlSmallStr::from_static("a"),
374            &[Some("foo"), None, Some("bar")],
375        );
376        let ca = ca.set(&mask, Some("foo")).unwrap();
377        assert_eq!(Vec::from(&ca), &[Some("foo"), Some("foo"), Some("bar")]);
378
379        let ca = BooleanChunked::new(
380            PlSmallStr::from_static("a"),
381            &[Some(false), None, Some(true)],
382        );
383        let ca = ca.set(&mask, Some(true)).unwrap();
384        assert_eq!(Vec::from(&ca), &[Some(false), Some(true), Some(true)]);
385    }
386}