polars_ops/chunked_array/
scatter.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::array::{Array, PrimitiveArray};
3use polars_core::prelude::*;
4use polars_core::series::IsSorted;
5use polars_core::utils::arrow::bitmap::MutableBitmap;
6use polars_core::utils::arrow::types::NativeType;
7use polars_utils::index::check_bounds;
8
9pub trait ChunkedSet<T: Copy> {
10 fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
13 where
14 V: IntoIterator<Item = Option<T>>;
15}
16fn check_sorted(idx: &[IdxSize]) -> PolarsResult<()> {
17 if idx.is_empty() {
18 return Ok(());
19 }
20 let mut sorted = true;
21 let mut previous = idx[0];
22 for &i in &idx[1..] {
23 if i < previous {
24 sorted = false;
26 }
27 previous = i;
28 }
29 polars_ensure!(sorted, ComputeError: "set indices must be sorted");
30 Ok(())
31}
32
33trait PolarsOpsNumericType: PolarsNumericType {}
34
35impl PolarsOpsNumericType for UInt8Type {}
36impl PolarsOpsNumericType for UInt16Type {}
37impl PolarsOpsNumericType for UInt32Type {}
38impl PolarsOpsNumericType for UInt64Type {}
39#[cfg(feature = "dtype-u128")]
40impl PolarsOpsNumericType for UInt128Type {}
41impl PolarsOpsNumericType for Int8Type {}
42impl PolarsOpsNumericType for Int16Type {}
43impl PolarsOpsNumericType for Int32Type {}
44impl PolarsOpsNumericType for Int64Type {}
45#[cfg(feature = "dtype-i128")]
46impl PolarsOpsNumericType for Int128Type {}
47impl PolarsOpsNumericType for Float32Type {}
48impl PolarsOpsNumericType for Float64Type {}
49
50unsafe fn scatter_impl<V, T: NativeType>(
51 new_values_slice: &mut [T],
52 set_values: V,
53 arr: &mut PrimitiveArray<T>,
54 idx: &[IdxSize],
55 len: usize,
56) where
57 V: IntoIterator<Item = Option<T>>,
58{
59 let mut values_iter = set_values.into_iter();
60
61 if arr.null_count() > 0 {
62 arr.apply_validity(|v| {
63 let mut mut_validity = v.make_mut();
64
65 for (idx, val) in idx.iter().zip(&mut values_iter) {
66 match val {
67 Some(value) => {
68 mut_validity.set_unchecked(*idx as usize, true);
69 *new_values_slice.get_unchecked_mut(*idx as usize) = value
70 },
71 None => mut_validity.set_unchecked(*idx as usize, false),
72 }
73 }
74 mut_validity.into()
75 })
76 } else {
77 let mut null_idx = vec![];
78 for (idx, val) in idx.iter().zip(values_iter) {
79 match val {
80 Some(value) => *new_values_slice.get_unchecked_mut(*idx as usize) = value,
81 None => {
82 null_idx.push(*idx);
83 },
84 }
85 }
86 if !null_idx.is_empty() {
88 let mut validity = MutableBitmap::with_capacity(len);
89 validity.extend_constant(len, true);
90 for idx in null_idx {
91 validity.set_unchecked(idx as usize, false)
92 }
93 arr.set_validity(Some(validity.into()))
94 }
95 }
96}
97
98impl<T: PolarsOpsNumericType> ChunkedSet<T::Native> for &mut ChunkedArray<T> {
99 fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
100 where
101 V: IntoIterator<Item = Option<T::Native>>,
102 {
103 check_bounds(idx, self.len() as IdxSize)?;
104 let mut ca = std::mem::take(self);
105 ca.rechunk_mut();
106
107 ca.set_sorted_flag(IsSorted::Not);
111 let arr = unsafe { ca.downcast_iter_mut() }.next().unwrap();
112 let len = arr.len();
113
114 match arr.get_mut_values() {
115 Some(current_values) => {
116 let ptr = current_values.as_mut_ptr();
117
118 let current_values = unsafe { &mut *std::slice::from_raw_parts_mut(ptr, len) };
120 unsafe { scatter_impl(current_values, values, arr, idx, len) };
123 },
124 None => {
125 let mut new_values = arr.values().as_slice().to_vec();
126 unsafe { scatter_impl(&mut new_values, values, arr, idx, len) };
129 arr.set_values(new_values.into());
130 },
131 };
132
133 let new_null_count = arr.null_count();
135 unsafe { ca.set_null_count(new_null_count) };
136
137 Ok(ca.into_series())
138 }
139}
140
141impl<'a> ChunkedSet<&'a str> for &'a StringChunked {
142 fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
143 where
144 V: IntoIterator<Item = Option<&'a str>>,
145 {
146 check_bounds(idx, self.len() as IdxSize)?;
147 check_sorted(idx)?;
148 let mut ca_iter = self.into_iter().enumerate();
149 let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len());
150
151 for (current_idx, current_value) in idx.iter().zip(values) {
152 for (cnt_idx, opt_val_self) in &mut ca_iter {
153 if cnt_idx == *current_idx as usize {
154 builder.append_option(current_value);
155 break;
156 } else {
157 builder.append_option(opt_val_self);
158 }
159 }
160 }
161 for (_, opt_val_self) in ca_iter {
163 builder.append_option(opt_val_self);
164 }
165
166 let ca = builder.finish();
167 Ok(ca.into_series())
168 }
169}
170impl ChunkedSet<bool> for &BooleanChunked {
171 fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
172 where
173 V: IntoIterator<Item = Option<bool>>,
174 {
175 check_bounds(idx, self.len() as IdxSize)?;
176 check_sorted(idx)?;
177 let mut ca_iter = self.into_iter().enumerate();
178 let mut builder = BooleanChunkedBuilder::new(self.name().clone(), self.len());
179
180 for (current_idx, current_value) in idx.iter().zip(values) {
181 for (cnt_idx, opt_val_self) in &mut ca_iter {
182 if cnt_idx == *current_idx as usize {
183 builder.append_option(current_value);
184 break;
185 } else {
186 builder.append_option(opt_val_self);
187 }
188 }
189 }
190 for (_, opt_val_self) in ca_iter {
192 builder.append_option(opt_val_self);
193 }
194
195 let ca = builder.finish();
196 Ok(ca.into_series())
197 }
198}