polars_ops/chunked_array/
scatter.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::array::{Array, BinaryViewArrayGeneric, BooleanArray, PrimitiveArray, View, ViewType};
3use arrow::buffer::Buffer;
4use polars_core::prelude::*;
5use polars_core::series::IsSorted;
6use polars_core::utils::arrow::bitmap::MutableBitmap;
7use polars_core::utils::arrow::types::NativeType;
8use polars_utils::index::check_bounds;
9
10pub trait ChunkedSet<T: Copy> {
11    /// Invariant for implementations: if the scatter() fails, typically because
12    /// of bad indexes, then self should remain unmodified.
13    fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
14    where
15        V: IntoIterator<Item = Option<T>>;
16}
17
18trait PolarsOpsNumericType: PolarsNumericType {}
19
20impl PolarsOpsNumericType for UInt8Type {}
21impl PolarsOpsNumericType for UInt16Type {}
22impl PolarsOpsNumericType for UInt32Type {}
23impl PolarsOpsNumericType for UInt64Type {}
24#[cfg(feature = "dtype-u128")]
25impl PolarsOpsNumericType for UInt128Type {}
26impl PolarsOpsNumericType for Int8Type {}
27impl PolarsOpsNumericType for Int16Type {}
28impl PolarsOpsNumericType for Int32Type {}
29impl PolarsOpsNumericType for Int64Type {}
30#[cfg(feature = "dtype-i128")]
31impl PolarsOpsNumericType for Int128Type {}
32#[cfg(feature = "dtype-f16")]
33impl PolarsOpsNumericType for Float16Type {}
34impl PolarsOpsNumericType for Float32Type {}
35impl PolarsOpsNumericType for Float64Type {}
36
37unsafe fn scatter_primitive_impl<V, T: NativeType>(
38    set_values: V,
39    arr: &mut PrimitiveArray<T>,
40    idx: &[IdxSize],
41) where
42    V: IntoIterator<Item = Option<T>>,
43{
44    let mut values_iter = set_values.into_iter();
45
46    if let Some(validity) = arr.take_validity() {
47        let mut mut_validity = validity.make_mut();
48        arr.with_values_mut(|cur_values| {
49            for (idx, val) in idx.iter().zip(&mut values_iter) {
50                match val {
51                    Some(value) => {
52                        mut_validity.set_unchecked(*idx as usize, true);
53                        *cur_values.get_unchecked_mut(*idx as usize) = value
54                    },
55                    None => mut_validity.set_unchecked(*idx as usize, false),
56                }
57            }
58        });
59        arr.set_validity(mut_validity.into())
60    } else {
61        let mut null_idx = vec![];
62        arr.with_values_mut(|cur_values| {
63            for (idx, val) in idx.iter().zip(values_iter) {
64                match val {
65                    Some(value) => *cur_values.get_unchecked_mut(*idx as usize) = value,
66                    None => {
67                        null_idx.push(*idx);
68                    },
69                }
70            }
71        });
72
73        // Only make a validity bitmap when null values are set.
74        if !null_idx.is_empty() {
75            let mut validity = MutableBitmap::with_capacity(arr.len());
76            validity.extend_constant(arr.len(), true);
77            for idx in null_idx {
78                validity.set_unchecked(idx as usize, false)
79            }
80            arr.set_validity(Some(validity.into()))
81        }
82    }
83}
84
85unsafe fn scatter_bool_impl<V>(set_values: V, arr: &mut BooleanArray, idx: &[IdxSize])
86where
87    V: IntoIterator<Item = Option<bool>>,
88{
89    let mut values_iter = set_values.into_iter();
90
91    if let Some(validity) = arr.take_validity() {
92        let mut mut_validity = validity.make_mut();
93        arr.apply_values_mut(|cur_values| {
94            for (idx, val) in idx.iter().zip(&mut values_iter) {
95                match val {
96                    Some(value) => {
97                        mut_validity.set_unchecked(*idx as usize, true);
98                        cur_values.set_unchecked(*idx as usize, value);
99                    },
100                    None => mut_validity.set_unchecked(*idx as usize, false),
101                }
102            }
103        });
104        arr.set_validity(mut_validity.into())
105    } else {
106        let mut null_idx = vec![];
107        arr.apply_values_mut(|cur_values| {
108            for (idx, val) in idx.iter().zip(values_iter) {
109                match val {
110                    Some(value) => cur_values.set_unchecked(*idx as usize, value),
111                    None => {
112                        null_idx.push(*idx);
113                    },
114                }
115            }
116        });
117
118        // Only make a validity bitmap when null values are set.
119        if !null_idx.is_empty() {
120            let mut validity = MutableBitmap::with_capacity(arr.len());
121            validity.extend_constant(arr.len(), true);
122            for idx in null_idx {
123                validity.set_unchecked(idx as usize, false)
124            }
125            arr.set_validity(Some(validity.into()))
126        }
127    }
128}
129
130unsafe fn scatter_binview_impl<'a, V, T: ViewType + ?Sized>(
131    set_values: V,
132    arr: &mut BinaryViewArrayGeneric<T>,
133    idx: &[IdxSize],
134) where
135    V: IntoIterator<Item = Option<&'a T>>,
136{
137    let mut values_iter = set_values.into_iter();
138    let buffer_offset = arr.data_buffers().len() as u32;
139    let mut new_buffers = Vec::new();
140
141    if let Some(validity) = arr.take_validity() {
142        let mut mut_validity = validity.make_mut();
143        arr.with_views_mut(|views| {
144            for (idx, val) in idx.iter().zip(&mut values_iter) {
145                if let Some(v) = val {
146                    let view =
147                        View::new_with_buffers(v.to_bytes(), buffer_offset, &mut new_buffers);
148                    *views.get_unchecked_mut(*idx as usize) = view;
149                    mut_validity.set_unchecked(*idx as usize, true);
150                } else {
151                    mut_validity.set_unchecked(*idx as usize, false);
152                }
153            }
154        });
155        arr.set_validity(mut_validity.into())
156    } else {
157        let mut null_idx = vec![];
158        arr.with_views_mut(|views| {
159            for (idx, val) in idx.iter().zip(values_iter) {
160                if let Some(v) = val {
161                    let view =
162                        View::new_with_buffers(v.to_bytes(), buffer_offset, &mut new_buffers);
163                    *views.get_unchecked_mut(*idx as usize) = view;
164                } else {
165                    null_idx.push(*idx);
166                }
167            }
168        });
169
170        // Only make a validity bitmap when null values are set.
171        if !null_idx.is_empty() {
172            let mut validity = MutableBitmap::with_capacity(arr.len());
173            validity.extend_constant(arr.len(), true);
174            for idx in null_idx {
175                validity.set_unchecked(idx as usize, false)
176            }
177            arr.set_validity(Some(validity.into()))
178        }
179    }
180
181    let mut buffers = Buffer::make_mut(core::mem::take(arr.data_buffers_mut()));
182    buffers.extend(new_buffers.into_iter().map(Buffer::from));
183    *arr.data_buffers_mut() = Buffer::from(buffers);
184}
185
186impl<T: PolarsOpsNumericType> ChunkedSet<T::Native> for &mut ChunkedArray<T> {
187    fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
188    where
189        V: IntoIterator<Item = Option<T::Native>>,
190    {
191        check_bounds(idx, self.len() as IdxSize)?;
192        let mut ca = std::mem::take(self);
193
194        // SAFETY: we will not modify the length and we unset the sorted flag,
195        // making sure to update the null count as well.
196        unsafe {
197            ca.rechunk_mut();
198            let arr = ca.downcast_iter_mut().next().unwrap();
199            scatter_primitive_impl(values, arr, idx);
200            let null_count = arr.null_count();
201            ca.set_sorted_flag(IsSorted::Not);
202            ca.set_null_count(null_count);
203        }
204
205        Ok(ca.into_series())
206    }
207}
208
209impl<'a> ChunkedSet<&'a [u8]> for &mut BinaryChunked {
210    fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
211    where
212        V: IntoIterator<Item = Option<&'a [u8]>>,
213    {
214        check_bounds(idx, self.len() as IdxSize)?;
215        let mut ca = std::mem::take(self);
216
217        unsafe {
218            ca.rechunk_mut();
219            let arr = ca.downcast_iter_mut().next().unwrap();
220            scatter_binview_impl(values, arr, idx);
221            let null_count = arr.null_count();
222            ca.set_sorted_flag(IsSorted::Not);
223            ca.set_null_count(null_count);
224        }
225
226        Ok(ca.into_series())
227    }
228}
229
230impl<'a> ChunkedSet<&'a str> for &mut StringChunked {
231    fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
232    where
233        V: IntoIterator<Item = Option<&'a str>>,
234    {
235        check_bounds(idx, self.len() as IdxSize)?;
236        let mut ca = std::mem::take(self);
237
238        unsafe {
239            ca.rechunk_mut();
240            let arr = ca.downcast_iter_mut().next().unwrap();
241            scatter_binview_impl(values, arr, idx);
242            let null_count = arr.null_count();
243            ca.set_sorted_flag(IsSorted::Not);
244            ca.set_null_count(null_count);
245        }
246
247        Ok(ca.into_series())
248    }
249}
250impl ChunkedSet<bool> for &mut BooleanChunked {
251    fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
252    where
253        V: IntoIterator<Item = Option<bool>>,
254    {
255        check_bounds(idx, self.len() as IdxSize)?;
256        let mut ca = std::mem::take(self);
257
258        unsafe {
259            ca.rechunk_mut();
260            let arr = ca.downcast_iter_mut().next().unwrap();
261            scatter_bool_impl(values, arr, idx);
262            let null_count = arr.null_count();
263            ca.set_sorted_flag(IsSorted::Not);
264            ca.set_null_count(null_count);
265        }
266
267        Ok(ca.into_series())
268    }
269}