polars_core/chunked_array/ops/
rolling_window.rs

1use polars_compute::rolling::RollingFnParams;
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4
5#[derive(Clone, Debug)]
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7#[cfg_attr(feature = "rolling_window", derive(PartialEq))]
8pub struct RollingOptionsFixedWindow {
9    /// The length of the window.
10    pub window_size: usize,
11    /// Amount of elements in the window that should be filled before computing a result.
12    pub min_periods: usize,
13    /// An optional slice with the same length as the window that will be multiplied
14    ///              elementwise with the values in the window.
15    pub weights: Option<Vec<f64>>,
16    /// Set the labels at the center of the window.
17    pub center: bool,
18    /// Optional parameters for the rolling
19    #[cfg_attr(feature = "serde", serde(default))]
20    pub fn_params: Option<RollingFnParams>,
21}
22
23impl Default for RollingOptionsFixedWindow {
24    fn default() -> Self {
25        RollingOptionsFixedWindow {
26            window_size: 3,
27            min_periods: 1,
28            weights: None,
29            center: false,
30            fn_params: None,
31        }
32    }
33}
34
35#[cfg(feature = "rolling_window")]
36mod inner_mod {
37    use std::ops::SubAssign;
38
39    use arrow::bitmap::MutableBitmap;
40    use arrow::bitmap::utils::set_bit_unchecked;
41    use arrow::legacy::trusted_len::TrustedLenPush;
42    use num_traits::pow::Pow;
43    use num_traits::{Float, Zero};
44    use polars_utils::float::IsFloat;
45
46    use crate::chunked_array::cast::CastOptions;
47    use crate::prelude::*;
48
49    /// utility
50    fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> {
51        polars_ensure!(
52            min_periods <= window_size,
53            ComputeError: "`window_size`: {} should be >= `min_periods`: {}",
54            window_size, min_periods
55        );
56        Ok(())
57    }
58
59    /// utility
60    fn window_edges(idx: usize, len: usize, window_size: usize, center: bool) -> (usize, usize) {
61        let (start, end) = if center {
62            let right_window = window_size.div_ceil(2);
63            (
64                idx.saturating_sub(window_size - right_window),
65                len.min(idx + right_window),
66            )
67        } else {
68            (idx.saturating_sub(window_size - 1), idx + 1)
69        };
70
71        (start, end - start)
72    }
73
74    impl<T> ChunkRollApply for ChunkedArray<T>
75    where
76        T: PolarsNumericType,
77        Self: IntoSeries,
78    {
79        /// Apply a rolling custom function. This is pretty slow because of dynamic dispatch.
80        fn rolling_map(
81            &self,
82            f: &dyn Fn(&Series) -> Series,
83            mut options: RollingOptionsFixedWindow,
84        ) -> PolarsResult<Series> {
85            check_input(options.window_size, options.min_periods)?;
86
87            let ca = self.rechunk();
88            if options.weights.is_some()
89                && !matches!(self.dtype(), DataType::Float64 | DataType::Float32)
90            {
91                let s = self.cast_with_options(&DataType::Float64, CastOptions::NonStrict)?;
92                return s.rolling_map(f, options);
93            }
94
95            options.window_size = std::cmp::min(self.len(), options.window_size);
96
97            let len = self.len();
98            let arr = ca.downcast_as_array();
99            let mut ca = ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
100            let ptr = ca.chunks[0].as_mut() as *mut dyn Array as *mut PrimitiveArray<T::Native>;
101            let mut series_container = ca.into_series();
102
103            let mut builder = PrimitiveChunkedBuilder::<T>::new(self.name().clone(), self.len());
104
105            if let Some(weights) = options.weights {
106                let weights_series =
107                    Float64Chunked::new(PlSmallStr::from_static("weights"), &weights).into_series();
108
109                let weights_series = weights_series.cast(self.dtype()).unwrap();
110
111                for idx in 0..len {
112                    let (start, size) = window_edges(idx, len, options.window_size, options.center);
113
114                    if size < options.min_periods {
115                        builder.append_null();
116                    } else {
117                        // SAFETY:
118                        // we are in bounds
119                        let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
120
121                        // ensure we still meet window size criteria after removing null values
122                        if size - arr_window.null_count() < options.min_periods {
123                            builder.append_null();
124                            continue;
125                        }
126
127                        // SAFETY.
128                        // ptr is not dropped as we are in scope
129                        // We are also the only owner of the contents of the Arc
130                        // we do this to reduce heap allocs.
131                        unsafe {
132                            *ptr = arr_window;
133                        }
134                        // reset flags as we reuse this container
135                        series_container.clear_flags();
136                        // ensure the length is correct
137                        series_container._get_inner_mut().compute_len();
138                        let s = if size == options.window_size {
139                            f(&series_container.multiply(&weights_series).unwrap())
140                        } else {
141                            let weights_cutoff: Series = match self.dtype() {
142                                DataType::Float64 => weights_series
143                                    .f64()
144                                    .unwrap()
145                                    .into_iter()
146                                    .take(series_container.len())
147                                    .collect(),
148                                _ => weights_series // Float32 case
149                                    .f32()
150                                    .unwrap()
151                                    .into_iter()
152                                    .take(series_container.len())
153                                    .collect(),
154                            };
155                            f(&series_container.multiply(&weights_cutoff).unwrap())
156                        };
157
158                        let out = self.unpack_series_matching_type(&s)?;
159                        builder.append_option(out.get(0));
160                    }
161                }
162
163                Ok(builder.finish().into_series())
164            } else {
165                for idx in 0..len {
166                    let (start, size) = window_edges(idx, len, options.window_size, options.center);
167
168                    if size < options.min_periods {
169                        builder.append_null();
170                    } else {
171                        // SAFETY:
172                        // we are in bounds
173                        let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
174
175                        // ensure we still meet window size criteria after removing null values
176                        if size - arr_window.null_count() < options.min_periods {
177                            builder.append_null();
178                            continue;
179                        }
180
181                        // SAFETY.
182                        // ptr is not dropped as we are in scope
183                        // We are also the only owner of the contents of the Arc
184                        // we do this to reduce heap allocs.
185                        unsafe {
186                            *ptr = arr_window;
187                        }
188                        // reset flags as we reuse this container
189                        series_container.clear_flags();
190                        // ensure the length is correct
191                        series_container._get_inner_mut().compute_len();
192                        let s = f(&series_container);
193                        let out = self.unpack_series_matching_type(&s)?;
194                        builder.append_option(out.get(0));
195                    }
196                }
197
198                Ok(builder.finish().into_series())
199            }
200        }
201    }
202
203    impl<T> ChunkedArray<T>
204    where
205        ChunkedArray<T>: IntoSeries,
206        T: PolarsFloatType,
207        T::Native: Float + IsFloat + SubAssign + Pow<T::Native, Output = T::Native>,
208    {
209        /// Apply a rolling custom function. This is pretty slow because of dynamic dispatch.
210        pub fn rolling_map_float<F>(&self, window_size: usize, mut f: F) -> PolarsResult<Self>
211        where
212            F: FnMut(&mut ChunkedArray<T>) -> Option<T::Native>,
213        {
214            if window_size > self.len() {
215                return Ok(Self::full_null(self.name().clone(), self.len()));
216            }
217            let ca = self.rechunk();
218            let arr = ca.downcast_as_array();
219
220            // We create a temporary dummy ChunkedArray. This will be a
221            // container where we swap the window contents every iteration doing
222            // so will save a lot of heap allocations.
223            let mut heap_container =
224                ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
225            let ptr = heap_container.chunks[0].as_mut() as *mut dyn Array
226                as *mut PrimitiveArray<T::Native>;
227
228            let mut validity = MutableBitmap::with_capacity(ca.len());
229            validity.extend_constant(window_size - 1, false);
230            validity.extend_constant(ca.len() - (window_size - 1), true);
231            let validity_slice = validity.as_mut_slice();
232
233            let mut values = Vec::with_capacity(ca.len());
234            values.extend(std::iter::repeat_n(T::Native::default(), window_size - 1));
235
236            for offset in 0..self.len() + 1 - window_size {
237                debug_assert!(offset + window_size <= arr.len());
238                let arr_window = unsafe { arr.slice_typed_unchecked(offset, window_size) };
239                // The lengths are cached, so we must update them.
240                heap_container.length = arr_window.len();
241
242                // SAFETY: ptr is not dropped as we are in scope. We are also the only
243                // owner of the contents of the Arc (we do this to reduce heap allocs).
244                unsafe {
245                    *ptr = arr_window;
246                }
247
248                let out = f(&mut heap_container);
249                match out {
250                    Some(v) => {
251                        // SAFETY: we have pre-allocated.
252                        unsafe { values.push_unchecked(v) }
253                    },
254                    None => {
255                        // SAFETY: we allocated enough for both the `values` vec
256                        // and the `validity_ptr`.
257                        unsafe {
258                            values.push_unchecked(T::Native::default());
259                            set_bit_unchecked(validity_slice, offset + window_size - 1, false);
260                        }
261                    },
262                }
263            }
264            let arr = PrimitiveArray::new(
265                T::get_dtype().to_arrow(CompatLevel::newest()),
266                values.into(),
267                Some(validity.into()),
268            );
269            Ok(Self::with_chunk(self.name().clone(), arr))
270        }
271    }
272}