1use arrow::bitmap::{Bitmap, MutableBitmap};
2use arrow::legacy::kernels::set::{scatter_single_non_null, set_with_mask};
3
4use crate::prelude::*;
5use crate::utils::align_chunks_binary;
6
7macro_rules! impl_scatter_with {
8 ($self:ident, $builder:ident, $idx:ident, $f:ident) => {{
9 let mut ca_iter = $self.into_iter().enumerate();
10
11 for current_idx in $idx.into_iter().map(|i| i as usize) {
12 polars_ensure!(current_idx < $self.len(), oob = current_idx, $self.len());
13 while let Some((cnt_idx, opt_val)) = ca_iter.next() {
14 if cnt_idx == current_idx {
15 $builder.append_option($f(opt_val));
16 break;
17 } else {
18 $builder.append_option(opt_val);
19 }
20 }
21 }
22 while let Some((_, opt_val)) = ca_iter.next() {
24 $builder.append_option(opt_val);
25 }
26
27 let ca = $builder.finish();
28 Ok(ca)
29 }};
30}
31
32macro_rules! check_bounds {
33 ($self:ident, $mask:ident) => {{
34 polars_ensure!(
35 $self.len() == $mask.len(),
36 ShapeMismatch: "invalid mask in `get` operation: shape doesn't match array's shape"
37 );
38 }};
39}
40
41impl<'a, T> ChunkSet<'a, T::Native, T::Native> for ChunkedArray<T>
42where
43 T: PolarsNumericType,
44{
45 fn scatter_single<I: IntoIterator<Item = IdxSize>>(
46 &'a self,
47 idx: I,
48 value: Option<T::Native>,
49 ) -> PolarsResult<Self> {
50 if !self.has_nulls() {
51 if let Some(value) = value {
52 if self.chunks.len() == 1 {
54 let arr = scatter_single_non_null(
55 self.downcast_iter().next().unwrap(),
56 idx,
57 value,
58 T::get_dtype().to_arrow(CompatLevel::newest()),
59 )?;
60 return Ok(Self::with_chunk(self.name().clone(), arr));
61 }
62 else {
64 let mut av = Vec::with_capacity(self.len());
65 for chunk in self.downcast_iter() {
66 av.extend_from_slice(chunk.values())
67 }
68 let data = av.as_mut_slice();
69
70 idx.into_iter().try_for_each::<_, PolarsResult<_>>(|idx| {
71 let val = data
72 .get_mut(idx as usize)
73 .ok_or_else(|| polars_err!(oob = idx as usize, self.len()))?;
74 *val = value;
75 Ok(())
76 })?;
77 return Ok(Self::from_vec(self.name().clone(), av));
78 }
79 }
80 }
81 self.scatter_with(idx, |_| value)
82 }
83
84 fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
85 &'a self,
86 idx: I,
87 f: F,
88 ) -> PolarsResult<Self>
89 where
90 F: Fn(Option<T::Native>) -> Option<T::Native>,
91 {
92 let mut builder = PrimitiveChunkedBuilder::<T>::new(self.name().clone(), self.len());
93 impl_scatter_with!(self, builder, idx, f)
94 }
95
96 fn set(&'a self, mask: &BooleanChunked, value: Option<T::Native>) -> PolarsResult<Self> {
97 check_bounds!(self, mask);
98
99 if let (Some(value), false) = (value, mask.has_nulls()) {
101 let (left, mask) = align_chunks_binary(self, mask);
102
103 let chunks = left
105 .downcast_iter()
106 .zip(mask.downcast_iter())
107 .map(|(arr, mask)| {
108 set_with_mask(
109 arr,
110 mask,
111 value,
112 T::get_dtype().to_arrow(CompatLevel::newest()),
113 )
114 });
115 Ok(ChunkedArray::from_chunk_iter(self.name().clone(), chunks))
116 } else {
117 let mask = mask.rechunk();
118 let mask = mask.downcast_as_array();
119 let mask = mask.true_and_valid();
120 let iter = mask.true_idx_iter();
121 self.scatter_single(iter.map(|v| v as IdxSize), value)
122 }
123 }
124}
125
126impl<'a> ChunkSet<'a, bool, bool> for BooleanChunked {
127 fn scatter_single<I: IntoIterator<Item = IdxSize>>(
128 &'a self,
129 idx: I,
130 value: Option<bool>,
131 ) -> PolarsResult<Self> {
132 self.scatter_with(idx, |_| value)
133 }
134
135 fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
136 &'a self,
137 idx: I,
138 f: F,
139 ) -> PolarsResult<Self>
140 where
141 F: Fn(Option<bool>) -> Option<bool>,
142 {
143 let mut values = MutableBitmap::with_capacity(self.len());
144 let mut validity = MutableBitmap::with_capacity(self.len());
145
146 for a in self.downcast_iter() {
147 values.extend_from_bitmap(a.values());
148 if let Some(v) = a.validity() {
149 validity.extend_from_bitmap(v)
150 } else {
151 validity.extend_constant(a.len(), true);
152 }
153 }
154
155 for i in idx.into_iter().map(|i| i as usize) {
156 let input = validity.get(i).then(|| values.get(i));
157
158 match f(input) {
159 Some(v) => {
160 values.set(i, v);
161 validity.set(i, true);
162 },
163 None => {
164 validity.set(i, false);
165 },
166 }
167 }
168 let validity: Bitmap = validity.into();
169 let validity = if validity.unset_bits() > 0 {
170 Some(validity)
171 } else {
172 None
173 };
174
175 let arr = BooleanArray::from_data_default(values.into(), validity);
176 Ok(BooleanChunked::with_chunk(self.name().clone(), arr))
177 }
178
179 fn set(&'a self, mask: &BooleanChunked, value: Option<bool>) -> PolarsResult<Self> {
180 let mask = mask.rechunk();
181 let mask = mask.downcast_as_array();
182 let mask = mask.true_and_valid();
183 let iter = mask.true_idx_iter();
184 self.scatter_single(iter.map(|v| v as IdxSize), value)
185 }
186}
187
188impl<'a> ChunkSet<'a, &'a str, String> for StringChunked {
189 fn scatter_single<I: IntoIterator<Item = IdxSize>>(
190 &'a self,
191 idx: I,
192 opt_value: Option<&'a str>,
193 ) -> PolarsResult<Self>
194 where
195 Self: Sized,
196 {
197 let idx_iter = idx.into_iter();
198 let mut ca_iter = self.into_iter().enumerate();
199 let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len());
200
201 for current_idx in idx_iter.into_iter().map(|i| i as usize) {
202 polars_ensure!(current_idx < self.len(), oob = current_idx, self.len());
203 for (cnt_idx, opt_val_self) in &mut ca_iter {
204 if cnt_idx == current_idx {
205 builder.append_option(opt_value);
206 break;
207 } else {
208 builder.append_option(opt_val_self);
209 }
210 }
211 }
212 for (_, opt_val_self) in ca_iter {
214 builder.append_option(opt_val_self);
215 }
216
217 let ca = builder.finish();
218 Ok(ca)
219 }
220
221 fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
222 &'a self,
223 idx: I,
224 f: F,
225 ) -> PolarsResult<Self>
226 where
227 Self: Sized,
228 F: Fn(Option<&'a str>) -> Option<String>,
229 {
230 let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len());
231 impl_scatter_with!(self, builder, idx, f)
232 }
233
234 fn set(&'a self, mask: &BooleanChunked, value: Option<&'a str>) -> PolarsResult<Self>
235 where
236 Self: Sized,
237 {
238 check_bounds!(self, mask);
239 let ca = mask
240 .into_iter()
241 .zip(self)
242 .map(|(mask_val, opt_val)| match mask_val {
243 Some(true) => value,
244 _ => opt_val,
245 })
246 .collect_trusted::<Self>()
247 .with_name(self.name().clone());
248 Ok(ca)
249 }
250}
251
252impl<'a> ChunkSet<'a, &'a [u8], Vec<u8>> for BinaryChunked {
253 fn scatter_single<I: IntoIterator<Item = IdxSize>>(
254 &'a self,
255 idx: I,
256 opt_value: Option<&'a [u8]>,
257 ) -> PolarsResult<Self>
258 where
259 Self: Sized,
260 {
261 let mut ca_iter = self.into_iter().enumerate();
262 let mut builder = BinaryChunkedBuilder::new(self.name().clone(), self.len());
263
264 for current_idx in idx.into_iter().map(|i| i as usize) {
265 polars_ensure!(current_idx < self.len(), oob = current_idx, self.len());
266 for (cnt_idx, opt_val_self) in &mut ca_iter {
267 if cnt_idx == current_idx {
268 builder.append_option(opt_value);
269 break;
270 } else {
271 builder.append_option(opt_val_self);
272 }
273 }
274 }
275 for (_, opt_val_self) in ca_iter {
277 builder.append_option(opt_val_self);
278 }
279
280 let ca = builder.finish();
281 Ok(ca)
282 }
283
284 fn scatter_with<I: IntoIterator<Item = IdxSize>, F>(
285 &'a self,
286 idx: I,
287 f: F,
288 ) -> PolarsResult<Self>
289 where
290 Self: Sized,
291 F: Fn(Option<&'a [u8]>) -> Option<Vec<u8>>,
292 {
293 let mut builder = BinaryChunkedBuilder::new(self.name().clone(), self.len());
294 impl_scatter_with!(self, builder, idx, f)
295 }
296
297 fn set(&'a self, mask: &BooleanChunked, value: Option<&'a [u8]>) -> PolarsResult<Self>
298 where
299 Self: Sized,
300 {
301 check_bounds!(self, mask);
302 let ca = mask
303 .into_iter()
304 .zip(self)
305 .map(|(mask_val, opt_val)| match mask_val {
306 Some(true) => value,
307 _ => opt_val,
308 })
309 .collect_trusted::<Self>()
310 .with_name(self.name().clone());
311 Ok(ca)
312 }
313}
314
315#[cfg(test)]
316mod test {
317 use crate::prelude::*;
318
319 #[test]
320 fn test_set() {
321 let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
322 let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]);
323 let ca = ca.set(&mask, Some(5)).unwrap();
324 assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]);
325
326 let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
327 let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[None, Some(true), None]);
328 let ca = ca.set(&mask, Some(5)).unwrap();
329 assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]);
330
331 let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
332 let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[None, None, None]);
333 let ca = ca.set(&mask, Some(5)).unwrap();
334 assert_eq!(Vec::from(&ca), &[Some(1), Some(2), Some(3)]);
335
336 let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]);
337 let mask = BooleanChunked::new(
338 PlSmallStr::from_static("mask"),
339 &[Some(true), Some(false), None],
340 );
341 let ca = ca.set(&mask, Some(5)).unwrap();
342 assert_eq!(Vec::from(&ca), &[Some(5), Some(2), Some(3)]);
343
344 let ca = ca.scatter_single(vec![0, 1], Some(10)).unwrap();
345 assert_eq!(Vec::from(&ca), &[Some(10), Some(10), Some(3)]);
346
347 assert!(ca.scatter_single(vec![0, 10], Some(0)).is_err());
348
349 let ca = BooleanChunked::new(PlSmallStr::from_static("a"), &[true, true, true]);
351 let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]);
352 let ca = ca.set(&mask, None).unwrap();
353 assert_eq!(Vec::from(&ca), &[Some(true), None, Some(true)]);
354
355 let ca = StringChunked::new(PlSmallStr::from_static("a"), &["foo", "foo", "foo"]);
357 let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]);
358 let ca = ca.set(&mask, Some("bar")).unwrap();
359 assert_eq!(Vec::from(&ca), &[Some("foo"), Some("bar"), Some("foo")]);
360 }
361
362 #[test]
363 fn test_set_null_values() {
364 let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[Some(1), None, Some(3)]);
365 let mask = BooleanChunked::new(
366 PlSmallStr::from_static("mask"),
367 &[Some(false), Some(true), None],
368 );
369 let ca = ca.set(&mask, Some(2)).unwrap();
370 assert_eq!(Vec::from(&ca), &[Some(1), Some(2), Some(3)]);
371
372 let ca = StringChunked::new(
373 PlSmallStr::from_static("a"),
374 &[Some("foo"), None, Some("bar")],
375 );
376 let ca = ca.set(&mask, Some("foo")).unwrap();
377 assert_eq!(Vec::from(&ca), &[Some("foo"), Some("foo"), Some("bar")]);
378
379 let ca = BooleanChunked::new(
380 PlSmallStr::from_static("a"),
381 &[Some(false), None, Some(true)],
382 );
383 let ca = ca.set(&mask, Some(true)).unwrap();
384 assert_eq!(Vec::from(&ca), &[Some(false), Some(true), Some(true)]);
385 }
386}