1use argminmax::ArgMinMax;
2use arrow::array::Array;
3use arrow::legacy::bit_util::*;
4use polars_core::chunked_array::ops::float_sorted_arg_max::{
5 float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
6};
7use polars_core::series::IsSorted;
8use polars_core::with_match_physical_numeric_polars_type;
9
10use super::*;
11
12pub trait ArgAgg {
14 fn arg_min(&self) -> Option<usize>;
16 fn arg_max(&self) -> Option<usize>;
18}
19
20macro_rules! with_match_physical_numeric_polars_type {(
21 $key_type:expr, | $_:tt $T:ident | $($body:tt)*
22) => ({
23 macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
24 use DataType::*;
25 match $key_type {
26 #[cfg(feature = "dtype-i8")]
27 Int8 => __with_ty__! { Int8Type },
28 #[cfg(feature = "dtype-i16")]
29 Int16 => __with_ty__! { Int16Type },
30 Int32 => __with_ty__! { Int32Type },
31 Int64 => __with_ty__! { Int64Type },
32 #[cfg(feature = "dtype-u8")]
33 UInt8 => __with_ty__! { UInt8Type },
34 #[cfg(feature = "dtype-u16")]
35 UInt16 => __with_ty__! { UInt16Type },
36 UInt32 => __with_ty__! { UInt32Type },
37 UInt64 => __with_ty__! { UInt64Type },
38 Float32 => __with_ty__! { Float32Type },
39 Float64 => __with_ty__! { Float64Type },
40 dt => panic!("not implemented for dtype {:?}", dt),
41 }
42})}
43
44impl ArgAgg for Series {
45 fn arg_min(&self) -> Option<usize> {
46 use DataType::*;
47 let s = self.to_physical_repr();
48 match self.dtype() {
49 #[cfg(feature = "dtype-categorical")]
50 Categorical(_, _) => {
51 let ca = self.categorical().unwrap();
52 if ca.null_count() == ca.len() {
53 return None;
54 }
55 if ca.uses_lexical_ordering() {
56 ca.iter_str()
57 .enumerate()
58 .flat_map(|(idx, val)| val.map(|val| (idx, val)))
59 .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
60 .map(|tpl| tpl.0)
61 } else {
62 let ca = s.u32().unwrap();
63 arg_min_numeric_dispatch(ca)
64 }
65 },
66 String => {
67 let ca = self.str().unwrap();
68 arg_min_str(ca)
69 },
70 Boolean => {
71 let ca = self.bool().unwrap();
72 arg_min_bool(ca)
73 },
74 Date => {
75 let ca = s.i32().unwrap();
76 arg_min_numeric_dispatch(ca)
77 },
78 Datetime(_, _) | Duration(_) | Time => {
79 let ca = s.i64().unwrap();
80 arg_min_numeric_dispatch(ca)
81 },
82 dt if dt.is_primitive_numeric() => {
83 with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
84 let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
85 arg_min_numeric_dispatch(ca)
86 })
87 },
88 _ => None,
89 }
90 }
91
92 fn arg_max(&self) -> Option<usize> {
93 use DataType::*;
94 let s = self.to_physical_repr();
95 match self.dtype() {
96 #[cfg(feature = "dtype-categorical")]
97 Categorical(_, _) => {
98 let ca = self.categorical().unwrap();
99 if ca.null_count() == ca.len() {
100 return None;
101 }
102 if ca.uses_lexical_ordering() {
103 ca.iter_str()
104 .enumerate()
105 .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
106 .map(|tpl| tpl.0)
107 } else {
108 let ca_phys = s.u32().unwrap();
109 arg_max_numeric_dispatch(ca_phys)
110 }
111 },
112 String => {
113 let ca = self.str().unwrap();
114 arg_max_str(ca)
115 },
116 Boolean => {
117 let ca = self.bool().unwrap();
118 arg_max_bool(ca)
119 },
120 Date => {
121 let ca = s.i32().unwrap();
122 arg_max_numeric_dispatch(ca)
123 },
124 Datetime(_, _) | Duration(_) | Time => {
125 let ca = s.i64().unwrap();
126 arg_max_numeric_dispatch(ca)
127 },
128 dt if dt.is_primitive_numeric() => {
129 with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
130 let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
131 arg_max_numeric_dispatch(ca)
132 })
133 },
134 _ => None,
135 }
136 }
137}
138
139fn arg_max_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
140where
141 T: PolarsNumericType,
142 for<'b> &'b [T::Native]: ArgMinMax,
143{
144 if ca.null_count() == ca.len() {
145 None
146 } else if T::get_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) {
147 arg_max_float_sorted(ca)
148 } else if let Ok(vals) = ca.cont_slice() {
149 arg_max_numeric_slice(vals, ca.is_sorted_flag())
150 } else {
151 arg_max_numeric(ca)
152 }
153}
154
155fn arg_min_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
156where
157 T: PolarsNumericType,
158 for<'b> &'b [T::Native]: ArgMinMax,
159{
160 if ca.null_count() == ca.len() {
161 None
162 } else if let Ok(vals) = ca.cont_slice() {
163 arg_min_numeric_slice(vals, ca.is_sorted_flag())
164 } else {
165 arg_min_numeric(ca)
166 }
167}
168
169pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
170 if ca.null_count() == ca.len() {
171 None
172 }
173 else if ca.null_count() == 0 && ca.chunks().len() == 1 {
175 let arr = ca.downcast_iter().next().unwrap();
176 let mask = arr.values();
177 Some(first_set_bit(mask))
178 } else {
179 let mut first_false_idx: Option<usize> = None;
180 ca.iter()
181 .enumerate()
182 .find_map(|(idx, val)| match val {
183 Some(true) => Some(idx),
184 Some(false) if first_false_idx.is_none() => {
185 first_false_idx = Some(idx);
186 None
187 },
188 _ => None,
189 })
190 .or(first_false_idx)
191 }
192}
193
194fn arg_max_float_sorted<T>(ca: &ChunkedArray<T>) -> Option<usize>
197where
198 T: PolarsNumericType,
199{
200 let out = match ca.is_sorted_flag() {
201 IsSorted::Ascending => float_arg_max_sorted_ascending(ca),
202 IsSorted::Descending => float_arg_max_sorted_descending(ca),
203 _ => unreachable!(),
204 };
205
206 Some(out)
207}
208
209fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {
210 if ca.null_count() == ca.len() {
211 None
212 } else if ca.null_count() == 0 && ca.chunks().len() == 1 {
213 let arr = ca.downcast_iter().next().unwrap();
214 let mask = arr.values();
215 Some(first_unset_bit(mask))
216 } else {
217 let mut first_true_idx: Option<usize> = None;
218 ca.iter()
219 .enumerate()
220 .find_map(|(idx, val)| match val {
221 Some(false) => Some(idx),
222 Some(true) if first_true_idx.is_none() => {
223 first_true_idx = Some(idx);
224 None
225 },
226 _ => None,
227 })
228 .or(first_true_idx)
229 }
230}
231
232fn arg_min_str(ca: &StringChunked) -> Option<usize> {
233 if ca.null_count() == ca.len() {
234 return None;
235 }
236 match ca.is_sorted_flag() {
237 IsSorted::Ascending => ca.first_non_null(),
238 IsSorted::Descending => ca.last_non_null(),
239 IsSorted::Not => ca
240 .iter()
241 .enumerate()
242 .flat_map(|(idx, val)| val.map(|val| (idx, val)))
243 .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
244 .map(|tpl| tpl.0),
245 }
246}
247
248fn arg_max_str(ca: &StringChunked) -> Option<usize> {
249 if ca.null_count() == ca.len() {
250 return None;
251 }
252 match ca.is_sorted_flag() {
253 IsSorted::Ascending => ca.last_non_null(),
254 IsSorted::Descending => ca.first_non_null(),
255 IsSorted::Not => ca
256 .iter()
257 .enumerate()
258 .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
259 .map(|tpl| tpl.0),
260 }
261}
262
263fn arg_min_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
264where
265 T: PolarsNumericType,
266 for<'b> &'b [T::Native]: ArgMinMax,
267{
268 match ca.is_sorted_flag() {
269 IsSorted::Ascending => ca.first_non_null(),
270 IsSorted::Descending => ca.last_non_null(),
271 IsSorted::Not => {
272 ca.downcast_iter()
273 .fold((None, None, 0), |acc, arr| {
274 if arr.len() == 0 {
275 return acc;
276 }
277 let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 {
278 arr.into_iter()
279 .enumerate()
280 .flat_map(|(idx, val)| val.map(|val| (idx, *val)))
281 .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
282 } else {
283 let min_idx: usize = arr.values().as_slice().argmin();
285 Some((min_idx, arr.value(min_idx)))
286 };
287
288 let new_offset: usize = acc.2 + arr.len();
289 match acc {
290 (Some(_), Some(acc_v), offset) => match chunk_min {
291 Some((idx, val)) if val < acc_v => {
292 (Some(idx + offset), Some(val), new_offset)
293 },
294 _ => (acc.0, acc.1, new_offset),
295 },
296 (None, None, offset) => match chunk_min {
297 Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
298 None => (None, None, new_offset),
299 },
300 _ => unreachable!(),
301 }
302 })
303 .0
304 },
305 }
306}
307
308fn arg_max_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
309where
310 T: PolarsNumericType,
311 for<'b> &'b [T::Native]: ArgMinMax,
312{
313 match ca.is_sorted_flag() {
314 IsSorted::Ascending => ca.last_non_null(),
315 IsSorted::Descending => ca.first_non_null(),
316 IsSorted::Not => {
317 ca.downcast_iter()
318 .fold((None, None, 0), |acc, arr| {
319 if arr.len() == 0 {
320 return acc;
321 }
322 let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 {
323 arr.into_iter()
325 .enumerate()
326 .flat_map(|(idx, val)| val.map(|val| (idx, *val)))
327 .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
328 } else {
329 let max_idx: usize = arr.values().as_slice().argmax();
331 Some((max_idx, arr.value(max_idx)))
332 };
333
334 let new_offset: usize = acc.2 + arr.len();
335 match acc {
336 (Some(_), Some(acc_v), offset) => match chunk_max {
337 Some((idx, val)) if acc_v < val => {
338 (Some(idx + offset), Some(val), new_offset)
339 },
340 _ => (acc.0, acc.1, new_offset),
341 },
342 (None, None, offset) => match chunk_max {
343 Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
344 None => (None, None, new_offset),
345 },
346 _ => unreachable!(),
347 }
348 })
349 .0
350 },
351 }
352}
353
354fn arg_min_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
355where
356 for<'a> &'a [T]: ArgMinMax,
357{
358 match is_sorted {
359 IsSorted::Ascending => Some(0),
361 IsSorted::Descending => Some(vals.len() - 1),
363 IsSorted::Not => Some(vals.argmin()), }
365}
366
367fn arg_max_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
368where
369 for<'a> &'a [T]: ArgMinMax,
370{
371 match is_sorted {
372 IsSorted::Ascending => Some(vals.len() - 1),
374 IsSorted::Descending => Some(0),
376 IsSorted::Not => Some(vals.argmax()), }
378}