polars_core/chunked_array/ops/
rolling_window.rs1use std::hash::{Hash, Hasher};
2
3use polars_compute::rolling::RollingFnParams;
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Debug)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
10#[cfg_attr(feature = "rolling_window", derive(PartialEq))]
11pub struct RollingOptionsFixedWindow {
12 pub window_size: usize,
14 pub min_periods: usize,
16 pub weights: Option<Vec<f64>>,
19 pub center: bool,
21 #[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(default))]
23 pub fn_params: Option<RollingFnParams>,
24}
25
26impl Hash for RollingOptionsFixedWindow {
27 fn hash<H: Hasher>(&self, state: &mut H) {
28 self.window_size.hash(state);
29 self.min_periods.hash(state);
30 self.center.hash(state);
31 self.weights.is_some().hash(state);
32 }
33}
34
35impl Default for RollingOptionsFixedWindow {
36 fn default() -> Self {
37 RollingOptionsFixedWindow {
38 window_size: 3,
39 min_periods: 1,
40 weights: None,
41 center: false,
42 fn_params: None,
43 }
44 }
45}
46
47#[cfg(feature = "rolling_window")]
48mod inner_mod {
49 use num_traits::Zero;
50
51 use crate::chunked_array::cast::CastOptions;
52 use crate::prelude::*;
53
54 fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> {
56 polars_ensure!(
57 min_periods <= window_size,
58 ComputeError: "`window_size`: {} should be >= `min_periods`: {}",
59 window_size, min_periods
60 );
61 Ok(())
62 }
63
64 fn window_edges(idx: usize, len: usize, window_size: usize, center: bool) -> (usize, usize) {
66 let (start, end) = if center {
67 let right_window = window_size.div_ceil(2);
68 (
69 idx.saturating_sub(window_size - right_window),
70 len.min(idx + right_window),
71 )
72 } else {
73 (idx.saturating_sub(window_size - 1), idx + 1)
74 };
75
76 (start, end - start)
77 }
78
79 impl<T: PolarsNumericType> ChunkRollApply for ChunkedArray<T> {
80 fn rolling_map(
82 &self,
83 f: &dyn Fn(&Series) -> PolarsResult<Series>,
84 mut options: RollingOptionsFixedWindow,
85 ) -> PolarsResult<Series> {
86 check_input(options.window_size, options.min_periods)?;
87
88 let ca = self.rechunk();
89 if options.weights.is_some()
90 && !matches!(self.dtype(), DataType::Float64 | DataType::Float32)
91 {
92 let s = self.cast_with_options(&DataType::Float64, CastOptions::NonStrict)?;
93 return s.rolling_map(f, options);
94 }
95
96 options.window_size = std::cmp::min(self.len(), options.window_size);
97
98 let len = self.len();
99 let arr = ca.downcast_as_array();
100 let mut ca = ChunkedArray::<T>::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]);
101 let ptr = ca.chunks[0].as_mut() as *mut dyn Array as *mut PrimitiveArray<T::Native>;
102 let mut series_container = ca.into_series();
103
104 let mut builder = PrimitiveChunkedBuilder::<T>::new(self.name().clone(), self.len());
105
106 if let Some(weights) = options.weights {
107 let weights_series =
108 Float64Chunked::new(PlSmallStr::from_static("weights"), &weights).into_series();
109
110 let weights_series = weights_series.cast(self.dtype()).unwrap();
111
112 for idx in 0..len {
113 let (start, size) = window_edges(idx, len, options.window_size, options.center);
114
115 if size < options.min_periods {
116 builder.append_null();
117 } else {
118 let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
121
122 if size - arr_window.null_count() < options.min_periods {
124 builder.append_null();
125 continue;
126 }
127
128 unsafe {
133 *ptr = arr_window;
134 }
135 series_container.clear_flags();
137 series_container._get_inner_mut().compute_len();
139 let s = if size == options.window_size {
140 f(&series_container.multiply(&weights_series).unwrap())?
141 } else {
142 let weights_cutoff: Series = match self.dtype() {
144 DataType::Float64 => {
145 let ws = weights_series.f64().unwrap();
146 if start == 0 {
147 ws.slice(
148 (ws.len() - series_container.len()) as i64,
149 series_container.len(),
150 )
151 .into_series()
152 } else {
153 ws.slice(0, series_container.len()).into_series()
154 }
155 },
156 _ => {
157 let ws = weights_series.f32().unwrap();
158 if start == 0 {
159 ws.slice(
160 (ws.len() - series_container.len()) as i64,
161 series_container.len(),
162 )
163 .into_series()
164 } else {
165 ws.slice(0, series_container.len()).into_series()
166 }
167 },
168 };
169 f(&series_container.multiply(&weights_cutoff).unwrap())?
170 };
171
172 let out = self.unpack_series_matching_type(&s)?;
173 builder.append_option(out.get(0));
174 }
175 }
176
177 Ok(builder.finish().into_series())
178 } else {
179 for idx in 0..len {
180 let (start, size) = window_edges(idx, len, options.window_size, options.center);
181
182 if size < options.min_periods {
183 builder.append_null();
184 } else {
185 let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };
188
189 if size - arr_window.null_count() < options.min_periods {
191 builder.append_null();
192 continue;
193 }
194
195 unsafe {
200 *ptr = arr_window;
201 }
202 series_container.clear_flags();
204 series_container._get_inner_mut().compute_len();
206 let s = f(&series_container)?;
207 let out = self.unpack_series_matching_type(&s)?;
208 builder.append_option(out.get(0));
209 }
210 }
211
212 Ok(builder.finish().into_series())
213 }
214 }
215 }
216}