polars_ops/chunked_array/list/
sum_mean.rs

1use std::ops::Div;
2
3use arrow::array::{Array, PrimitiveArray};
4use arrow::bitmap::Bitmap;
5use arrow::compute::utils::combine_validities_and;
6use arrow::temporal_conversions::MICROSECONDS_IN_DAY as US_IN_DAY;
7use arrow::types::NativeType;
8use num_traits::{NumCast, ToPrimitive};
9use polars_utils::float16::pf16;
10
11use super::*;
12use crate::chunked_array::sum::sum_slice;
13
14fn sum_between_offsets<T, S>(values: &[T], offset: &[i64]) -> Vec<S>
15where
16    T: NativeType + ToPrimitive,
17    S: NumCast + std::iter::Sum,
18{
19    offset
20        .windows(2)
21        .map(|w| {
22            values
23                .get(w[0] as usize..w[1] as usize)
24                .map(sum_slice)
25                .unwrap_or(S::from(0).unwrap())
26        })
27        .collect()
28}
29
30fn dispatch_sum<T, S>(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef
31where
32    T: NativeType + ToPrimitive,
33    S: NativeType + NumCast + std::iter::Sum,
34{
35    let values = arr.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
36    let values = values.values().as_slice();
37    Box::new(PrimitiveArray::from_data_default(
38        sum_between_offsets::<_, S>(values, offsets).into(),
39        validity.cloned(),
40    )) as ArrayRef
41}
42
43pub(super) fn sum_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series {
44    use DataType::*;
45    let chunks = ca
46        .downcast_iter()
47        .map(|arr| {
48            let offsets = arr.offsets().as_slice();
49            let values = arr.values().as_ref();
50
51            match inner_type {
52                Int8 => dispatch_sum::<i8, i64>(values, offsets, arr.validity()),
53                Int16 => dispatch_sum::<i16, i64>(values, offsets, arr.validity()),
54                Int32 => dispatch_sum::<i32, i32>(values, offsets, arr.validity()),
55                Int64 => dispatch_sum::<i64, i64>(values, offsets, arr.validity()),
56                Int128 => dispatch_sum::<i128, i128>(values, offsets, arr.validity()),
57                UInt8 => dispatch_sum::<u8, i64>(values, offsets, arr.validity()),
58                UInt16 => dispatch_sum::<u16, i64>(values, offsets, arr.validity()),
59                UInt32 => dispatch_sum::<u32, u32>(values, offsets, arr.validity()),
60                UInt64 => dispatch_sum::<u64, u64>(values, offsets, arr.validity()),
61                UInt128 => dispatch_sum::<u128, u128>(values, offsets, arr.validity()),
62                Float16 => dispatch_sum::<pf16, pf16>(values, offsets, arr.validity()),
63                Float32 => dispatch_sum::<f32, f32>(values, offsets, arr.validity()),
64                Float64 => dispatch_sum::<f64, f64>(values, offsets, arr.validity()),
65                _ => unimplemented!(),
66            }
67        })
68        .collect::<Vec<_>>();
69
70    Series::try_from((ca.name().clone(), chunks)).unwrap()
71}
72
73pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> PolarsResult<Series> {
74    use DataType::*;
75    let mut out = match inner_dtype {
76        Boolean => {
77            let out: IdxCa =
78                ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::<IdxSize>().unwrap()));
79            out.into_series()
80        },
81        UInt8 => {
82            let out: Int64Chunked =
83                ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::<i64>().unwrap()));
84            out.into_series()
85        },
86        UInt16 => {
87            let out: Int64Chunked =
88                ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::<i64>().unwrap()));
89            out.into_series()
90        },
91        UInt32 => {
92            let out: UInt32Chunked =
93                ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::<u32>().unwrap()));
94            out.into_series()
95        },
96        UInt64 => {
97            let out: UInt64Chunked =
98                ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::<u64>().unwrap()));
99            out.into_series()
100        },
101        Int8 => {
102            let out: Int64Chunked =
103                ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::<i64>().unwrap()));
104            out.into_series()
105        },
106        Int16 => {
107            let out: Int64Chunked =
108                ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::<i64>().unwrap()));
109            out.into_series()
110        },
111        Int32 => {
112            let out: Int32Chunked =
113                ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::<i32>().unwrap()));
114            out.into_series()
115        },
116        Int64 => {
117            let out: Int64Chunked =
118                ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::<i64>().unwrap()));
119            out.into_series()
120        },
121        #[cfg(feature = "dtype-f16")]
122        Float16 => {
123            let out: Float16Chunked =
124                ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::<pf16>().unwrap()));
125            out.into_series()
126        },
127        Float32 => {
128            let out: Float32Chunked =
129                ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::<f32>().unwrap()));
130            out.into_series()
131        },
132        Float64 => {
133            let out: Float64Chunked =
134                ca.apply_amortized_generic(|s| s.map(|s| s.as_ref().sum::<f64>().unwrap()));
135            out.into_series()
136        },
137        // slowest sum_as_series path
138        dt => unsafe {
139            // SAFETY: `sum_reduce` doesn't change the dtype
140            ca.try_apply_amortized_same_type(|s| {
141                s.as_ref()
142                    .sum_reduce()
143                    .map(|sc| sc.into_series(PlSmallStr::EMPTY))
144            })?
145        }
146        .explode(ExplodeOptions {
147            empty_as_null: true,
148            keep_nulls: true,
149        })
150        .unwrap()
151        .into_series()
152        .cast(dt)?,
153    };
154    out.rename(ca.name().clone());
155    Ok(out)
156}
157
158fn mean_between_offsets<T, S>(values: &[T], offset: &[i64]) -> PrimitiveArray<S>
159where
160    T: NativeType + ToPrimitive,
161    S: NativeType + NumCast + std::iter::Sum + Div<Output = S>,
162{
163    offset
164        .windows(2)
165        .map(|w| {
166            values
167                .get(w[0] as usize..w[1] as usize)
168                .filter(|sl| !sl.is_empty())
169                .map(|sl| sum_slice::<_, S>(sl) / NumCast::from(sl.len()).unwrap())
170        })
171        .collect()
172}
173
174fn dispatch_mean<T, S>(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef
175where
176    T: NativeType + ToPrimitive,
177    S: NativeType + NumCast + std::iter::Sum + Div<Output = S>,
178{
179    let values = arr.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
180    let values = values.values().as_slice();
181    let out = mean_between_offsets::<_, S>(values, offsets);
182    let new_validity = combine_validities_and(out.validity(), validity);
183    out.with_validity(new_validity).to_boxed()
184}
185
186pub(super) fn mean_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series {
187    use DataType::*;
188    let chunks = ca
189        .downcast_iter()
190        .map(|arr| {
191            let offsets = arr.offsets().as_slice();
192            let values = arr.values().as_ref();
193
194            match inner_type {
195                Int8 => dispatch_mean::<i8, f64>(values, offsets, arr.validity()),
196                Int16 => dispatch_mean::<i16, f64>(values, offsets, arr.validity()),
197                Int32 => dispatch_mean::<i32, f64>(values, offsets, arr.validity()),
198                Int64 => dispatch_mean::<i64, f64>(values, offsets, arr.validity()),
199                Int128 => dispatch_mean::<i128, f64>(values, offsets, arr.validity()),
200                UInt8 => dispatch_mean::<u8, f64>(values, offsets, arr.validity()),
201                UInt16 => dispatch_mean::<u16, f64>(values, offsets, arr.validity()),
202                UInt32 => dispatch_mean::<u32, f64>(values, offsets, arr.validity()),
203                UInt64 => dispatch_mean::<u64, f64>(values, offsets, arr.validity()),
204                UInt128 => dispatch_mean::<u128, f64>(values, offsets, arr.validity()),
205                Float32 => dispatch_mean::<f32, f32>(values, offsets, arr.validity()),
206                Float64 => dispatch_mean::<f64, f64>(values, offsets, arr.validity()),
207                _ => unimplemented!(),
208            }
209        })
210        .collect::<Vec<_>>();
211
212    Series::try_from((ca.name().clone(), chunks)).unwrap()
213}
214
215pub(super) fn mean_with_nulls(ca: &ListChunked) -> Series {
216    match ca.inner_dtype() {
217        #[cfg(feature = "dtype-f16")]
218        DataType::Float16 => {
219            let out: Float16Chunked = ca
220                .apply_amortized_generic(|s| {
221                    use num_traits::FromPrimitive;
222
223                    s.and_then(|s| s.as_ref().mean().map(|v| pf16::from_f64(v).unwrap()))
224                })
225                .with_name(ca.name().clone());
226            out.into_series()
227        },
228        DataType::Float32 => {
229            let out: Float32Chunked = ca
230                .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().mean().map(|v| v as f32)))
231                .with_name(ca.name().clone());
232            out.into_series()
233        },
234        #[cfg(feature = "dtype-datetime")]
235        DataType::Date => {
236            let out: Int64Chunked = ca
237                .apply_amortized_generic(|s| {
238                    s.and_then(|s| s.as_ref().mean().map(|v| (v * (US_IN_DAY as f64)) as i64))
239                })
240                .with_name(ca.name().clone());
241            out.into_datetime(TimeUnit::Microseconds, None)
242                .into_series()
243        },
244        dt if dt.is_temporal() => {
245            let out: Int64Chunked = ca
246                .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().mean().map(|v| v as i64)))
247                .with_name(ca.name().clone());
248            out.cast(dt).unwrap()
249        },
250        _ => {
251            let out: Float64Chunked = ca
252                .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().mean()))
253                .with_name(ca.name().clone());
254            out.into_series()
255        },
256    }
257}