polars_core/chunked_array/ops/
rolling_window.rs

1use std::hash::{Hash, Hasher};
2
3use polars_compute::rolling::RollingFnParams;
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Debug)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
10#[cfg_attr(feature = "rolling_window", derive(PartialEq))]
11pub struct RollingOptionsFixedWindow {
12    /// The length of the window.
13    pub window_size: usize,
14    /// Amount of elements in the window that should be filled before computing a result.
15    pub min_periods: usize,
16    /// An optional slice with the same length as the window that will be multiplied
17    ///              elementwise with the values in the window.
18    pub weights: Option<Vec<f64>>,
19    /// Set the labels at the center of the window.
20    pub center: bool,
21    /// Optional parameters for the rolling
22    #[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(default))]
23    pub fn_params: Option<RollingFnParams>,
24}
25
26impl Hash for RollingOptionsFixedWindow {
27    fn hash<H: Hasher>(&self, state: &mut H) {
28        self.window_size.hash(state);
29        self.min_periods.hash(state);
30        self.center.hash(state);
31        self.weights.is_some().hash(state);
32    }
33}
34
35impl Default for RollingOptionsFixedWindow {
36    fn default() -> Self {
37        RollingOptionsFixedWindow {
38            window_size: 3,
39            min_periods: 1,
40            weights: None,
41            center: false,
42            fn_params: None,
43        }
44    }
45}
46
47#[cfg(feature = "rolling_window")]
48mod inner_mod {
49    use num_traits::Zero;
50
51    use crate::chunked_array::cast::CastOptions;
52    use crate::prelude::*;
53
54    /// utility
55    fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> {
56        polars_ensure!(
57            min_periods <= window_size,
58            ComputeError: "`window_size`: {} should be >= `min_periods`: {}",
59            window_size, min_periods
60        );
61        Ok(())
62    }
63
64    /// utility
65    fn window_edges(idx: usize, len: usize, window_size: usize, center: bool) -> (usize, usize) {
66        let (start, end) = if center {
67            let right_window = window_size.div_ceil(2);
68            (
69                idx.saturating_sub(window_size - right_window),
70                len.min(idx + right_window),
71            )
72        } else {
73            (idx.saturating_sub(window_size - 1), idx + 1)
74        };
75
76        (start, end - start)
77    }
78
79    impl<T: PolarsNumericType> ChunkRollApply for ChunkedArray<T> {
80        /// Apply a rolling custom function. This is pretty slow because of dynamic dispatch.
81        fn rolling_map(
82            &self,
83            f: &dyn Fn(&Series) -> PolarsResult<Series>,
84            mut options: RollingOptionsFixedWindow,
85        ) -> PolarsResult<Series> {
86            check_input(options.window_size, options.min_periods)?;
87
88            let ca = self.rechunk();
89            if options.weights.is_some()
90                && !matches!(self.dtype(), DataType::Float64 | DataType::Float32)
91            {
92                let s = self.cast_with_options(&DataType::Float64, CastOptions::NonStrict)?;
93                return s.rolling_map(f, options);
94            }
95
96            options.window_size = std::cmp::min(self.len(), options.window_size);
97
98            let len = self.len();
99            let arr = ca.downcast_as_array();
100            let mut ca = ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
101            let ptr = ca.chunks[0].as_mut() as *mut dyn Array as *mut PrimitiveArray<T::Native>;
102            let mut series_container = ca.into_series();
103
104            let mut builder = PrimitiveChunkedBuilder::<T>::new(self.name().clone(), self.len());
105
106            if let Some(weights) = options.weights {
107                let weights_series =
108                    Float64Chunked::new(PlSmallStr::from_static("weights"), &weights).into_series();
109
110                let weights_series = weights_series.cast(self.dtype()).unwrap();
111
112                for idx in 0..len {
113                    let (start, size) = window_edges(idx, len, options.window_size, options.center);
114
115                    if size < options.min_periods {
116                        builder.append_null();
117                    } else {
118                        // SAFETY:
119                        // we are in bounds
120                        let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
121
122                        // ensure we still meet window size criteria after removing null values
123                        if size - arr_window.null_count() < options.min_periods {
124                            builder.append_null();
125                            continue;
126                        }
127
128                        // SAFETY.
129                        // ptr is not dropped as we are in scope
130                        // We are also the only owner of the contents of the Arc
131                        // we do this to reduce heap allocs.
132                        unsafe {
133                            *ptr = arr_window;
134                        }
135                        // reset flags as we reuse this container
136                        series_container.clear_flags();
137                        // ensure the length is correct
138                        series_container._get_inner_mut().compute_len();
139                        let s = if size == options.window_size {
140                            f(&series_container.multiply(&weights_series).unwrap())?
141                        } else {
142                            // Determine which side to slice weights from
143                            let weights_cutoff: Series = match self.dtype() {
144                                DataType::Float64 => {
145                                    let ws = weights_series.f64().unwrap();
146                                    if start == 0 {
147                                        ws.slice(
148                                            (ws.len() - series_container.len()) as i64,
149                                            series_container.len(),
150                                        )
151                                        .into_series()
152                                    } else {
153                                        ws.slice(0, series_container.len()).into_series()
154                                    }
155                                },
156                                _ => {
157                                    let ws = weights_series.f32().unwrap();
158                                    if start == 0 {
159                                        ws.slice(
160                                            (ws.len() - series_container.len()) as i64,
161                                            series_container.len(),
162                                        )
163                                        .into_series()
164                                    } else {
165                                        ws.slice(0, series_container.len()).into_series()
166                                    }
167                                },
168                            };
169                            f(&series_container.multiply(&weights_cutoff).unwrap())?
170                        };
171
172                        let out = self.unpack_series_matching_type(&s)?;
173                        builder.append_option(out.get(0));
174                    }
175                }
176
177                Ok(builder.finish().into_series())
178            } else {
179                for idx in 0..len {
180                    let (start, size) = window_edges(idx, len, options.window_size, options.center);
181
182                    if size < options.min_periods {
183                        builder.append_null();
184                    } else {
185                        // SAFETY:
186                        // we are in bounds
187                        let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
188
189                        // ensure we still meet window size criteria after removing null values
190                        if size - arr_window.null_count() < options.min_periods {
191                            builder.append_null();
192                            continue;
193                        }
194
195                        // SAFETY.
196                        // ptr is not dropped as we are in scope
197                        // We are also the only owner of the contents of the Arc
198                        // we do this to reduce heap allocs.
199                        unsafe {
200                            *ptr = arr_window;
201                        }
202                        // reset flags as we reuse this container
203                        series_container.clear_flags();
204                        // ensure the length is correct
205                        series_container._get_inner_mut().compute_len();
206                        let s = f(&series_container)?;
207                        let out = self.unpack_series_matching_type(&s)?;
208                        builder.append_option(out.get(0));
209                    }
210                }
211
212                Ok(builder.finish().into_series())
213            }
214        }
215    }
216}