1use argminmax::ArgMinMax;
2use arrow::array::Array;
3use polars_core::chunked_array::ops::float_sorted_arg_max::{
4 float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
5};
6use polars_core::series::IsSorted;
7use polars_core::with_match_categorical_physical_type;
8
9use super::*;
10
11pub trait ArgAgg {
13 fn arg_min(&self) -> Option<usize>;
15 fn arg_max(&self) -> Option<usize>;
17}
18
19macro_rules! with_match_physical_numeric_polars_type {(
20 $key_type:expr, | $_:tt $T:ident | $($body:tt)*
21) => ({
22 macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
23 use DataType::*;
24 match $key_type {
25 #[cfg(feature = "dtype-i8")]
26 Int8 => __with_ty__! { Int8Type },
27 #[cfg(feature = "dtype-i16")]
28 Int16 => __with_ty__! { Int16Type },
29 Int32 => __with_ty__! { Int32Type },
30 Int64 => __with_ty__! { Int64Type },
31 #[cfg(feature = "dtype-u8")]
32 UInt8 => __with_ty__! { UInt8Type },
33 #[cfg(feature = "dtype-u16")]
34 UInt16 => __with_ty__! { UInt16Type },
35 UInt32 => __with_ty__! { UInt32Type },
36 UInt64 => __with_ty__! { UInt64Type },
37 Float32 => __with_ty__! { Float32Type },
38 Float64 => __with_ty__! { Float64Type },
39 dt => panic!("not implemented for dtype {:?}", dt),
40 }
41})}
42
43impl ArgAgg for Series {
44 fn arg_min(&self) -> Option<usize> {
45 use DataType::*;
46 let phys_s = self.to_physical_repr();
47 match self.dtype() {
48 #[cfg(feature = "dtype-categorical")]
49 Categorical(cats, _) => {
50 with_match_categorical_physical_type!(cats.physical(), |$C| {
51 let ca = self.cat::<$C>().unwrap();
52 if ca.null_count() == ca.len() {
53 return None;
54 }
55 ca.iter_str()
56 .enumerate()
57 .flat_map(|(idx, val)| val.map(|val| (idx, val)))
58 .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
59 .map(|tpl| tpl.0)
60 })
61 },
62 #[cfg(feature = "dtype-categorical")]
63 Enum(_, _) => phys_s.arg_min(),
64 Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_min(),
65 String => {
66 let ca = self.str().unwrap();
67 arg_min_str(ca)
68 },
69 Boolean => {
70 let ca = self.bool().unwrap();
71 arg_min_bool(ca)
72 },
73 dt if dt.is_primitive_numeric() => {
74 with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
75 let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
76 arg_min_numeric_dispatch(ca)
77 })
78 },
79 _ => None,
80 }
81 }
82
83 fn arg_max(&self) -> Option<usize> {
84 use DataType::*;
85 let phys_s = self.to_physical_repr();
86 match self.dtype() {
87 #[cfg(feature = "dtype-categorical")]
88 Categorical(cats, _) => {
89 with_match_categorical_physical_type!(cats.physical(), |$C| {
90 let ca = self.cat::<$C>().unwrap();
91 if ca.null_count() == ca.len() {
92 return None;
93 }
94 ca.iter_str()
95 .enumerate()
96 .flat_map(|(idx, val)| val.map(|val| (idx, val)))
97 .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
98 .map(|tpl| tpl.0)
99 })
100 },
101 #[cfg(feature = "dtype-categorical")]
102 Enum(_, _) => phys_s.arg_max(),
103 Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_max(),
104 String => {
105 let ca = self.str().unwrap();
106 arg_max_str(ca)
107 },
108 Boolean => {
109 let ca = self.bool().unwrap();
110 arg_max_bool(ca)
111 },
112 dt if dt.is_primitive_numeric() => {
113 with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
114 let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
115 arg_max_numeric_dispatch(ca)
116 })
117 },
118 _ => None,
119 }
120 }
121}
122
123fn arg_max_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
124where
125 T: PolarsNumericType,
126 for<'b> &'b [T::Native]: ArgMinMax,
127{
128 if ca.null_count() == ca.len() {
129 None
130 } else if T::get_static_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) {
131 arg_max_float_sorted(ca)
132 } else if let Ok(vals) = ca.cont_slice() {
133 arg_max_numeric_slice(vals, ca.is_sorted_flag())
134 } else {
135 arg_max_numeric(ca)
136 }
137}
138
139fn arg_min_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
140where
141 T: PolarsNumericType,
142 for<'b> &'b [T::Native]: ArgMinMax,
143{
144 if ca.null_count() == ca.len() {
145 None
146 } else if let Ok(vals) = ca.cont_slice() {
147 arg_min_numeric_slice(vals, ca.is_sorted_flag())
148 } else {
149 arg_min_numeric(ca)
150 }
151}
152
153fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
154 ca.first_true_idx().or_else(|| ca.first_false_idx())
155}
156
157fn arg_max_float_sorted<T>(ca: &ChunkedArray<T>) -> Option<usize>
160where
161 T: PolarsNumericType,
162{
163 let out = match ca.is_sorted_flag() {
164 IsSorted::Ascending => float_arg_max_sorted_ascending(ca),
165 IsSorted::Descending => float_arg_max_sorted_descending(ca),
166 _ => unreachable!(),
167 };
168 Some(out)
169}
170
171fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {
172 ca.first_false_idx().or_else(|| ca.first_true_idx())
173}
174
175fn arg_min_str(ca: &StringChunked) -> Option<usize> {
176 if ca.null_count() == ca.len() {
177 return None;
178 }
179 match ca.is_sorted_flag() {
180 IsSorted::Ascending => ca.first_non_null(),
181 IsSorted::Descending => ca.last_non_null(),
182 IsSorted::Not => ca
183 .iter()
184 .enumerate()
185 .flat_map(|(idx, val)| val.map(|val| (idx, val)))
186 .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
187 .map(|tpl| tpl.0),
188 }
189}
190
191fn arg_max_str(ca: &StringChunked) -> Option<usize> {
192 if ca.null_count() == ca.len() {
193 return None;
194 }
195 match ca.is_sorted_flag() {
196 IsSorted::Ascending => ca.last_non_null(),
197 IsSorted::Descending => ca.first_non_null(),
198 IsSorted::Not => ca
199 .iter()
200 .enumerate()
201 .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
202 .map(|tpl| tpl.0),
203 }
204}
205
206fn arg_min_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
207where
208 T: PolarsNumericType,
209 for<'b> &'b [T::Native]: ArgMinMax,
210{
211 match ca.is_sorted_flag() {
212 IsSorted::Ascending => ca.first_non_null(),
213 IsSorted::Descending => ca.last_non_null(),
214 IsSorted::Not => {
215 ca.downcast_iter()
216 .fold((None, None, 0), |acc, arr| {
217 if arr.len() == 0 {
218 return acc;
219 }
220 let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 {
221 arr.into_iter()
222 .enumerate()
223 .flat_map(|(idx, val)| val.map(|val| (idx, *val)))
224 .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
225 } else {
226 let min_idx: usize = arr.values().as_slice().argmin();
228 Some((min_idx, arr.value(min_idx)))
229 };
230
231 let new_offset: usize = acc.2 + arr.len();
232 match acc {
233 (Some(_), Some(acc_v), offset) => match chunk_min {
234 Some((idx, val)) if val < acc_v => {
235 (Some(idx + offset), Some(val), new_offset)
236 },
237 _ => (acc.0, acc.1, new_offset),
238 },
239 (None, None, offset) => match chunk_min {
240 Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
241 None => (None, None, new_offset),
242 },
243 _ => unreachable!(),
244 }
245 })
246 .0
247 },
248 }
249}
250
251fn arg_max_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
252where
253 T: PolarsNumericType,
254 for<'b> &'b [T::Native]: ArgMinMax,
255{
256 match ca.is_sorted_flag() {
257 IsSorted::Ascending => ca.last_non_null(),
258 IsSorted::Descending => ca.first_non_null(),
259 IsSorted::Not => {
260 ca.downcast_iter()
261 .fold((None, None, 0), |acc, arr| {
262 if arr.len() == 0 {
263 return acc;
264 }
265 let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 {
266 arr.into_iter()
268 .enumerate()
269 .flat_map(|(idx, val)| val.map(|val| (idx, *val)))
270 .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
271 } else {
272 let max_idx: usize = arr.values().as_slice().argmax();
274 Some((max_idx, arr.value(max_idx)))
275 };
276
277 let new_offset: usize = acc.2 + arr.len();
278 match acc {
279 (Some(_), Some(acc_v), offset) => match chunk_max {
280 Some((idx, val)) if acc_v < val => {
281 (Some(idx + offset), Some(val), new_offset)
282 },
283 _ => (acc.0, acc.1, new_offset),
284 },
285 (None, None, offset) => match chunk_max {
286 Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
287 None => (None, None, new_offset),
288 },
289 _ => unreachable!(),
290 }
291 })
292 .0
293 },
294 }
295}
296
297fn arg_min_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
298where
299 for<'a> &'a [T]: ArgMinMax,
300{
301 match is_sorted {
302 IsSorted::Ascending => Some(0),
304 IsSorted::Descending => Some(vals.len() - 1),
306 IsSorted::Not => Some(vals.argmin()), }
308}
309
310fn arg_max_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
311where
312 for<'a> &'a [T]: ArgMinMax,
313{
314 match is_sorted {
315 IsSorted::Ascending => Some(vals.len() - 1),
317 IsSorted::Descending => Some(0),
319 IsSorted::Not => Some(vals.argmax()), }
321}