polars_core/chunked_array/ops/
rolling_window.rs1use polars_compute::rolling::RollingFnParams;
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4
5#[derive(Clone, Debug)]
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7#[cfg_attr(feature = "rolling_window", derive(PartialEq))]
8pub struct RollingOptionsFixedWindow {
9 pub window_size: usize,
11 pub min_periods: usize,
13 pub weights: Option<Vec<f64>>,
16 pub center: bool,
18 #[cfg_attr(feature = "serde", serde(default))]
20 pub fn_params: Option<RollingFnParams>,
21}
22
23impl Default for RollingOptionsFixedWindow {
24 fn default() -> Self {
25 RollingOptionsFixedWindow {
26 window_size: 3,
27 min_periods: 1,
28 weights: None,
29 center: false,
30 fn_params: None,
31 }
32 }
33}
34
35#[cfg(feature = "rolling_window")]
36mod inner_mod {
37 use std::ops::SubAssign;
38
39 use arrow::bitmap::MutableBitmap;
40 use arrow::bitmap::utils::set_bit_unchecked;
41 use arrow::legacy::trusted_len::TrustedLenPush;
42 use num_traits::pow::Pow;
43 use num_traits::{Float, Zero};
44 use polars_utils::float::IsFloat;
45
46 use crate::chunked_array::cast::CastOptions;
47 use crate::prelude::*;
48
49 fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> {
51 polars_ensure!(
52 min_periods <= window_size,
53 ComputeError: "`window_size`: {} should be >= `min_periods`: {}",
54 window_size, min_periods
55 );
56 Ok(())
57 }
58
59 fn window_edges(idx: usize, len: usize, window_size: usize, center: bool) -> (usize, usize) {
61 let (start, end) = if center {
62 let right_window = window_size.div_ceil(2);
63 (
64 idx.saturating_sub(window_size - right_window),
65 len.min(idx + right_window),
66 )
67 } else {
68 (idx.saturating_sub(window_size - 1), idx + 1)
69 };
70
71 (start, end - start)
72 }
73
74 impl<T> ChunkRollApply for ChunkedArray<T>
75 where
76 T: PolarsNumericType,
77 Self: IntoSeries,
78 {
79 fn rolling_map(
81 &self,
82 f: &dyn Fn(&Series) -> Series,
83 mut options: RollingOptionsFixedWindow,
84 ) -> PolarsResult<Series> {
85 check_input(options.window_size, options.min_periods)?;
86
87 let ca = self.rechunk();
88 if options.weights.is_some()
89 && !matches!(self.dtype(), DataType::Float64 | DataType::Float32)
90 {
91 let s = self.cast_with_options(&DataType::Float64, CastOptions::NonStrict)?;
92 return s.rolling_map(f, options);
93 }
94
95 options.window_size = std::cmp::min(self.len(), options.window_size);
96
97 let len = self.len();
98 let arr = ca.downcast_as_array();
99 let mut ca = ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
100 let ptr = ca.chunks[0].as_mut() as *mut dyn Array as *mut PrimitiveArray<T::Native>;
101 let mut series_container = ca.into_series();
102
103 let mut builder = PrimitiveChunkedBuilder::<T>::new(self.name().clone(), self.len());
104
105 if let Some(weights) = options.weights {
106 let weights_series =
107 Float64Chunked::new(PlSmallStr::from_static("weights"), &weights).into_series();
108
109 let weights_series = weights_series.cast(self.dtype()).unwrap();
110
111 for idx in 0..len {
112 let (start, size) = window_edges(idx, len, options.window_size, options.center);
113
114 if size < options.min_periods {
115 builder.append_null();
116 } else {
117 let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
120
121 if size - arr_window.null_count() < options.min_periods {
123 builder.append_null();
124 continue;
125 }
126
127 unsafe {
132 *ptr = arr_window;
133 }
134 series_container.clear_flags();
136 series_container._get_inner_mut().compute_len();
138 let s = if size == options.window_size {
139 f(&series_container.multiply(&weights_series).unwrap())
140 } else {
141 let weights_cutoff: Series = match self.dtype() {
142 DataType::Float64 => weights_series
143 .f64()
144 .unwrap()
145 .into_iter()
146 .take(series_container.len())
147 .collect(),
148 _ => weights_series .f32()
150 .unwrap()
151 .into_iter()
152 .take(series_container.len())
153 .collect(),
154 };
155 f(&series_container.multiply(&weights_cutoff).unwrap())
156 };
157
158 let out = self.unpack_series_matching_type(&s)?;
159 builder.append_option(out.get(0));
160 }
161 }
162
163 Ok(builder.finish().into_series())
164 } else {
165 for idx in 0..len {
166 let (start, size) = window_edges(idx, len, options.window_size, options.center);
167
168 if size < options.min_periods {
169 builder.append_null();
170 } else {
171 let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
174
175 if size - arr_window.null_count() < options.min_periods {
177 builder.append_null();
178 continue;
179 }
180
181 unsafe {
186 *ptr = arr_window;
187 }
188 series_container.clear_flags();
190 series_container._get_inner_mut().compute_len();
192 let s = f(&series_container);
193 let out = self.unpack_series_matching_type(&s)?;
194 builder.append_option(out.get(0));
195 }
196 }
197
198 Ok(builder.finish().into_series())
199 }
200 }
201 }
202
203 impl<T> ChunkedArray<T>
204 where
205 ChunkedArray<T>: IntoSeries,
206 T: PolarsFloatType,
207 T::Native: Float + IsFloat + SubAssign + Pow<T::Native, Output = T::Native>,
208 {
209 pub fn rolling_map_float<F>(&self, window_size: usize, mut f: F) -> PolarsResult<Self>
211 where
212 F: FnMut(&mut ChunkedArray<T>) -> Option<T::Native>,
213 {
214 if window_size > self.len() {
215 return Ok(Self::full_null(self.name().clone(), self.len()));
216 }
217 let ca = self.rechunk();
218 let arr = ca.downcast_as_array();
219
220 let mut heap_container =
224 ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
225 let ptr = heap_container.chunks[0].as_mut() as *mut dyn Array
226 as *mut PrimitiveArray<T::Native>;
227
228 let mut validity = MutableBitmap::with_capacity(ca.len());
229 validity.extend_constant(window_size - 1, false);
230 validity.extend_constant(ca.len() - (window_size - 1), true);
231 let validity_slice = validity.as_mut_slice();
232
233 let mut values = Vec::with_capacity(ca.len());
234 values.extend(std::iter::repeat_n(T::Native::default(), window_size - 1));
235
236 for offset in 0..self.len() + 1 - window_size {
237 debug_assert!(offset + window_size <= arr.len());
238 let arr_window = unsafe { arr.slice_typed_unchecked(offset, window_size) };
239 heap_container.length = arr_window.len();
241
242 unsafe {
245 *ptr = arr_window;
246 }
247
248 let out = f(&mut heap_container);
249 match out {
250 Some(v) => {
251 unsafe { values.push_unchecked(v) }
253 },
254 None => {
255 unsafe {
258 values.push_unchecked(T::Native::default());
259 set_bit_unchecked(validity_slice, offset + window_size - 1, false);
260 }
261 },
262 }
263 }
264 let arr = PrimitiveArray::new(
265 T::get_dtype().to_arrow(CompatLevel::newest()),
266 values.into(),
267 Some(validity.into()),
268 );
269 Ok(Self::with_chunk(self.name().clone(), arr))
270 }
271 }
272}