polars_core/chunked_array/ops/
rolling_window.rs1use std::hash::{Hash, Hasher};
2
3use polars_compute::rolling::RollingFnParams;
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Debug)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[cfg_attr(feature = "rolling_window", derive(PartialEq))]
10pub struct RollingOptionsFixedWindow {
11 pub window_size: usize,
13 pub min_periods: usize,
15 pub weights: Option<Vec<f64>>,
18 pub center: bool,
20 #[cfg_attr(feature = "serde", serde(default))]
22 pub fn_params: Option<RollingFnParams>,
23}
24
25impl Hash for RollingOptionsFixedWindow {
26 fn hash<H: Hasher>(&self, state: &mut H) {
27 self.window_size.hash(state);
28 self.min_periods.hash(state);
29 self.center.hash(state);
30 self.weights.is_some().hash(state);
31 }
32}
33
34impl Default for RollingOptionsFixedWindow {
35 fn default() -> Self {
36 RollingOptionsFixedWindow {
37 window_size: 3,
38 min_periods: 1,
39 weights: None,
40 center: false,
41 fn_params: None,
42 }
43 }
44}
45
46#[cfg(feature = "rolling_window")]
47mod inner_mod {
48 use std::ops::SubAssign;
49
50 use arrow::bitmap::MutableBitmap;
51 use arrow::bitmap::utils::set_bit_unchecked;
52 use arrow::legacy::trusted_len::TrustedLenPush;
53 use num_traits::pow::Pow;
54 use num_traits::{Float, Zero};
55 use polars_utils::float::IsFloat;
56
57 use crate::chunked_array::cast::CastOptions;
58 use crate::prelude::*;
59
60 fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> {
62 polars_ensure!(
63 min_periods <= window_size,
64 ComputeError: "`window_size`: {} should be >= `min_periods`: {}",
65 window_size, min_periods
66 );
67 Ok(())
68 }
69
70 fn window_edges(idx: usize, len: usize, window_size: usize, center: bool) -> (usize, usize) {
72 let (start, end) = if center {
73 let right_window = window_size.div_ceil(2);
74 (
75 idx.saturating_sub(window_size - right_window),
76 len.min(idx + right_window),
77 )
78 } else {
79 (idx.saturating_sub(window_size - 1), idx + 1)
80 };
81
82 (start, end - start)
83 }
84
85 impl<T> ChunkRollApply for ChunkedArray<T>
86 where
87 T: PolarsNumericType,
88 Self: IntoSeries,
89 {
90 fn rolling_map(
92 &self,
93 f: &dyn Fn(&Series) -> Series,
94 mut options: RollingOptionsFixedWindow,
95 ) -> PolarsResult<Series> {
96 check_input(options.window_size, options.min_periods)?;
97
98 let ca = self.rechunk();
99 if options.weights.is_some()
100 && !matches!(self.dtype(), DataType::Float64 | DataType::Float32)
101 {
102 let s = self.cast_with_options(&DataType::Float64, CastOptions::NonStrict)?;
103 return s.rolling_map(f, options);
104 }
105
106 options.window_size = std::cmp::min(self.len(), options.window_size);
107
108 let len = self.len();
109 let arr = ca.downcast_as_array();
110 let mut ca = ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
111 let ptr = ca.chunks[0].as_mut() as *mut dyn Array as *mut PrimitiveArray<T::Native>;
112 let mut series_container = ca.into_series();
113
114 let mut builder = PrimitiveChunkedBuilder::<T>::new(self.name().clone(), self.len());
115
116 if let Some(weights) = options.weights {
117 let weights_series =
118 Float64Chunked::new(PlSmallStr::from_static("weights"), &weights).into_series();
119
120 let weights_series = weights_series.cast(self.dtype()).unwrap();
121
122 for idx in 0..len {
123 let (start, size) = window_edges(idx, len, options.window_size, options.center);
124
125 if size < options.min_periods {
126 builder.append_null();
127 } else {
128 let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
131
132 if size - arr_window.null_count() < options.min_periods {
134 builder.append_null();
135 continue;
136 }
137
138 unsafe {
143 *ptr = arr_window;
144 }
145 series_container.clear_flags();
147 series_container._get_inner_mut().compute_len();
149 let s = if size == options.window_size {
150 f(&series_container.multiply(&weights_series).unwrap())
151 } else {
152 let weights_cutoff: Series = match self.dtype() {
153 DataType::Float64 => weights_series
154 .f64()
155 .unwrap()
156 .into_iter()
157 .take(series_container.len())
158 .collect(),
159 _ => weights_series .f32()
161 .unwrap()
162 .into_iter()
163 .take(series_container.len())
164 .collect(),
165 };
166 f(&series_container.multiply(&weights_cutoff).unwrap())
167 };
168
169 let out = self.unpack_series_matching_type(&s)?;
170 builder.append_option(out.get(0));
171 }
172 }
173
174 Ok(builder.finish().into_series())
175 } else {
176 for idx in 0..len {
177 let (start, size) = window_edges(idx, len, options.window_size, options.center);
178
179 if size < options.min_periods {
180 builder.append_null();
181 } else {
182 let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
185
186 if size - arr_window.null_count() < options.min_periods {
188 builder.append_null();
189 continue;
190 }
191
192 unsafe {
197 *ptr = arr_window;
198 }
199 series_container.clear_flags();
201 series_container._get_inner_mut().compute_len();
203 let s = f(&series_container);
204 let out = self.unpack_series_matching_type(&s)?;
205 builder.append_option(out.get(0));
206 }
207 }
208
209 Ok(builder.finish().into_series())
210 }
211 }
212 }
213
214 impl<T> ChunkedArray<T>
215 where
216 ChunkedArray<T>: IntoSeries,
217 T: PolarsFloatType,
218 T::Native: Float + IsFloat + SubAssign + Pow<T::Native, Output = T::Native>,
219 {
220 pub fn rolling_map_float<F>(&self, window_size: usize, mut f: F) -> PolarsResult<Self>
222 where
223 F: FnMut(&mut ChunkedArray<T>) -> Option<T::Native>,
224 {
225 if window_size > self.len() {
226 return Ok(Self::full_null(self.name().clone(), self.len()));
227 }
228 let ca = self.rechunk();
229 let arr = ca.downcast_as_array();
230
231 let mut heap_container =
235 ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
236 let ptr = heap_container.chunks[0].as_mut() as *mut dyn Array
237 as *mut PrimitiveArray<T::Native>;
238
239 let mut validity = MutableBitmap::with_capacity(ca.len());
240 validity.extend_constant(window_size - 1, false);
241 validity.extend_constant(ca.len() - (window_size - 1), true);
242 let validity_slice = validity.as_mut_slice();
243
244 let mut values = Vec::with_capacity(ca.len());
245 values.extend(std::iter::repeat_n(T::Native::default(), window_size - 1));
246
247 for offset in 0..self.len() + 1 - window_size {
248 debug_assert!(offset + window_size <= arr.len());
249 let arr_window = unsafe { arr.slice_typed_unchecked(offset, window_size) };
250 heap_container.length = arr_window.len();
252
253 unsafe {
256 *ptr = arr_window;
257 }
258
259 let out = f(&mut heap_container);
260 match out {
261 Some(v) => {
262 unsafe { values.push_unchecked(v) }
264 },
265 None => {
266 unsafe {
269 values.push_unchecked(T::Native::default());
270 set_bit_unchecked(validity_slice, offset + window_size - 1, false);
271 }
272 },
273 }
274 }
275 let arr = PrimitiveArray::new(
276 T::get_static_dtype().to_arrow(CompatLevel::newest()),
277 values.into(),
278 Some(validity.into()),
279 );
280 Ok(Self::with_chunk(self.name().clone(), arr))
281 }
282 }
283}