polars_ops/chunked_array/
scatter.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::array::{Array, PrimitiveArray};
3use polars_core::prelude::*;
4use polars_core::series::IsSorted;
5use polars_core::utils::arrow::bitmap::MutableBitmap;
6use polars_core::utils::arrow::types::NativeType;
7use polars_utils::index::check_bounds;
8
9pub trait ChunkedSet<T: Copy> {
10    /// Invariant for implementations: if the scatter() fails, typically because
11    /// of bad indexes, then self should remain unmodified.
12    fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
13    where
14        V: IntoIterator<Item = Option<T>>;
15}
16fn check_sorted(idx: &[IdxSize]) -> PolarsResult<()> {
17    if idx.is_empty() {
18        return Ok(());
19    }
20    let mut sorted = true;
21    let mut previous = idx[0];
22    for &i in &idx[1..] {
23        if i < previous {
24            // we will not break here as that prevents SIMD
25            sorted = false;
26        }
27        previous = i;
28    }
29    polars_ensure!(sorted, ComputeError: "set indices must be sorted");
30    Ok(())
31}
32
33trait PolarsOpsNumericType: PolarsNumericType {}
34
35impl PolarsOpsNumericType for UInt8Type {}
36impl PolarsOpsNumericType for UInt16Type {}
37impl PolarsOpsNumericType for UInt32Type {}
38impl PolarsOpsNumericType for UInt64Type {}
39#[cfg(feature = "dtype-u128")]
40impl PolarsOpsNumericType for UInt128Type {}
41impl PolarsOpsNumericType for Int8Type {}
42impl PolarsOpsNumericType for Int16Type {}
43impl PolarsOpsNumericType for Int32Type {}
44impl PolarsOpsNumericType for Int64Type {}
45#[cfg(feature = "dtype-i128")]
46impl PolarsOpsNumericType for Int128Type {}
47#[cfg(feature = "dtype-f16")]
48impl PolarsOpsNumericType for Float16Type {}
49impl PolarsOpsNumericType for Float32Type {}
50impl PolarsOpsNumericType for Float64Type {}
51
52unsafe fn scatter_impl<V, T: NativeType>(
53    new_values_slice: &mut [T],
54    set_values: V,
55    arr: &mut PrimitiveArray<T>,
56    idx: &[IdxSize],
57    len: usize,
58) where
59    V: IntoIterator<Item = Option<T>>,
60{
61    let mut values_iter = set_values.into_iter();
62
63    if arr.null_count() > 0 {
64        arr.apply_validity(|v| {
65            let mut mut_validity = v.make_mut();
66
67            for (idx, val) in idx.iter().zip(&mut values_iter) {
68                match val {
69                    Some(value) => {
70                        mut_validity.set_unchecked(*idx as usize, true);
71                        *new_values_slice.get_unchecked_mut(*idx as usize) = value
72                    },
73                    None => mut_validity.set_unchecked(*idx as usize, false),
74                }
75            }
76            mut_validity.into()
77        })
78    } else {
79        let mut null_idx = vec![];
80        for (idx, val) in idx.iter().zip(values_iter) {
81            match val {
82                Some(value) => *new_values_slice.get_unchecked_mut(*idx as usize) = value,
83                None => {
84                    null_idx.push(*idx);
85                },
86            }
87        }
88        // only make a validity bitmap when null values are set
89        if !null_idx.is_empty() {
90            let mut validity = MutableBitmap::with_capacity(len);
91            validity.extend_constant(len, true);
92            for idx in null_idx {
93                validity.set_unchecked(idx as usize, false)
94            }
95            arr.set_validity(Some(validity.into()))
96        }
97    }
98}
99
100impl<T: PolarsOpsNumericType> ChunkedSet<T::Native> for &mut ChunkedArray<T> {
101    fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
102    where
103        V: IntoIterator<Item = Option<T::Native>>,
104    {
105        check_bounds(idx, self.len() as IdxSize)?;
106        let mut ca = std::mem::take(self);
107        ca.rechunk_mut();
108
109        // SAFETY:
110        // we will not modify the length
111        // and we unset the sorted flag.
112        ca.set_sorted_flag(IsSorted::Not);
113        let arr = unsafe { ca.downcast_iter_mut() }.next().unwrap();
114        let len = arr.len();
115
116        match arr.get_mut_values() {
117            Some(current_values) => {
118                let ptr = current_values.as_mut_ptr();
119
120                // reborrow because the bck does not allow it
121                let current_values = unsafe { &mut *std::slice::from_raw_parts_mut(ptr, len) };
122                // SAFETY:
123                // we checked bounds
124                unsafe { scatter_impl(current_values, values, arr, idx, len) };
125            },
126            None => {
127                let mut new_values = arr.values().as_slice().to_vec();
128                // SAFETY:
129                // we checked bounds
130                unsafe { scatter_impl(&mut new_values, values, arr, idx, len) };
131                arr.set_values(new_values.into());
132            },
133        };
134
135        // The null count may have changed - make sure to update the ChunkedArray
136        let new_null_count = arr.null_count();
137        unsafe { ca.set_null_count(new_null_count) };
138
139        Ok(ca.into_series())
140    }
141}
142
143impl<'a> ChunkedSet<&'a str> for &'a StringChunked {
144    fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
145    where
146        V: IntoIterator<Item = Option<&'a str>>,
147    {
148        check_bounds(idx, self.len() as IdxSize)?;
149        check_sorted(idx)?;
150        let mut ca_iter = self.into_iter().enumerate();
151        let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len());
152
153        for (current_idx, current_value) in idx.iter().zip(values) {
154            for (cnt_idx, opt_val_self) in &mut ca_iter {
155                if cnt_idx == *current_idx as usize {
156                    builder.append_option(current_value);
157                    break;
158                } else {
159                    builder.append_option(opt_val_self);
160                }
161            }
162        }
163        // the last idx is probably not the last value so we finish the iterator
164        for (_, opt_val_self) in ca_iter {
165            builder.append_option(opt_val_self);
166        }
167
168        let ca = builder.finish();
169        Ok(ca.into_series())
170    }
171}
172impl ChunkedSet<bool> for &BooleanChunked {
173    fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
174    where
175        V: IntoIterator<Item = Option<bool>>,
176    {
177        check_bounds(idx, self.len() as IdxSize)?;
178        check_sorted(idx)?;
179        let mut ca_iter = self.into_iter().enumerate();
180        let mut builder = BooleanChunkedBuilder::new(self.name().clone(), self.len());
181
182        for (current_idx, current_value) in idx.iter().zip(values) {
183            for (cnt_idx, opt_val_self) in &mut ca_iter {
184                if cnt_idx == *current_idx as usize {
185                    builder.append_option(current_value);
186                    break;
187                } else {
188                    builder.append_option(opt_val_self);
189                }
190            }
191        }
192        // the last idx is probably not the last value so we finish the iterator
193        for (_, opt_val_self) in ca_iter {
194            builder.append_option(opt_val_self);
195        }
196
197        let ca = builder.finish();
198        Ok(ca.into_series())
199    }
200}