polars_ops/series/ops/
arg_min_max.rs

1use arrow::array::Array;
2use polars_core::chunked_array::ops::float_sorted_arg_max::{
3    float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
4};
5use polars_core::series::IsSorted;
6use polars_core::with_match_categorical_physical_type;
7use polars_utils::arg_min_max::ArgMinMax;
8
9use super::*;
10
11/// Argmin/ Argmax
12pub trait ArgAgg {
13    /// Get the index of the minimal value
14    fn arg_min(&self) -> Option<usize>;
15    /// Get the index of the maximal value
16    fn arg_max(&self) -> Option<usize>;
17}
18
19macro_rules! with_match_physical_numeric_polars_type {(
20    $key_type:expr, | $_:tt $T:ident | $($body:tt)*
21) => ({
22    macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
23    use DataType::*;
24    match $key_type {
25        #[cfg(feature = "dtype-i8")]
26        Int8 => __with_ty__! { Int8Type },
27        #[cfg(feature = "dtype-i16")]
28        Int16 => __with_ty__! { Int16Type },
29        Int32 => __with_ty__! { Int32Type },
30        Int64 => __with_ty__! { Int64Type },
31        #[cfg(feature = "dtype-i128")]
32        Int128 => __with_ty__! { Int128Type },
33        #[cfg(feature = "dtype-u8")]
34        UInt8 => __with_ty__! { UInt8Type },
35        #[cfg(feature = "dtype-u16")]
36        UInt16 => __with_ty__! { UInt16Type },
37        UInt32 => __with_ty__! { UInt32Type },
38        UInt64 => __with_ty__! { UInt64Type },
39        #[cfg(feature = "dtype-u128")]
40        UInt128 => __with_ty__! { UInt128Type },
41        #[cfg(feature = "dtype-f16")]
42        Float16 => __with_ty__! { Float16Type },
43        Float32 => __with_ty__! { Float32Type },
44        Float64 => __with_ty__! { Float64Type },
45        dt => panic!("not implemented for dtype {:?}", dt),
46    }
47})}
48
49impl ArgAgg for Series {
50    fn arg_min(&self) -> Option<usize> {
51        use DataType::*;
52        let phys_s = self.to_physical_repr();
53        match self.dtype() {
54            #[cfg(feature = "dtype-categorical")]
55            Categorical(cats, _) => {
56                with_match_categorical_physical_type!(cats.physical(), |$C| {
57                    let ca = self.cat::<$C>().unwrap();
58                    if ca.null_count() == ca.len() {
59                        return None;
60                    }
61                    ca.iter_str()
62                        .enumerate()
63                        .flat_map(|(idx, val)| val.map(|val| (idx, val)))
64                        .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
65                        .map(|tpl| tpl.0)
66                })
67            },
68            #[cfg(feature = "dtype-categorical")]
69            Enum(_, _) => phys_s.arg_min(),
70            Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_min(),
71            String => {
72                let ca = self.str().unwrap();
73                arg_min_str(ca)
74            },
75            Boolean => {
76                let ca = self.bool().unwrap();
77                arg_min_bool(ca)
78            },
79            dt if dt.is_primitive_numeric() => {
80                with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
81                    let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
82                    arg_min_numeric_dispatch(ca)
83                })
84            },
85            _ => None,
86        }
87    }
88
89    fn arg_max(&self) -> Option<usize> {
90        use DataType::*;
91        let phys_s = self.to_physical_repr();
92        match self.dtype() {
93            #[cfg(feature = "dtype-categorical")]
94            Categorical(cats, _) => {
95                with_match_categorical_physical_type!(cats.physical(), |$C| {
96                    let ca = self.cat::<$C>().unwrap();
97                    if ca.null_count() == ca.len() {
98                        return None;
99                    }
100                    ca.iter_str()
101                        .enumerate()
102                        .flat_map(|(idx, val)| val.map(|val| (idx, val)))
103                        .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
104                        .map(|tpl| tpl.0)
105                })
106            },
107            #[cfg(feature = "dtype-categorical")]
108            Enum(_, _) => phys_s.arg_max(),
109            Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_max(),
110            String => {
111                let ca = self.str().unwrap();
112                arg_max_str(ca)
113            },
114            Boolean => {
115                let ca = self.bool().unwrap();
116                arg_max_bool(ca)
117            },
118            #[cfg(feature = "dtype-f16")]
119            Float16 => {
120                // TODO: Dispatch to ArgMinMax directly
121                phys_s.cast(&DataType::Float32).unwrap().arg_max()
122            },
123            dt if dt.is_primitive_numeric() => {
124                with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
125                    let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
126                    arg_max_numeric_dispatch(ca)
127                })
128            },
129            _ => None,
130        }
131    }
132}
133
134fn arg_max_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
135where
136    T: PolarsNumericType,
137    for<'b> &'b [T::Native]: ArgMinMax,
138{
139    if ca.null_count() == ca.len() {
140        None
141    } else if T::get_static_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) {
142        arg_max_float_sorted(ca)
143    } else if let Ok(vals) = ca.cont_slice() {
144        arg_max_numeric_slice(vals, ca.is_sorted_flag())
145    } else {
146        arg_max_numeric(ca)
147    }
148}
149
150fn arg_min_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
151where
152    T: PolarsNumericType,
153    for<'b> &'b [T::Native]: ArgMinMax,
154{
155    if ca.null_count() == ca.len() {
156        None
157    } else if let Ok(vals) = ca.cont_slice() {
158        arg_min_numeric_slice(vals, ca.is_sorted_flag())
159    } else {
160        arg_min_numeric(ca)
161    }
162}
163
164fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
165    ca.first_true_idx().or_else(|| ca.first_false_idx())
166}
167
168/// # Safety
169/// `ca` has a float dtype, has at least one non-null value and is sorted.
170fn arg_max_float_sorted<T>(ca: &ChunkedArray<T>) -> Option<usize>
171where
172    T: PolarsNumericType,
173{
174    let out = match ca.is_sorted_flag() {
175        IsSorted::Ascending => float_arg_max_sorted_ascending(ca),
176        IsSorted::Descending => float_arg_max_sorted_descending(ca),
177        _ => unreachable!(),
178    };
179    Some(out)
180}
181
182fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {
183    ca.first_false_idx().or_else(|| ca.first_true_idx())
184}
185
186fn arg_min_str(ca: &StringChunked) -> Option<usize> {
187    if ca.null_count() == ca.len() {
188        return None;
189    }
190    match ca.is_sorted_flag() {
191        IsSorted::Ascending => ca.first_non_null(),
192        IsSorted::Descending => ca.last_non_null(),
193        IsSorted::Not => ca
194            .iter()
195            .enumerate()
196            .flat_map(|(idx, val)| val.map(|val| (idx, val)))
197            .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
198            .map(|tpl| tpl.0),
199    }
200}
201
202fn arg_max_str(ca: &StringChunked) -> Option<usize> {
203    if ca.null_count() == ca.len() {
204        return None;
205    }
206    match ca.is_sorted_flag() {
207        IsSorted::Ascending => ca.last_non_null(),
208        IsSorted::Descending => ca.first_non_null(),
209        IsSorted::Not => ca
210            .iter()
211            .enumerate()
212            .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
213            .map(|tpl| tpl.0),
214    }
215}
216
217fn arg_min_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
218where
219    T: PolarsNumericType,
220    for<'b> &'b [T::Native]: ArgMinMax,
221{
222    match ca.is_sorted_flag() {
223        IsSorted::Ascending => ca.first_non_null(),
224        IsSorted::Descending => ca.last_non_null(),
225        IsSorted::Not => {
226            ca.downcast_iter()
227                .fold((None, None, 0), |acc, arr| {
228                    if arr.len() == 0 {
229                        return acc;
230                    }
231                    let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 {
232                        arr.into_iter()
233                            .enumerate()
234                            .flat_map(|(idx, val)| val.map(|val| (idx, *val)))
235                            .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
236                    } else {
237                        // When no nulls & array not empty => we can use fast argminmax
238                        let min_idx: usize = arr.values().as_slice().argmin();
239                        Some((min_idx, arr.value(min_idx)))
240                    };
241
242                    let new_offset: usize = acc.2 + arr.len();
243                    match acc {
244                        (Some(_), Some(acc_v), offset) => match chunk_min {
245                            Some((idx, val)) if val < acc_v => {
246                                (Some(idx + offset), Some(val), new_offset)
247                            },
248                            _ => (acc.0, acc.1, new_offset),
249                        },
250                        (None, None, offset) => match chunk_min {
251                            Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
252                            None => (None, None, new_offset),
253                        },
254                        _ => unreachable!(),
255                    }
256                })
257                .0
258        },
259    }
260}
261
262fn arg_max_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
263where
264    T: PolarsNumericType,
265    for<'b> &'b [T::Native]: ArgMinMax,
266{
267    match ca.is_sorted_flag() {
268        IsSorted::Ascending => ca.last_non_null(),
269        IsSorted::Descending => ca.first_non_null(),
270        IsSorted::Not => {
271            ca.downcast_iter()
272                .fold((None, None, 0), |acc, arr| {
273                    if arr.len() == 0 {
274                        return acc;
275                    }
276                    let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 {
277                        // When there are nulls, we should compare Option<T::Native>
278                        arr.into_iter()
279                            .enumerate()
280                            .flat_map(|(idx, val)| val.map(|val| (idx, *val)))
281                            .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
282                    } else {
283                        // When no nulls & array not empty => we can use fast argminmax
284                        let max_idx: usize = arr.values().as_slice().argmax();
285                        Some((max_idx, arr.value(max_idx)))
286                    };
287
288                    let new_offset: usize = acc.2 + arr.len();
289                    match acc {
290                        (Some(_), Some(acc_v), offset) => match chunk_max {
291                            Some((idx, val)) if acc_v < val => {
292                                (Some(idx + offset), Some(val), new_offset)
293                            },
294                            _ => (acc.0, acc.1, new_offset),
295                        },
296                        (None, None, offset) => match chunk_max {
297                            Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
298                            None => (None, None, new_offset),
299                        },
300                        _ => unreachable!(),
301                    }
302                })
303                .0
304        },
305    }
306}
307
308fn arg_min_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
309where
310    for<'a> &'a [T]: ArgMinMax,
311{
312    match is_sorted {
313        // all vals are not null guarded by cont_slice
314        IsSorted::Ascending => Some(0),
315        // all vals are not null guarded by cont_slice
316        IsSorted::Descending => Some(vals.len() - 1),
317        IsSorted::Not => Some(vals.argmin()), // assumes not empty
318    }
319}
320
321fn arg_max_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
322where
323    for<'a> &'a [T]: ArgMinMax,
324{
325    match is_sorted {
326        // all vals are not null guarded by cont_slice
327        IsSorted::Ascending => Some(vals.len() - 1),
328        // all vals are not null guarded by cont_slice
329        IsSorted::Descending => Some(0),
330        IsSorted::Not => Some(vals.argmax()), // assumes not empty
331    }
332}