Skip to main content

polars_ops/series/ops/
arg_min_max.rs

1use polars_core::chunked_array::arg_min_max::{
2    arg_max_binary, arg_max_binary_offset, arg_max_bool, arg_max_numeric, arg_max_str,
3    arg_min_binary, arg_min_binary_offset, arg_min_bool, arg_min_numeric, arg_min_str,
4};
5#[cfg(feature = "dtype-categorical")]
6use polars_core::chunked_array::arg_min_max::{arg_max_cat, arg_min_cat};
7#[cfg(feature = "dtype-categorical")]
8use polars_core::with_match_categorical_physical_type;
9
10use super::*;
11
12/// Argmin/ Argmax
13pub trait ArgAgg {
14    /// Get the index of the minimal value
15    fn arg_min(&self) -> Option<usize>;
16    /// Get the index of the maximal value
17    fn arg_max(&self) -> Option<usize>;
18}
19
20macro_rules! with_match_physical_numeric_polars_type {(
21    $key_type:expr, | $_:tt $T:ident | $($body:tt)*
22) => ({
23    macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
24    use DataType::*;
25    match $key_type {
26        #[cfg(feature = "dtype-i8")]
27        Int8 => __with_ty__! { Int8Type },
28        #[cfg(feature = "dtype-i16")]
29        Int16 => __with_ty__! { Int16Type },
30        Int32 => __with_ty__! { Int32Type },
31        Int64 => __with_ty__! { Int64Type },
32        #[cfg(feature = "dtype-i128")]
33        Int128 => __with_ty__! { Int128Type },
34        #[cfg(feature = "dtype-u8")]
35        UInt8 => __with_ty__! { UInt8Type },
36        #[cfg(feature = "dtype-u16")]
37        UInt16 => __with_ty__! { UInt16Type },
38        UInt32 => __with_ty__! { UInt32Type },
39        UInt64 => __with_ty__! { UInt64Type },
40        #[cfg(feature = "dtype-u128")]
41        UInt128 => __with_ty__! { UInt128Type },
42        #[cfg(feature = "dtype-f16")]
43        Float16 => __with_ty__! { Float16Type },
44        Float32 => __with_ty__! { Float32Type },
45        Float64 => __with_ty__! { Float64Type },
46        dt => panic!("not implemented for dtype {:?}", dt),
47    }
48})}
49
50impl ArgAgg for Series {
51    fn arg_min(&self) -> Option<usize> {
52        use DataType::*;
53        let phys_s = self.to_physical_repr();
54        match self.dtype() {
55            #[cfg(feature = "dtype-categorical")]
56            Categorical(cats, _) => {
57                with_match_categorical_physical_type!(cats.physical(), |$C| {
58                    arg_min_cat(self.cat::<$C>().unwrap())
59                })
60            },
61            #[cfg(feature = "dtype-categorical")]
62            Enum(_, _) => phys_s.arg_min(),
63            #[cfg(feature = "dtype-decimal")]
64            Decimal(_, _) => phys_s.arg_min(),
65            Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_min(),
66            String => arg_min_str(self.str().unwrap()),
67            Binary => arg_min_binary(self.binary().unwrap()),
68            BinaryOffset => arg_min_binary_offset(self.binary_offset().unwrap()),
69            Boolean => arg_min_bool(self.bool().unwrap()),
70            dt if dt.is_primitive_numeric() => {
71                with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
72                    let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
73                    arg_min_numeric(ca)
74                })
75            },
76            dt if dt.is_nested() => self
77                .row_encode_ordered(false, false)
78                .ok()?
79                .into_series()
80                .arg_min(),
81            _ => None,
82        }
83    }
84
85    fn arg_max(&self) -> Option<usize> {
86        use DataType::*;
87        let phys_s = self.to_physical_repr();
88        match self.dtype() {
89            #[cfg(feature = "dtype-categorical")]
90            Categorical(cats, _) => {
91                with_match_categorical_physical_type!(cats.physical(), |$C| {
92                    arg_max_cat(self.cat::<$C>().unwrap())
93                })
94            },
95            #[cfg(feature = "dtype-categorical")]
96            Enum(_, _) => phys_s.arg_max(),
97            #[cfg(feature = "dtype-decimal")]
98            Decimal(_, _) => phys_s.arg_max(),
99            Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_max(),
100            String => arg_max_str(self.str().unwrap()),
101            Binary => arg_max_binary(self.binary().unwrap()),
102            BinaryOffset => arg_max_binary_offset(self.binary_offset().unwrap()),
103            Boolean => arg_max_bool(self.bool().unwrap()),
104            dt if dt.is_primitive_numeric() => {
105                with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
106                    let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
107                    arg_max_numeric(ca)
108                })
109            },
110            dt if dt.is_nested() => self
111                .row_encode_ordered(false, false)
112                .ok()?
113                .into_series()
114                .arg_max(),
115            _ => None,
116        }
117    }
118}