1use arrow::array::Array;
2use polars_core::chunked_array::ops::float_sorted_arg_max::{
3 float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
4};
5use polars_core::series::IsSorted;
6use polars_core::with_match_categorical_physical_type;
7use polars_utils::arg_min_max::ArgMinMax;
8use polars_utils::min_max::{MaxIgnoreNan, MinIgnoreNan, MinMaxPolicy};
9
10pub fn arg_min_opt_iter<T, I>(iter: I) -> Option<usize>
11where
12 I: IntoIterator<Item = Option<T>>,
13 T: Ord,
14{
15 iter.into_iter()
16 .enumerate()
17 .flat_map(|(idx, val)| Some((idx, val?)))
18 .min_by(|x, y| Ord::cmp(&x.1, &y.1))
19 .map(|x| x.0)
20}
21
22pub fn arg_max_opt_iter<T, I>(iter: I) -> Option<usize>
23where
24 I: IntoIterator<Item = Option<T>>,
25 T: Ord,
26{
27 iter.into_iter()
28 .enumerate()
29 .flat_map(|(idx, val)| Some((idx, val?)))
30 .max_by(|x, y| Ord::cmp(&x.1, &y.1))
31 .map(|x| x.0)
32}
33
34use super::*;
35
36pub trait ArgAgg {
38 fn arg_min(&self) -> Option<usize>;
40 fn arg_max(&self) -> Option<usize>;
42}
43
44macro_rules! with_match_physical_numeric_polars_type {(
45 $key_type:expr, | $_:tt $T:ident | $($body:tt)*
46) => ({
47 macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
48 use DataType::*;
49 match $key_type {
50 #[cfg(feature = "dtype-i8")]
51 Int8 => __with_ty__! { Int8Type },
52 #[cfg(feature = "dtype-i16")]
53 Int16 => __with_ty__! { Int16Type },
54 Int32 => __with_ty__! { Int32Type },
55 Int64 => __with_ty__! { Int64Type },
56 #[cfg(feature = "dtype-i128")]
57 Int128 => __with_ty__! { Int128Type },
58 #[cfg(feature = "dtype-u8")]
59 UInt8 => __with_ty__! { UInt8Type },
60 #[cfg(feature = "dtype-u16")]
61 UInt16 => __with_ty__! { UInt16Type },
62 UInt32 => __with_ty__! { UInt32Type },
63 UInt64 => __with_ty__! { UInt64Type },
64 #[cfg(feature = "dtype-u128")]
65 UInt128 => __with_ty__! { UInt128Type },
66 #[cfg(feature = "dtype-f16")]
67 Float16 => __with_ty__! { Float16Type },
68 Float32 => __with_ty__! { Float32Type },
69 Float64 => __with_ty__! { Float64Type },
70 dt => panic!("not implemented for dtype {:?}", dt),
71 }
72})}
73
74impl ArgAgg for Series {
75 fn arg_min(&self) -> Option<usize> {
76 use DataType::*;
77 let phys_s = self.to_physical_repr();
78 match self.dtype() {
79 #[cfg(feature = "dtype-categorical")]
80 Categorical(cats, _) => {
81 with_match_categorical_physical_type!(cats.physical(), |$C| {
82 arg_min_cat(self.cat::<$C>().unwrap())
83 })
84 },
85 #[cfg(feature = "dtype-categorical")]
86 Enum(_, _) => phys_s.arg_min(),
87 #[cfg(feature = "dtype-decimal")]
88 Decimal(_, _) => phys_s.arg_min(),
89 Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_min(),
90 String => arg_min_str(self.str().unwrap()),
91 Binary => arg_min_binary(self.binary().unwrap()),
92 Boolean => arg_min_bool(self.bool().unwrap()),
93 dt if dt.is_primitive_numeric() => {
94 with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
95 let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
96 arg_min_numeric(ca)
97 })
98 },
99 _ => None,
100 }
101 }
102
103 fn arg_max(&self) -> Option<usize> {
104 use DataType::*;
105 let phys_s = self.to_physical_repr();
106 match self.dtype() {
107 #[cfg(feature = "dtype-categorical")]
108 Categorical(cats, _) => {
109 with_match_categorical_physical_type!(cats.physical(), |$C| {
110 arg_max_cat(self.cat::<$C>().unwrap())
111 })
112 },
113 #[cfg(feature = "dtype-categorical")]
114 Enum(_, _) => phys_s.arg_max(),
115 #[cfg(feature = "dtype-decimal")]
116 Decimal(_, _) => phys_s.arg_max(),
117 Date | Datetime(_, _) | Duration(_) | Time => phys_s.arg_max(),
118 String => arg_max_str(self.str().unwrap()),
119 Binary => arg_max_binary(self.binary().unwrap()),
120 Boolean => arg_max_bool(self.bool().unwrap()),
121 dt if dt.is_primitive_numeric() => {
122 with_match_physical_numeric_polars_type!(phys_s.dtype(), |$T| {
123 let ca: &ChunkedArray<$T> = phys_s.as_ref().as_ref().as_ref();
124 arg_max_numeric(ca)
125 })
126 },
127 _ => None,
128 }
129 }
130}
131
132pub fn arg_min_numeric<T>(ca: &ChunkedArray<T>) -> Option<usize>
133where
134 T: PolarsNumericType,
135 for<'b> &'b [T::Native]: ArgMinMax,
136{
137 if ca.null_count() == ca.len() {
138 None
139 } else if let Ok(vals) = ca.cont_slice() {
140 arg_min_numeric_slice(vals, ca.is_sorted_flag())
141 } else {
142 arg_min_numeric_chunked(ca)
143 }
144}
145
146pub fn arg_max_numeric<T>(ca: &ChunkedArray<T>) -> Option<usize>
147where
148 T: PolarsNumericType,
149 for<'b> &'b [T::Native]: ArgMinMax,
150{
151 if ca.null_count() == ca.len() {
152 None
153 } else if T::get_static_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) {
154 arg_max_float_sorted(ca)
155 } else if let Ok(vals) = ca.cont_slice() {
156 arg_max_numeric_slice(vals, ca.is_sorted_flag())
157 } else {
158 arg_max_numeric_chunked(ca)
159 }
160}
161
162fn arg_max_float_sorted<T>(ca: &ChunkedArray<T>) -> Option<usize>
165where
166 T: PolarsNumericType,
167{
168 let out = match ca.is_sorted_flag() {
169 IsSorted::Ascending => float_arg_max_sorted_ascending(ca),
170 IsSorted::Descending => float_arg_max_sorted_descending(ca),
171 _ => unreachable!(),
172 };
173 Some(out)
174}
175
176#[cfg(feature = "dtype-categorical")]
177pub fn arg_min_cat<T: PolarsCategoricalType>(ca: &CategoricalChunked<T>) -> Option<usize> {
178 if ca.null_count() == ca.len() {
179 return None;
180 }
181 arg_min_opt_iter(ca.iter_str())
182}
183
184#[cfg(feature = "dtype-categorical")]
185pub fn arg_max_cat<T: PolarsCategoricalType>(ca: &CategoricalChunked<T>) -> Option<usize> {
186 if ca.null_count() == ca.len() {
187 return None;
188 }
189 arg_max_opt_iter(ca.iter_str())
190}
191
192pub fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {
193 ca.first_false_idx().or_else(|| ca.first_true_idx())
194}
195
196pub fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
197 ca.first_true_idx().or_else(|| ca.first_false_idx())
198}
199
200pub fn arg_min_str(ca: &StringChunked) -> Option<usize> {
201 arg_min_physical_generic(ca)
202}
203
204pub fn arg_max_str(ca: &StringChunked) -> Option<usize> {
205 arg_max_physical_generic(ca)
206}
207
208pub fn arg_min_binary(ca: &BinaryChunked) -> Option<usize> {
209 arg_min_physical_generic(ca)
210}
211
212pub fn arg_max_binary(ca: &BinaryChunked) -> Option<usize> {
213 arg_max_physical_generic(ca)
214}
215
216fn arg_min_physical_generic<T>(ca: &ChunkedArray<T>) -> Option<usize>
217where
218 T: PolarsDataType,
219 for<'a> T::Physical<'a>: Ord,
220{
221 if ca.null_count() == ca.len() {
222 return None;
223 }
224 match ca.is_sorted_flag() {
225 IsSorted::Ascending => ca.first_non_null(),
226 IsSorted::Descending => ca.last_non_null(),
227 IsSorted::Not => arg_min_opt_iter(ca.iter()),
228 }
229}
230
231fn arg_max_physical_generic<T>(ca: &ChunkedArray<T>) -> Option<usize>
232where
233 T: PolarsDataType,
234 for<'a> T::Physical<'a>: Ord,
235{
236 if ca.null_count() == ca.len() {
237 return None;
238 }
239 match ca.is_sorted_flag() {
240 IsSorted::Ascending => ca.last_non_null(),
241 IsSorted::Descending => ca.first_non_null(),
242 IsSorted::Not => arg_max_opt_iter(ca.iter()),
243 }
244}
245
246fn arg_min_numeric_chunked<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
247where
248 T: PolarsNumericType,
249 for<'b> &'b [T::Native]: ArgMinMax,
250{
251 match ca.is_sorted_flag() {
252 IsSorted::Ascending => ca.first_non_null(),
253 IsSorted::Descending => ca.last_non_null(),
254 IsSorted::Not => {
255 let mut chunk_start_offset = 0;
256 let mut min_idx: Option<usize> = None;
257 let mut min_val: Option<T::Native> = None;
258 for arr in ca.downcast_iter() {
259 if arr.len() == arr.null_count() {
260 chunk_start_offset += arr.len();
261 continue;
262 }
263
264 let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 {
265 arr.into_iter()
266 .enumerate()
267 .flat_map(|(idx, val)| Some((idx, *(val?))))
268 .reduce(|acc, (idx, val)| {
269 if MinIgnoreNan::is_better(&val, &acc.1) {
270 (idx, val)
271 } else {
272 acc
273 }
274 })
275 } else {
276 let min_idx: usize = arr.values().as_slice().argmin();
278 Some((min_idx, arr.value(min_idx)))
279 };
280
281 if let Some((chunk_min_idx, chunk_min_val)) = chunk_min {
282 if min_val.is_none()
283 || MinIgnoreNan::is_better(&chunk_min_val, &min_val.unwrap())
284 {
285 min_val = Some(chunk_min_val);
286 min_idx = Some(chunk_start_offset + chunk_min_idx);
287 }
288 }
289 chunk_start_offset += arr.len();
290 }
291 min_idx
292 },
293 }
294}
295
296fn arg_max_numeric_chunked<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
297where
298 T: PolarsNumericType,
299 for<'b> &'b [T::Native]: ArgMinMax,
300{
301 match ca.is_sorted_flag() {
302 IsSorted::Ascending => ca.last_non_null(),
303 IsSorted::Descending => ca.first_non_null(),
304 IsSorted::Not => {
305 let mut chunk_start_offset = 0;
306 let mut max_idx: Option<usize> = None;
307 let mut max_val: Option<T::Native> = None;
308 for arr in ca.downcast_iter() {
309 if arr.len() == arr.null_count() {
310 chunk_start_offset += arr.len();
311 continue;
312 }
313
314 let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 {
315 arr.into_iter()
316 .enumerate()
317 .flat_map(|(idx, val)| Some((idx, *(val?))))
318 .reduce(|acc, (idx, val)| {
319 if MaxIgnoreNan::is_better(&val, &acc.1) {
320 (idx, val)
321 } else {
322 acc
323 }
324 })
325 } else {
326 let max_idx: usize = arr.values().as_slice().argmax();
328 Some((max_idx, arr.value(max_idx)))
329 };
330
331 if let Some((chunk_max_idx, chunk_max_val)) = chunk_max {
332 if max_val.is_none()
333 || MaxIgnoreNan::is_better(&chunk_max_val, &max_val.unwrap())
334 {
335 max_val = Some(chunk_max_val);
336 max_idx = Some(chunk_start_offset + chunk_max_idx);
337 }
338 }
339 chunk_start_offset += arr.len();
340 }
341 max_idx
342 },
343 }
344}
345
346fn arg_min_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
347where
348 for<'a> &'a [T]: ArgMinMax,
349{
350 match is_sorted {
351 IsSorted::Ascending => Some(0),
353 IsSorted::Descending => Some(vals.len() - 1),
355 IsSorted::Not => Some(vals.argmin()), }
357}
358
359fn arg_max_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
360where
361 for<'a> &'a [T]: ArgMinMax,
362{
363 match is_sorted {
364 IsSorted::Ascending => Some(vals.len() - 1),
366 IsSorted::Descending => Some(0),
368 IsSorted::Not => Some(vals.argmax()), }
370}