polars_core/chunked_array/ops/
rolling_window.rs

1use std::hash::{Hash, Hasher};
2
3use polars_compute::rolling::RollingFnParams;
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Debug)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[cfg_attr(feature = "rolling_window", derive(PartialEq))]
10pub struct RollingOptionsFixedWindow {
11    /// The length of the window.
12    pub window_size: usize,
13    /// Amount of elements in the window that should be filled before computing a result.
14    pub min_periods: usize,
15    /// An optional slice with the same length as the window that will be multiplied
16    ///              elementwise with the values in the window.
17    pub weights: Option<Vec<f64>>,
18    /// Set the labels at the center of the window.
19    pub center: bool,
20    /// Optional parameters for the rolling
21    #[cfg_attr(feature = "serde", serde(default))]
22    pub fn_params: Option<RollingFnParams>,
23}
24
25impl Hash for RollingOptionsFixedWindow {
26    fn hash<H: Hasher>(&self, state: &mut H) {
27        self.window_size.hash(state);
28        self.min_periods.hash(state);
29        self.center.hash(state);
30        self.weights.is_some().hash(state);
31    }
32}
33
34impl Default for RollingOptionsFixedWindow {
35    fn default() -> Self {
36        RollingOptionsFixedWindow {
37            window_size: 3,
38            min_periods: 1,
39            weights: None,
40            center: false,
41            fn_params: None,
42        }
43    }
44}
45
46#[cfg(feature = "rolling_window")]
47mod inner_mod {
48    use std::ops::SubAssign;
49
50    use arrow::bitmap::MutableBitmap;
51    use arrow::bitmap::utils::set_bit_unchecked;
52    use arrow::legacy::trusted_len::TrustedLenPush;
53    use num_traits::pow::Pow;
54    use num_traits::{Float, Zero};
55    use polars_utils::float::IsFloat;
56
57    use crate::chunked_array::cast::CastOptions;
58    use crate::prelude::*;
59
60    /// utility
61    fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> {
62        polars_ensure!(
63            min_periods <= window_size,
64            ComputeError: "`window_size`: {} should be >= `min_periods`: {}",
65            window_size, min_periods
66        );
67        Ok(())
68    }
69
70    /// utility
71    fn window_edges(idx: usize, len: usize, window_size: usize, center: bool) -> (usize, usize) {
72        let (start, end) = if center {
73            let right_window = window_size.div_ceil(2);
74            (
75                idx.saturating_sub(window_size - right_window),
76                len.min(idx + right_window),
77            )
78        } else {
79            (idx.saturating_sub(window_size - 1), idx + 1)
80        };
81
82        (start, end - start)
83    }
84
85    impl<T> ChunkRollApply for ChunkedArray<T>
86    where
87        T: PolarsNumericType,
88        Self: IntoSeries,
89    {
90        /// Apply a rolling custom function. This is pretty slow because of dynamic dispatch.
91        fn rolling_map(
92            &self,
93            f: &dyn Fn(&Series) -> Series,
94            mut options: RollingOptionsFixedWindow,
95        ) -> PolarsResult<Series> {
96            check_input(options.window_size, options.min_periods)?;
97
98            let ca = self.rechunk();
99            if options.weights.is_some()
100                && !matches!(self.dtype(), DataType::Float64 | DataType::Float32)
101            {
102                let s = self.cast_with_options(&DataType::Float64, CastOptions::NonStrict)?;
103                return s.rolling_map(f, options);
104            }
105
106            options.window_size = std::cmp::min(self.len(), options.window_size);
107
108            let len = self.len();
109            let arr = ca.downcast_as_array();
110            let mut ca = ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
111            let ptr = ca.chunks[0].as_mut() as *mut dyn Array as *mut PrimitiveArray<T::Native>;
112            let mut series_container = ca.into_series();
113
114            let mut builder = PrimitiveChunkedBuilder::<T>::new(self.name().clone(), self.len());
115
116            if let Some(weights) = options.weights {
117                let weights_series =
118                    Float64Chunked::new(PlSmallStr::from_static("weights"), &weights).into_series();
119
120                let weights_series = weights_series.cast(self.dtype()).unwrap();
121
122                for idx in 0..len {
123                    let (start, size) = window_edges(idx, len, options.window_size, options.center);
124
125                    if size < options.min_periods {
126                        builder.append_null();
127                    } else {
128                        // SAFETY:
129                        // we are in bounds
130                        let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
131
132                        // ensure we still meet window size criteria after removing null values
133                        if size - arr_window.null_count() < options.min_periods {
134                            builder.append_null();
135                            continue;
136                        }
137
138                        // SAFETY.
139                        // ptr is not dropped as we are in scope
140                        // We are also the only owner of the contents of the Arc
141                        // we do this to reduce heap allocs.
142                        unsafe {
143                            *ptr = arr_window;
144                        }
145                        // reset flags as we reuse this container
146                        series_container.clear_flags();
147                        // ensure the length is correct
148                        series_container._get_inner_mut().compute_len();
149                        let s = if size == options.window_size {
150                            f(&series_container.multiply(&weights_series).unwrap())
151                        } else {
152                            let weights_cutoff: Series = match self.dtype() {
153                                DataType::Float64 => weights_series
154                                    .f64()
155                                    .unwrap()
156                                    .into_iter()
157                                    .take(series_container.len())
158                                    .collect(),
159                                _ => weights_series // Float32 case
160                                    .f32()
161                                    .unwrap()
162                                    .into_iter()
163                                    .take(series_container.len())
164                                    .collect(),
165                            };
166                            f(&series_container.multiply(&weights_cutoff).unwrap())
167                        };
168
169                        let out = self.unpack_series_matching_type(&s)?;
170                        builder.append_option(out.get(0));
171                    }
172                }
173
174                Ok(builder.finish().into_series())
175            } else {
176                for idx in 0..len {
177                    let (start, size) = window_edges(idx, len, options.window_size, options.center);
178
179                    if size < options.min_periods {
180                        builder.append_null();
181                    } else {
182                        // SAFETY:
183                        // we are in bounds
184                        let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
185
186                        // ensure we still meet window size criteria after removing null values
187                        if size - arr_window.null_count() < options.min_periods {
188                            builder.append_null();
189                            continue;
190                        }
191
192                        // SAFETY.
193                        // ptr is not dropped as we are in scope
194                        // We are also the only owner of the contents of the Arc
195                        // we do this to reduce heap allocs.
196                        unsafe {
197                            *ptr = arr_window;
198                        }
199                        // reset flags as we reuse this container
200                        series_container.clear_flags();
201                        // ensure the length is correct
202                        series_container._get_inner_mut().compute_len();
203                        let s = f(&series_container);
204                        let out = self.unpack_series_matching_type(&s)?;
205                        builder.append_option(out.get(0));
206                    }
207                }
208
209                Ok(builder.finish().into_series())
210            }
211        }
212    }
213
214    impl<T> ChunkedArray<T>
215    where
216        ChunkedArray<T>: IntoSeries,
217        T: PolarsFloatType,
218        T::Native: Float + IsFloat + SubAssign + Pow<T::Native, Output = T::Native>,
219    {
220        /// Apply a rolling custom function. This is pretty slow because of dynamic dispatch.
221        pub fn rolling_map_float<F>(&self, window_size: usize, mut f: F) -> PolarsResult<Self>
222        where
223            F: FnMut(&mut ChunkedArray<T>) -> Option<T::Native>,
224        {
225            if window_size > self.len() {
226                return Ok(Self::full_null(self.name().clone(), self.len()));
227            }
228            let ca = self.rechunk();
229            let arr = ca.downcast_as_array();
230
231            // We create a temporary dummy ChunkedArray. This will be a
232            // container where we swap the window contents every iteration doing
233            // so will save a lot of heap allocations.
234            let mut heap_container =
235                ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
236            let ptr = heap_container.chunks[0].as_mut() as *mut dyn Array
237                as *mut PrimitiveArray<T::Native>;
238
239            let mut validity = MutableBitmap::with_capacity(ca.len());
240            validity.extend_constant(window_size - 1, false);
241            validity.extend_constant(ca.len() - (window_size - 1), true);
242            let validity_slice = validity.as_mut_slice();
243
244            let mut values = Vec::with_capacity(ca.len());
245            values.extend(std::iter::repeat_n(T::Native::default(), window_size - 1));
246
247            for offset in 0..self.len() + 1 - window_size {
248                debug_assert!(offset + window_size <= arr.len());
249                let arr_window = unsafe { arr.slice_typed_unchecked(offset, window_size) };
250                // The lengths are cached, so we must update them.
251                heap_container.length = arr_window.len();
252
253                // SAFETY: ptr is not dropped as we are in scope. We are also the only
254                // owner of the contents of the Arc (we do this to reduce heap allocs).
255                unsafe {
256                    *ptr = arr_window;
257                }
258
259                let out = f(&mut heap_container);
260                match out {
261                    Some(v) => {
262                        // SAFETY: we have pre-allocated.
263                        unsafe { values.push_unchecked(v) }
264                    },
265                    None => {
266                        // SAFETY: we allocated enough for both the `values` vec
267                        // and the `validity_ptr`.
268                        unsafe {
269                            values.push_unchecked(T::Native::default());
270                            set_bit_unchecked(validity_slice, offset + window_size - 1, false);
271                        }
272                    },
273                }
274            }
275            let arr = PrimitiveArray::new(
276                T::get_static_dtype().to_arrow(CompatLevel::newest()),
277                values.into(),
278                Some(validity.into()),
279            );
280            Ok(Self::with_chunk(self.name().clone(), arr))
281        }
282    }
283}