polars_ops/series/ops/
arg_min_max.rs

1use arrow::array::Array;
2use polars_core::chunked_array::ops::float_sorted_arg_max::{
3    float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
4};
5use polars_core::series::IsSorted;
6use polars_core::with_match_categorical_physical_type;
7use polars_utils::arg_min_max::ArgMinMax;
8use polars_utils::min_max::{MaxIgnoreNan, MinIgnoreNan, MinMaxPolicy};
9
10pub fn arg_min_opt_iter<T, I>(iter: I) -> Option<usize>
11where
12    I: IntoIterator<Item = Option<T>>,
13    T: Ord,
14{
15    iter.into_iter()
16        .enumerate()
17        .flat_map(|(idx, val)| Some((idx, val?)))
18        .min_by(|x, y| Ord::cmp(&x.1, &y.1))
19        .map(|x| x.0)
20}
21
22pub fn arg_max_opt_iter<T, I>(iter: I) -> Option<usize>
23where
24    I: IntoIterator<Item = Option<T>>,
25    T: Ord,
26{
27    iter.into_iter()
28        .enumerate()
29        .flat_map(|(idx, val)| Some((idx, val?)))
30        .max_by(|x, y| Ord::cmp(&x.1, &y.1))
31        .map(|x| x.0)
32}
33
34use super::*;
35
36/// Argmin/ Argmax
37pub trait ArgAgg {
38    /// Get the index of the minimal value
39    fn arg_min(&self) -> Option<usize>;
40    /// Get the index of the maximal value
41    fn arg_max(&self) -> Option<usize>;
42}
43
44macro_rules! with_match_physical_numeric_polars_type {(
45    $key_type:expr, | $_:tt $T:ident | $($body:tt)*
46) => ({
47    macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
48    use DataType::*;
49    match $key_type {
50        #[cfg(feature = "dtype-i8")]
51        Int8 => __with_ty__! { Int8Type },
52        #[cfg(feature = "dtype-i16")]
53        Int16 => __with_ty__! { Int16Type },
54        Int32 => __with_ty__! { Int32Type },
55        Int64 => __with_ty__! { Int64Type },
56        #[cfg(feature = "dtype-i128")]
57        Int128 => __with_ty__! { Int128Type },
58        #[cfg(feature = "dtype-u8")]
59        UInt8 => __with_ty__! { UInt8Type },
60        #[cfg(feature = "dtype-u16")]
61        UInt16 => __with_ty__! { UInt16Type },
62        UInt32 => __with_ty__! { UInt32Type },
63        UInt64 => __with_ty__! { UInt64Type },
64        #[cfg(feature = "dtype-u128")]
65        UInt128 => __with_ty__! { UInt128Type },
66        #[cfg(feature = "dtype-f16")]
67        Float16 => __with_ty__! { Float16Type },
68        Float32 => __with_ty__! { Float32Type },
69        Float64 => __with_ty__! { Float64Type },
70        dt => panic!("not implemented for dtype {:?}", dt),
71    }
72})}
73
74impl ArgAgg for Series {
75    fn arg_min(&self) -> Option<usize> {
76        use DataType::*;
77        let phys_s = self.to_physical_repr();
78        match self.dtype() {
79            #[cfg(feature = "dtype-categorical")]
80            Categorical(cats, _) => {
81                with_match_categorical_physical_type!(cats.physical(), |$C| {
82                    arg_min_cat(self.cat::<$C>().unwrap())
83                })
84            },
85            #[cfg(feature = "dtype-categorical")]
86            Enum(_, _) => phys_s.arg_min(),
87            #[cfg(feature = "dtype-decimal")]
88            Decimal(_, _) => phys_s.arg_min(),
89            Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_min(),
90            String => arg_min_str(self.str().unwrap()),
91            Binary => arg_min_binary(self.binary().unwrap()),
92            Boolean => arg_min_bool(self.bool().unwrap()),
93            dt if dt.is_primitive_numeric() => {
94                with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
95                    let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
96                    arg_min_numeric(ca)
97                })
98            },
99            _ => None,
100        }
101    }
102
103    fn arg_max(&self) -> Option<usize> {
104        use DataType::*;
105        let phys_s = self.to_physical_repr();
106        match self.dtype() {
107            #[cfg(feature = "dtype-categorical")]
108            Categorical(cats, _) => {
109                with_match_categorical_physical_type!(cats.physical(), |$C| {
110                    arg_max_cat(self.cat::<$C>().unwrap())
111                })
112            },
113            #[cfg(feature = "dtype-categorical")]
114            Enum(_, _) => phys_s.arg_max(),
115            #[cfg(feature = "dtype-decimal")]
116            Decimal(_, _) => phys_s.arg_max(),
117            Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_max(),
118            String => arg_max_str(self.str().unwrap()),
119            Binary => arg_max_binary(self.binary().unwrap()),
120            Boolean => arg_max_bool(self.bool().unwrap()),
121            dt if dt.is_primitive_numeric() => {
122                with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
123                    let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
124                    arg_max_numeric(ca)
125                })
126            },
127            _ => None,
128        }
129    }
130}
131
132pub fn arg_min_numeric<T>(ca: &ChunkedArray<T>) -> Option<usize>
133where
134    T: PolarsNumericType,
135    for<'b> &'b [T::Native]: ArgMinMax,
136{
137    if ca.null_count() == ca.len() {
138        None
139    } else if let Ok(vals) = ca.cont_slice() {
140        arg_min_numeric_slice(vals, ca.is_sorted_flag())
141    } else {
142        arg_min_numeric_chunked(ca)
143    }
144}
145
146pub fn arg_max_numeric<T>(ca: &ChunkedArray<T>) -> Option<usize>
147where
148    T: PolarsNumericType,
149    for<'b> &'b [T::Native]: ArgMinMax,
150{
151    if ca.null_count() == ca.len() {
152        None
153    } else if T::get_static_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) {
154        arg_max_float_sorted(ca)
155    } else if let Ok(vals) = ca.cont_slice() {
156        arg_max_numeric_slice(vals, ca.is_sorted_flag())
157    } else {
158        arg_max_numeric_chunked(ca)
159    }
160}
161
162/// # Safety
163/// `ca` has a float dtype, has at least one non-null value and is sorted.
164fn arg_max_float_sorted<T>(ca: &ChunkedArray<T>) -> Option<usize>
165where
166    T: PolarsNumericType,
167{
168    let out = match ca.is_sorted_flag() {
169        IsSorted::Ascending => float_arg_max_sorted_ascending(ca),
170        IsSorted::Descending => float_arg_max_sorted_descending(ca),
171        _ => unreachable!(),
172    };
173    Some(out)
174}
175
176#[cfg(feature = "dtype-categorical")]
177pub fn arg_min_cat<T: PolarsCategoricalType>(ca: &CategoricalChunked<T>) -> Option<usize> {
178    if ca.null_count() == ca.len() {
179        return None;
180    }
181    arg_min_opt_iter(ca.iter_str())
182}
183
184#[cfg(feature = "dtype-categorical")]
185pub fn arg_max_cat<T: PolarsCategoricalType>(ca: &CategoricalChunked<T>) -> Option<usize> {
186    if ca.null_count() == ca.len() {
187        return None;
188    }
189    arg_max_opt_iter(ca.iter_str())
190}
191
192pub fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {
193    ca.first_false_idx().or_else(|| ca.first_true_idx())
194}
195
196pub fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
197    ca.first_true_idx().or_else(|| ca.first_false_idx())
198}
199
200pub fn arg_min_str(ca: &StringChunked) -> Option<usize> {
201    arg_min_physical_generic(ca)
202}
203
204pub fn arg_max_str(ca: &StringChunked) -> Option<usize> {
205    arg_max_physical_generic(ca)
206}
207
208pub fn arg_min_binary(ca: &BinaryChunked) -> Option<usize> {
209    arg_min_physical_generic(ca)
210}
211
212pub fn arg_max_binary(ca: &BinaryChunked) -> Option<usize> {
213    arg_max_physical_generic(ca)
214}
215
216fn arg_min_physical_generic<T>(ca: &ChunkedArray<T>) -> Option<usize>
217where
218    T: PolarsDataType,
219    for<'a> T::Physical<'a>: Ord,
220{
221    if ca.null_count() == ca.len() {
222        return None;
223    }
224    match ca.is_sorted_flag() {
225        IsSorted::Ascending => ca.first_non_null(),
226        IsSorted::Descending => ca.last_non_null(),
227        IsSorted::Not => arg_min_opt_iter(ca.iter()),
228    }
229}
230
231fn arg_max_physical_generic<T>(ca: &ChunkedArray<T>) -> Option<usize>
232where
233    T: PolarsDataType,
234    for<'a> T::Physical<'a>: Ord,
235{
236    if ca.null_count() == ca.len() {
237        return None;
238    }
239    match ca.is_sorted_flag() {
240        IsSorted::Ascending => ca.last_non_null(),
241        IsSorted::Descending => ca.first_non_null(),
242        IsSorted::Not => arg_max_opt_iter(ca.iter()),
243    }
244}
245
246fn arg_min_numeric_chunked<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
247where
248    T: PolarsNumericType,
249    for<'b> &'b [T::Native]: ArgMinMax,
250{
251    match ca.is_sorted_flag() {
252        IsSorted::Ascending => ca.first_non_null(),
253        IsSorted::Descending => ca.last_non_null(),
254        IsSorted::Not => {
255            let mut chunk_start_offset = 0;
256            let mut min_idx: Option<usize> = None;
257            let mut min_val: Option<T::Native> = None;
258            for arr in ca.downcast_iter() {
259                if arr.len() == arr.null_count() {
260                    chunk_start_offset += arr.len();
261                    continue;
262                }
263
264                let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 {
265                    arr.into_iter()
266                        .enumerate()
267                        .flat_map(|(idx, val)| Some((idx, *(val?))))
268                        .reduce(|acc, (idx, val)| {
269                            if MinIgnoreNan::is_better(&val, &acc.1) {
270                                (idx, val)
271                            } else {
272                                acc
273                            }
274                        })
275                } else {
276                    // When no nulls & array not empty => we can use fast argmin.
277                    let min_idx: usize = arr.values().as_slice().argmin();
278                    Some((min_idx, arr.value(min_idx)))
279                };
280
281                if let Some((chunk_min_idx, chunk_min_val)) = chunk_min {
282                    if min_val.is_none()
283                        || MinIgnoreNan::is_better(&chunk_min_val, &min_val.unwrap())
284                    {
285                        min_val = Some(chunk_min_val);
286                        min_idx = Some(chunk_start_offset + chunk_min_idx);
287                    }
288                }
289                chunk_start_offset += arr.len();
290            }
291            min_idx
292        },
293    }
294}
295
296fn arg_max_numeric_chunked<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
297where
298    T: PolarsNumericType,
299    for<'b> &'b [T::Native]: ArgMinMax,
300{
301    match ca.is_sorted_flag() {
302        IsSorted::Ascending => ca.last_non_null(),
303        IsSorted::Descending => ca.first_non_null(),
304        IsSorted::Not => {
305            let mut chunk_start_offset = 0;
306            let mut max_idx: Option<usize> = None;
307            let mut max_val: Option<T::Native> = None;
308            for arr in ca.downcast_iter() {
309                if arr.len() == arr.null_count() {
310                    chunk_start_offset += arr.len();
311                    continue;
312                }
313
314                let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 {
315                    arr.into_iter()
316                        .enumerate()
317                        .flat_map(|(idx, val)| Some((idx, *(val?))))
318                        .reduce(|acc, (idx, val)| {
319                            if MaxIgnoreNan::is_better(&val, &acc.1) {
320                                (idx, val)
321                            } else {
322                                acc
323                            }
324                        })
325                } else {
326                    // When no nulls & array not empty => we can use fast argmax.
327                    let max_idx: usize = arr.values().as_slice().argmax();
328                    Some((max_idx, arr.value(max_idx)))
329                };
330
331                if let Some((chunk_max_idx, chunk_max_val)) = chunk_max {
332                    if max_val.is_none()
333                        || MaxIgnoreNan::is_better(&chunk_max_val, &max_val.unwrap())
334                    {
335                        max_val = Some(chunk_max_val);
336                        max_idx = Some(chunk_start_offset + chunk_max_idx);
337                    }
338                }
339                chunk_start_offset += arr.len();
340            }
341            max_idx
342        },
343    }
344}
345
346fn arg_min_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
347where
348    for<'a> &'a [T]: ArgMinMax,
349{
350    match is_sorted {
351        // all vals are not null guarded by cont_slice
352        IsSorted::Ascending => Some(0),
353        // all vals are not null guarded by cont_slice
354        IsSorted::Descending => Some(vals.len() - 1),
355        IsSorted::Not => Some(vals.argmin()), // assumes not empty
356    }
357}
358
359fn arg_max_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
360where
361    for<'a> &'a [T]: ArgMinMax,
362{
363    match is_sorted {
364        // all vals are not null guarded by cont_slice
365        IsSorted::Ascending => Some(vals.len() - 1),
366        // all vals are not null guarded by cont_slice
367        IsSorted::Descending => Some(0),
368        IsSorted::Not => Some(vals.argmax()), // assumes not empty
369    }
370}