polars_ops/series/ops/
arg_min_max.rs

1use argminmax::ArgMinMax;
2use arrow::array::Array;
3use polars_core::chunked_array::ops::float_sorted_arg_max::{
4    float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
5};
6use polars_core::series::IsSorted;
7use polars_core::with_match_categorical_physical_type;
8
9use super::*;
10
11/// Argmin/ Argmax
12pub trait ArgAgg {
13    /// Get the index of the minimal value
14    fn arg_min(&self) -> Option<usize>;
15    /// Get the index of the maximal value
16    fn arg_max(&self) -> Option<usize>;
17}
18
19macro_rules! with_match_physical_numeric_polars_type {(
20    $key_type:expr, | $_:tt $T:ident | $($body:tt)*
21) => ({
22    macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
23    use DataType::*;
24    match $key_type {
25            #[cfg(feature = "dtype-i8")]
26        Int8 => __with_ty__! { Int8Type },
27            #[cfg(feature = "dtype-i16")]
28        Int16 => __with_ty__! { Int16Type },
29        Int32 => __with_ty__! { Int32Type },
30        Int64 => __with_ty__! { Int64Type },
31            #[cfg(feature = "dtype-u8")]
32        UInt8 => __with_ty__! { UInt8Type },
33            #[cfg(feature = "dtype-u16")]
34        UInt16 => __with_ty__! { UInt16Type },
35        UInt32 => __with_ty__! { UInt32Type },
36        UInt64 => __with_ty__! { UInt64Type },
37        Float32 => __with_ty__! { Float32Type },
38        Float64 => __with_ty__! { Float64Type },
39        dt => panic!("not implemented for dtype {:?}", dt),
40    }
41})}
42
43impl ArgAgg for Series {
44    fn arg_min(&self) -> Option<usize> {
45        use DataType::*;
46        let phys_s = self.to_physical_repr();
47        match self.dtype() {
48            #[cfg(feature = "dtype-categorical")]
49            Categorical(cats, _) => {
50                with_match_categorical_physical_type!(cats.physical(), |$C| {
51                    let ca = self.cat::<$C>().unwrap();
52                    if ca.null_count() == ca.len() {
53                        return None;
54                    }
55                    ca.iter_str()
56                        .enumerate()
57                        .flat_map(|(idx, val)| val.map(|val| (idx, val)))
58                        .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
59                        .map(|tpl| tpl.0)
60                })
61            },
62            #[cfg(feature = "dtype-categorical")]
63            Enum(_, _) => phys_s.arg_min(),
64            Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_min(),
65            String => {
66                let ca = self.str().unwrap();
67                arg_min_str(ca)
68            },
69            Boolean => {
70                let ca = self.bool().unwrap();
71                arg_min_bool(ca)
72            },
73            dt if dt.is_primitive_numeric() => {
74                with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
75                    let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
76                    arg_min_numeric_dispatch(ca)
77                })
78            },
79            _ => None,
80        }
81    }
82
83    fn arg_max(&self) -> Option<usize> {
84        use DataType::*;
85        let phys_s = self.to_physical_repr();
86        match self.dtype() {
87            #[cfg(feature = "dtype-categorical")]
88            Categorical(cats, _) => {
89                with_match_categorical_physical_type!(cats.physical(), |$C| {
90                    let ca = self.cat::<$C>().unwrap();
91                    if ca.null_count() == ca.len() {
92                        return None;
93                    }
94                    ca.iter_str()
95                        .enumerate()
96                        .flat_map(|(idx, val)| val.map(|val| (idx, val)))
97                        .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
98                        .map(|tpl| tpl.0)
99                })
100            },
101            #[cfg(feature = "dtype-categorical")]
102            Enum(_, _) => phys_s.arg_max(),
103            Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_max(),
104            String => {
105                let ca = self.str().unwrap();
106                arg_max_str(ca)
107            },
108            Boolean => {
109                let ca = self.bool().unwrap();
110                arg_max_bool(ca)
111            },
112            dt if dt.is_primitive_numeric() => {
113                with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
114                    let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
115                    arg_max_numeric_dispatch(ca)
116                })
117            },
118            _ => None,
119        }
120    }
121}
122
123fn arg_max_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
124where
125    T: PolarsNumericType,
126    for<'b> &'b [T::Native]: ArgMinMax,
127{
128    if ca.null_count() == ca.len() {
129        None
130    } else if T::get_static_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) {
131        arg_max_float_sorted(ca)
132    } else if let Ok(vals) = ca.cont_slice() {
133        arg_max_numeric_slice(vals, ca.is_sorted_flag())
134    } else {
135        arg_max_numeric(ca)
136    }
137}
138
139fn arg_min_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
140where
141    T: PolarsNumericType,
142    for<'b> &'b [T::Native]: ArgMinMax,
143{
144    if ca.null_count() == ca.len() {
145        None
146    } else if let Ok(vals) = ca.cont_slice() {
147        arg_min_numeric_slice(vals, ca.is_sorted_flag())
148    } else {
149        arg_min_numeric(ca)
150    }
151}
152
153fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
154    ca.first_true_idx().or_else(|| ca.first_false_idx())
155}
156
157/// # Safety
158/// `ca` has a float dtype, has at least one non-null value and is sorted.
159fn arg_max_float_sorted<T>(ca: &ChunkedArray<T>) -> Option<usize>
160where
161    T: PolarsNumericType,
162{
163    let out = match ca.is_sorted_flag() {
164        IsSorted::Ascending => float_arg_max_sorted_ascending(ca),
165        IsSorted::Descending => float_arg_max_sorted_descending(ca),
166        _ => unreachable!(),
167    };
168    Some(out)
169}
170
171fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {
172    ca.first_false_idx().or_else(|| ca.first_true_idx())
173}
174
175fn arg_min_str(ca: &StringChunked) -> Option<usize> {
176    if ca.null_count() == ca.len() {
177        return None;
178    }
179    match ca.is_sorted_flag() {
180        IsSorted::Ascending => ca.first_non_null(),
181        IsSorted::Descending => ca.last_non_null(),
182        IsSorted::Not => ca
183            .iter()
184            .enumerate()
185            .flat_map(|(idx, val)| val.map(|val| (idx, val)))
186            .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
187            .map(|tpl| tpl.0),
188    }
189}
190
191fn arg_max_str(ca: &StringChunked) -> Option<usize> {
192    if ca.null_count() == ca.len() {
193        return None;
194    }
195    match ca.is_sorted_flag() {
196        IsSorted::Ascending => ca.last_non_null(),
197        IsSorted::Descending => ca.first_non_null(),
198        IsSorted::Not => ca
199            .iter()
200            .enumerate()
201            .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
202            .map(|tpl| tpl.0),
203    }
204}
205
206fn arg_min_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
207where
208    T: PolarsNumericType,
209    for<'b> &'b [T::Native]: ArgMinMax,
210{
211    match ca.is_sorted_flag() {
212        IsSorted::Ascending => ca.first_non_null(),
213        IsSorted::Descending => ca.last_non_null(),
214        IsSorted::Not => {
215            ca.downcast_iter()
216                .fold((None, None, 0), |acc, arr| {
217                    if arr.len() == 0 {
218                        return acc;
219                    }
220                    let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 {
221                        arr.into_iter()
222                            .enumerate()
223                            .flat_map(|(idx, val)| val.map(|val| (idx, *val)))
224                            .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
225                    } else {
226                        // When no nulls & array not empty => we can use fast argminmax
227                        let min_idx: usize = arr.values().as_slice().argmin();
228                        Some((min_idx, arr.value(min_idx)))
229                    };
230
231                    let new_offset: usize = acc.2 + arr.len();
232                    match acc {
233                        (Some(_), Some(acc_v), offset) => match chunk_min {
234                            Some((idx, val)) if val < acc_v => {
235                                (Some(idx + offset), Some(val), new_offset)
236                            },
237                            _ => (acc.0, acc.1, new_offset),
238                        },
239                        (None, None, offset) => match chunk_min {
240                            Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
241                            None => (None, None, new_offset),
242                        },
243                        _ => unreachable!(),
244                    }
245                })
246                .0
247        },
248    }
249}
250
251fn arg_max_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
252where
253    T: PolarsNumericType,
254    for<'b> &'b [T::Native]: ArgMinMax,
255{
256    match ca.is_sorted_flag() {
257        IsSorted::Ascending => ca.last_non_null(),
258        IsSorted::Descending => ca.first_non_null(),
259        IsSorted::Not => {
260            ca.downcast_iter()
261                .fold((None, None, 0), |acc, arr| {
262                    if arr.len() == 0 {
263                        return acc;
264                    }
265                    let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 {
266                        // When there are nulls, we should compare Option<T::Native>
267                        arr.into_iter()
268                            .enumerate()
269                            .flat_map(|(idx, val)| val.map(|val| (idx, *val)))
270                            .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
271                    } else {
272                        // When no nulls & array not empty => we can use fast argminmax
273                        let max_idx: usize = arr.values().as_slice().argmax();
274                        Some((max_idx, arr.value(max_idx)))
275                    };
276
277                    let new_offset: usize = acc.2 + arr.len();
278                    match acc {
279                        (Some(_), Some(acc_v), offset) => match chunk_max {
280                            Some((idx, val)) if acc_v < val => {
281                                (Some(idx + offset), Some(val), new_offset)
282                            },
283                            _ => (acc.0, acc.1, new_offset),
284                        },
285                        (None, None, offset) => match chunk_max {
286                            Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
287                            None => (None, None, new_offset),
288                        },
289                        _ => unreachable!(),
290                    }
291                })
292                .0
293        },
294    }
295}
296
297fn arg_min_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
298where
299    for<'a> &'a [T]: ArgMinMax,
300{
301    match is_sorted {
302        // all vals are not null guarded by cont_slice
303        IsSorted::Ascending => Some(0),
304        // all vals are not null guarded by cont_slice
305        IsSorted::Descending => Some(vals.len() - 1),
306        IsSorted::Not => Some(vals.argmin()), // assumes not empty
307    }
308}
309
310fn arg_max_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
311where
312    for<'a> &'a [T]: ArgMinMax,
313{
314    match is_sorted {
315        // all vals are not null guarded by cont_slice
316        IsSorted::Ascending => Some(vals.len() - 1),
317        // all vals are not null guarded by cont_slice
318        IsSorted::Descending => Some(0),
319        IsSorted::Not => Some(vals.argmax()), // assumes not empty
320    }
321}