1use arrow::array::Array;
2use polars_core::chunked_array::ops::float_sorted_arg_max::{
3 float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
4};
5use polars_core::series::IsSorted;
6use polars_core::with_match_categorical_physical_type;
7use polars_utils::arg_min_max::ArgMinMax;
8
9use super::*;
10
11pub trait ArgAgg {
13 fn arg_min(&self) -> Option<usize>;
15 fn arg_max(&self) -> Option<usize>;
17}
18
19macro_rules! with_match_physical_numeric_polars_type {(
20 $key_type:expr, | $_:tt $T:ident | $($body:tt)*
21) => ({
22 macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
23 use DataType::*;
24 match $key_type {
25 #[cfg(feature = "dtype-i8")]
26 Int8 => __with_ty__! { Int8Type },
27 #[cfg(feature = "dtype-i16")]
28 Int16 => __with_ty__! { Int16Type },
29 Int32 => __with_ty__! { Int32Type },
30 Int64 => __with_ty__! { Int64Type },
31 #[cfg(feature = "dtype-i128")]
32 Int128 => __with_ty__! { Int128Type },
33 #[cfg(feature = "dtype-u8")]
34 UInt8 => __with_ty__! { UInt8Type },
35 #[cfg(feature = "dtype-u16")]
36 UInt16 => __with_ty__! { UInt16Type },
37 UInt32 => __with_ty__! { UInt32Type },
38 UInt64 => __with_ty__! { UInt64Type },
39 #[cfg(feature = "dtype-u128")]
40 UInt128 => __with_ty__! { UInt128Type },
41 #[cfg(feature = "dtype-f16")]
42 Float16 => __with_ty__! { Float16Type },
43 Float32 => __with_ty__! { Float32Type },
44 Float64 => __with_ty__! { Float64Type },
45 dt => panic!("not implemented for dtype {:?}", dt),
46 }
47})}
48
49impl ArgAgg for Series {
50 fn arg_min(&self) -> Option<usize> {
51 use DataType::*;
52 let phys_s = self.to_physical_repr();
53 match self.dtype() {
54 #[cfg(feature = "dtype-categorical")]
55 Categorical(cats, _) => {
56 with_match_categorical_physical_type!(cats.physical(), |$C| {
57 let ca = self.cat::<$C>().unwrap();
58 if ca.null_count() == ca.len() {
59 return None;
60 }
61 ca.iter_str()
62 .enumerate()
63 .flat_map(|(idx, val)| val.map(|val| (idx, val)))
64 .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
65 .map(|tpl| tpl.0)
66 })
67 },
68 #[cfg(feature = "dtype-categorical")]
69 Enum(_, _) => phys_s.arg_min(),
70 Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_min(),
71 String => {
72 let ca = self.str().unwrap();
73 arg_min_str(ca)
74 },
75 Boolean => {
76 let ca = self.bool().unwrap();
77 arg_min_bool(ca)
78 },
79 dt if dt.is_primitive_numeric() => {
80 with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
81 let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
82 arg_min_numeric_dispatch(ca)
83 })
84 },
85 _ => None,
86 }
87 }
88
89 fn arg_max(&self) -> Option<usize> {
90 use DataType::*;
91 let phys_s = self.to_physical_repr();
92 match self.dtype() {
93 #[cfg(feature = "dtype-categorical")]
94 Categorical(cats, _) => {
95 with_match_categorical_physical_type!(cats.physical(), |$C| {
96 let ca = self.cat::<$C>().unwrap();
97 if ca.null_count() == ca.len() {
98 return None;
99 }
100 ca.iter_str()
101 .enumerate()
102 .flat_map(|(idx, val)| val.map(|val| (idx, val)))
103 .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
104 .map(|tpl| tpl.0)
105 })
106 },
107 #[cfg(feature = "dtype-categorical")]
108 Enum(_, _) => phys_s.arg_max(),
109 Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_max(),
110 String => {
111 let ca = self.str().unwrap();
112 arg_max_str(ca)
113 },
114 Boolean => {
115 let ca = self.bool().unwrap();
116 arg_max_bool(ca)
117 },
118 #[cfg(feature = "dtype-f16")]
119 Float16 => {
120 phys_s.cast(&DataType::Float32).unwrap().arg_max()
122 },
123 dt if dt.is_primitive_numeric() => {
124 with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
125 let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
126 arg_max_numeric_dispatch(ca)
127 })
128 },
129 _ => None,
130 }
131 }
132}
133
134fn arg_max_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
135where
136 T: PolarsNumericType,
137 for<'b> &'b [T::Native]: ArgMinMax,
138{
139 if ca.null_count() == ca.len() {
140 None
141 } else if T::get_static_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) {
142 arg_max_float_sorted(ca)
143 } else if let Ok(vals) = ca.cont_slice() {
144 arg_max_numeric_slice(vals, ca.is_sorted_flag())
145 } else {
146 arg_max_numeric(ca)
147 }
148}
149
150fn arg_min_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
151where
152 T: PolarsNumericType,
153 for<'b> &'b [T::Native]: ArgMinMax,
154{
155 if ca.null_count() == ca.len() {
156 None
157 } else if let Ok(vals) = ca.cont_slice() {
158 arg_min_numeric_slice(vals, ca.is_sorted_flag())
159 } else {
160 arg_min_numeric(ca)
161 }
162}
163
164fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
165 ca.first_true_idx().or_else(|| ca.first_false_idx())
166}
167
168fn arg_max_float_sorted<T>(ca: &ChunkedArray<T>) -> Option<usize>
171where
172 T: PolarsNumericType,
173{
174 let out = match ca.is_sorted_flag() {
175 IsSorted::Ascending => float_arg_max_sorted_ascending(ca),
176 IsSorted::Descending => float_arg_max_sorted_descending(ca),
177 _ => unreachable!(),
178 };
179 Some(out)
180}
181
182fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {
183 ca.first_false_idx().or_else(|| ca.first_true_idx())
184}
185
186fn arg_min_str(ca: &StringChunked) -> Option<usize> {
187 if ca.null_count() == ca.len() {
188 return None;
189 }
190 match ca.is_sorted_flag() {
191 IsSorted::Ascending => ca.first_non_null(),
192 IsSorted::Descending => ca.last_non_null(),
193 IsSorted::Not => ca
194 .iter()
195 .enumerate()
196 .flat_map(|(idx, val)| val.map(|val| (idx, val)))
197 .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
198 .map(|tpl| tpl.0),
199 }
200}
201
202fn arg_max_str(ca: &StringChunked) -> Option<usize> {
203 if ca.null_count() == ca.len() {
204 return None;
205 }
206 match ca.is_sorted_flag() {
207 IsSorted::Ascending => ca.last_non_null(),
208 IsSorted::Descending => ca.first_non_null(),
209 IsSorted::Not => ca
210 .iter()
211 .enumerate()
212 .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
213 .map(|tpl| tpl.0),
214 }
215}
216
217fn arg_min_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
218where
219 T: PolarsNumericType,
220 for<'b> &'b [T::Native]: ArgMinMax,
221{
222 match ca.is_sorted_flag() {
223 IsSorted::Ascending => ca.first_non_null(),
224 IsSorted::Descending => ca.last_non_null(),
225 IsSorted::Not => {
226 ca.downcast_iter()
227 .fold((None, None, 0), |acc, arr| {
228 if arr.len() == 0 {
229 return acc;
230 }
231 let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 {
232 arr.into_iter()
233 .enumerate()
234 .flat_map(|(idx, val)| val.map(|val| (idx, *val)))
235 .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
236 } else {
237 let min_idx: usize = arr.values().as_slice().argmin();
239 Some((min_idx, arr.value(min_idx)))
240 };
241
242 let new_offset: usize = acc.2 + arr.len();
243 match acc {
244 (Some(_), Some(acc_v), offset) => match chunk_min {
245 Some((idx, val)) if val < acc_v => {
246 (Some(idx + offset), Some(val), new_offset)
247 },
248 _ => (acc.0, acc.1, new_offset),
249 },
250 (None, None, offset) => match chunk_min {
251 Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
252 None => (None, None, new_offset),
253 },
254 _ => unreachable!(),
255 }
256 })
257 .0
258 },
259 }
260}
261
262fn arg_max_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
263where
264 T: PolarsNumericType,
265 for<'b> &'b [T::Native]: ArgMinMax,
266{
267 match ca.is_sorted_flag() {
268 IsSorted::Ascending => ca.last_non_null(),
269 IsSorted::Descending => ca.first_non_null(),
270 IsSorted::Not => {
271 ca.downcast_iter()
272 .fold((None, None, 0), |acc, arr| {
273 if arr.len() == 0 {
274 return acc;
275 }
276 let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 {
277 arr.into_iter()
279 .enumerate()
280 .flat_map(|(idx, val)| val.map(|val| (idx, *val)))
281 .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
282 } else {
283 let max_idx: usize = arr.values().as_slice().argmax();
285 Some((max_idx, arr.value(max_idx)))
286 };
287
288 let new_offset: usize = acc.2 + arr.len();
289 match acc {
290 (Some(_), Some(acc_v), offset) => match chunk_max {
291 Some((idx, val)) if acc_v < val => {
292 (Some(idx + offset), Some(val), new_offset)
293 },
294 _ => (acc.0, acc.1, new_offset),
295 },
296 (None, None, offset) => match chunk_max {
297 Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
298 None => (None, None, new_offset),
299 },
300 _ => unreachable!(),
301 }
302 })
303 .0
304 },
305 }
306}
307
308fn arg_min_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
309where
310 for<'a> &'a [T]: ArgMinMax,
311{
312 match is_sorted {
313 IsSorted::Ascending => Some(0),
315 IsSorted::Descending => Some(vals.len() - 1),
317 IsSorted::Not => Some(vals.argmin()), }
319}
320
321fn arg_max_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
322where
323 for<'a> &'a [T]: ArgMinMax,
324{
325 match is_sorted {
326 IsSorted::Ascending => Some(vals.len() - 1),
328 IsSorted::Descending => Some(0),
330 IsSorted::Not => Some(vals.argmax()), }
332}