polars_ops/series/ops/
arg_min_max.rs

1use argminmax::ArgMinMax;
2use arrow::array::Array;
3use arrow::legacy::bit_util::*;
4use polars_core::chunked_array::ops::float_sorted_arg_max::{
5    float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
6};
7use polars_core::series::IsSorted;
8use polars_core::with_match_physical_numeric_polars_type;
9
10use super::*;
11
12/// Argmin/ Argmax
13pub trait ArgAgg {
14    /// Get the index of the minimal value
15    fn arg_min(&self) -> Option<usize>;
16    /// Get the index of the maximal value
17    fn arg_max(&self) -> Option<usize>;
18}
19
20macro_rules! with_match_physical_numeric_polars_type {(
21    $key_type:expr, | $_:tt $T:ident | $($body:tt)*
22) => ({
23    macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
24    use DataType::*;
25    match $key_type {
26            #[cfg(feature = "dtype-i8")]
27        Int8 => __with_ty__! { Int8Type },
28            #[cfg(feature = "dtype-i16")]
29        Int16 => __with_ty__! { Int16Type },
30        Int32 => __with_ty__! { Int32Type },
31        Int64 => __with_ty__! { Int64Type },
32            #[cfg(feature = "dtype-u8")]
33        UInt8 => __with_ty__! { UInt8Type },
34            #[cfg(feature = "dtype-u16")]
35        UInt16 => __with_ty__! { UInt16Type },
36        UInt32 => __with_ty__! { UInt32Type },
37        UInt64 => __with_ty__! { UInt64Type },
38        Float32 => __with_ty__! { Float32Type },
39        Float64 => __with_ty__! { Float64Type },
40        dt => panic!("not implemented for dtype {:?}", dt),
41    }
42})}
43
44impl ArgAgg for Series {
45    fn arg_min(&self) -> Option<usize> {
46        use DataType::*;
47        let s = self.to_physical_repr();
48        match self.dtype() {
49            #[cfg(feature = "dtype-categorical")]
50            Categorical(_, _) => {
51                let ca = self.categorical().unwrap();
52                if ca.null_count() == ca.len() {
53                    return None;
54                }
55                if ca.uses_lexical_ordering() {
56                    ca.iter_str()
57                        .enumerate()
58                        .flat_map(|(idx, val)| val.map(|val| (idx, val)))
59                        .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
60                        .map(|tpl| tpl.0)
61                } else {
62                    let ca = s.u32().unwrap();
63                    arg_min_numeric_dispatch(ca)
64                }
65            },
66            String => {
67                let ca = self.str().unwrap();
68                arg_min_str(ca)
69            },
70            Boolean => {
71                let ca = self.bool().unwrap();
72                arg_min_bool(ca)
73            },
74            Date => {
75                let ca = s.i32().unwrap();
76                arg_min_numeric_dispatch(ca)
77            },
78            Datetime(_, _) | Duration(_) | Time => {
79                let ca = s.i64().unwrap();
80                arg_min_numeric_dispatch(ca)
81            },
82            dt if dt.is_primitive_numeric() => {
83                with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
84                    let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
85                    arg_min_numeric_dispatch(ca)
86                })
87            },
88            _ => None,
89        }
90    }
91
92    fn arg_max(&self) -> Option<usize> {
93        use DataType::*;
94        let s = self.to_physical_repr();
95        match self.dtype() {
96            #[cfg(feature = "dtype-categorical")]
97            Categorical(_, _) => {
98                let ca = self.categorical().unwrap();
99                if ca.null_count() == ca.len() {
100                    return None;
101                }
102                if ca.uses_lexical_ordering() {
103                    ca.iter_str()
104                        .enumerate()
105                        .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
106                        .map(|tpl| tpl.0)
107                } else {
108                    let ca_phys = s.u32().unwrap();
109                    arg_max_numeric_dispatch(ca_phys)
110                }
111            },
112            String => {
113                let ca = self.str().unwrap();
114                arg_max_str(ca)
115            },
116            Boolean => {
117                let ca = self.bool().unwrap();
118                arg_max_bool(ca)
119            },
120            Date => {
121                let ca = s.i32().unwrap();
122                arg_max_numeric_dispatch(ca)
123            },
124            Datetime(_, _) | Duration(_) | Time => {
125                let ca = s.i64().unwrap();
126                arg_max_numeric_dispatch(ca)
127            },
128            dt if dt.is_primitive_numeric() => {
129                with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
130                    let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
131                    arg_max_numeric_dispatch(ca)
132                })
133            },
134            _ => None,
135        }
136    }
137}
138
139fn arg_max_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
140where
141    T: PolarsNumericType,
142    for<'b> &'b [T::Native]: ArgMinMax,
143{
144    if ca.null_count() == ca.len() {
145        None
146    } else if T::get_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) {
147        arg_max_float_sorted(ca)
148    } else if let Ok(vals) = ca.cont_slice() {
149        arg_max_numeric_slice(vals, ca.is_sorted_flag())
150    } else {
151        arg_max_numeric(ca)
152    }
153}
154
155fn arg_min_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
156where
157    T: PolarsNumericType,
158    for<'b> &'b [T::Native]: ArgMinMax,
159{
160    if ca.null_count() == ca.len() {
161        None
162    } else if let Ok(vals) = ca.cont_slice() {
163        arg_min_numeric_slice(vals, ca.is_sorted_flag())
164    } else {
165        arg_min_numeric(ca)
166    }
167}
168
169pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
170    if ca.null_count() == ca.len() {
171        None
172    }
173    // don't check for any, that on itself is already an argmax search
174    else if ca.null_count() == 0 && ca.chunks().len() == 1 {
175        let arr = ca.downcast_iter().next().unwrap();
176        let mask = arr.values();
177        Some(first_set_bit(mask))
178    } else {
179        let mut first_false_idx: Option<usize> = None;
180        ca.iter()
181            .enumerate()
182            .find_map(|(idx, val)| match val {
183                Some(true) => Some(idx),
184                Some(false) if first_false_idx.is_none() => {
185                    first_false_idx = Some(idx);
186                    None
187                },
188                _ => None,
189            })
190            .or(first_false_idx)
191    }
192}
193
194/// # Safety
195/// `ca` has a float dtype, has at least one non-null value and is sorted.
196fn arg_max_float_sorted<T>(ca: &ChunkedArray<T>) -> Option<usize>
197where
198    T: PolarsNumericType,
199{
200    let out = match ca.is_sorted_flag() {
201        IsSorted::Ascending => float_arg_max_sorted_ascending(ca),
202        IsSorted::Descending => float_arg_max_sorted_descending(ca),
203        _ => unreachable!(),
204    };
205
206    Some(out)
207}
208
209fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {
210    if ca.null_count() == ca.len() {
211        None
212    } else if ca.null_count() == 0 && ca.chunks().len() == 1 {
213        let arr = ca.downcast_iter().next().unwrap();
214        let mask = arr.values();
215        Some(first_unset_bit(mask))
216    } else {
217        let mut first_true_idx: Option<usize> = None;
218        ca.iter()
219            .enumerate()
220            .find_map(|(idx, val)| match val {
221                Some(false) => Some(idx),
222                Some(true) if first_true_idx.is_none() => {
223                    first_true_idx = Some(idx);
224                    None
225                },
226                _ => None,
227            })
228            .or(first_true_idx)
229    }
230}
231
232fn arg_min_str(ca: &StringChunked) -> Option<usize> {
233    if ca.null_count() == ca.len() {
234        return None;
235    }
236    match ca.is_sorted_flag() {
237        IsSorted::Ascending => ca.first_non_null(),
238        IsSorted::Descending => ca.last_non_null(),
239        IsSorted::Not => ca
240            .iter()
241            .enumerate()
242            .flat_map(|(idx, val)| val.map(|val| (idx, val)))
243            .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
244            .map(|tpl| tpl.0),
245    }
246}
247
248fn arg_max_str(ca: &StringChunked) -> Option<usize> {
249    if ca.null_count() == ca.len() {
250        return None;
251    }
252    match ca.is_sorted_flag() {
253        IsSorted::Ascending => ca.last_non_null(),
254        IsSorted::Descending => ca.first_non_null(),
255        IsSorted::Not => ca
256            .iter()
257            .enumerate()
258            .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
259            .map(|tpl| tpl.0),
260    }
261}
262
263fn arg_min_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
264where
265    T: PolarsNumericType,
266    for<'b> &'b [T::Native]: ArgMinMax,
267{
268    match ca.is_sorted_flag() {
269        IsSorted::Ascending => ca.first_non_null(),
270        IsSorted::Descending => ca.last_non_null(),
271        IsSorted::Not => {
272            ca.downcast_iter()
273                .fold((None, None, 0), |acc, arr| {
274                    if arr.len() == 0 {
275                        return acc;
276                    }
277                    let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 {
278                        arr.into_iter()
279                            .enumerate()
280                            .flat_map(|(idx, val)| val.map(|val| (idx, *val)))
281                            .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
282                    } else {
283                        // When no nulls & array not empty => we can use fast argminmax
284                        let min_idx: usize = arr.values().as_slice().argmin();
285                        Some((min_idx, arr.value(min_idx)))
286                    };
287
288                    let new_offset: usize = acc.2 + arr.len();
289                    match acc {
290                        (Some(_), Some(acc_v), offset) => match chunk_min {
291                            Some((idx, val)) if val < acc_v => {
292                                (Some(idx + offset), Some(val), new_offset)
293                            },
294                            _ => (acc.0, acc.1, new_offset),
295                        },
296                        (None, None, offset) => match chunk_min {
297                            Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
298                            None => (None, None, new_offset),
299                        },
300                        _ => unreachable!(),
301                    }
302                })
303                .0
304        },
305    }
306}
307
308fn arg_max_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
309where
310    T: PolarsNumericType,
311    for<'b> &'b [T::Native]: ArgMinMax,
312{
313    match ca.is_sorted_flag() {
314        IsSorted::Ascending => ca.last_non_null(),
315        IsSorted::Descending => ca.first_non_null(),
316        IsSorted::Not => {
317            ca.downcast_iter()
318                .fold((None, None, 0), |acc, arr| {
319                    if arr.len() == 0 {
320                        return acc;
321                    }
322                    let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 {
323                        // When there are nulls, we should compare Option<T::Native>
324                        arr.into_iter()
325                            .enumerate()
326                            .flat_map(|(idx, val)| val.map(|val| (idx, *val)))
327                            .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
328                    } else {
329                        // When no nulls & array not empty => we can use fast argminmax
330                        let max_idx: usize = arr.values().as_slice().argmax();
331                        Some((max_idx, arr.value(max_idx)))
332                    };
333
334                    let new_offset: usize = acc.2 + arr.len();
335                    match acc {
336                        (Some(_), Some(acc_v), offset) => match chunk_max {
337                            Some((idx, val)) if acc_v < val => {
338                                (Some(idx + offset), Some(val), new_offset)
339                            },
340                            _ => (acc.0, acc.1, new_offset),
341                        },
342                        (None, None, offset) => match chunk_max {
343                            Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
344                            None => (None, None, new_offset),
345                        },
346                        _ => unreachable!(),
347                    }
348                })
349                .0
350        },
351    }
352}
353
354fn arg_min_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
355where
356    for<'a> &'a [T]: ArgMinMax,
357{
358    match is_sorted {
359        // all vals are not null guarded by cont_slice
360        IsSorted::Ascending => Some(0),
361        // all vals are not null guarded by cont_slice
362        IsSorted::Descending => Some(vals.len() - 1),
363        IsSorted::Not => Some(vals.argmin()), // assumes not empty
364    }
365}
366
367fn arg_max_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
368where
369    for<'a> &'a [T]: ArgMinMax,
370{
371    match is_sorted {
372        // all vals are not null guarded by cont_slice
373        IsSorted::Ascending => Some(vals.len() - 1),
374        // all vals are not null guarded by cont_slice
375        IsSorted::Descending => Some(0),
376        IsSorted::Not => Some(vals.argmax()), // assumes not empty
377    }
378}