use argminmax::ArgMinMax;
use arrow::array::Array;
use arrow::legacy::bit_util::*;
use polars_core::chunked_array::ops::float_sorted_arg_max::{
float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
};
use polars_core::series::IsSorted;
use polars_core::with_match_physical_numeric_polars_type;
use super::*;
pub trait ArgAgg {
fn arg_min(&self) -> Option<usize>;
fn arg_max(&self) -> Option<usize>;
}
macro_rules! with_match_physical_numeric_polars_type {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use DataType::*;
match $key_type {
#[cfg(feature = "dtype-i8")]
Int8 => __with_ty__! { Int8Type },
#[cfg(feature = "dtype-i16")]
Int16 => __with_ty__! { Int16Type },
Int32 => __with_ty__! { Int32Type },
Int64 => __with_ty__! { Int64Type },
#[cfg(feature = "dtype-u8")]
UInt8 => __with_ty__! { UInt8Type },
#[cfg(feature = "dtype-u16")]
UInt16 => __with_ty__! { UInt16Type },
UInt32 => __with_ty__! { UInt32Type },
UInt64 => __with_ty__! { UInt64Type },
Float32 => __with_ty__! { Float32Type },
Float64 => __with_ty__! { Float64Type },
dt => panic!("not implemented for dtype {:?}", dt),
}
})}
impl ArgAgg for Series {
fn arg_min(&self) -> Option<usize> {
use DataType::*;
let s = self.to_physical_repr();
match self.dtype() {
#[cfg(feature = "dtype-categorical")]
Categorical(_, _) => {
let ca = self.categorical().unwrap();
if ca.null_count() == ca.len() {
return None;
}
if ca.uses_lexical_ordering() {
ca.iter_str()
.enumerate()
.flat_map(|(idx, val)| val.map(|val| (idx, val)))
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
.map(|tpl| tpl.0)
} else {
let ca = s.u32().unwrap();
arg_min_numeric_dispatch(ca)
}
},
String => {
let ca = self.str().unwrap();
arg_min_str(ca)
},
Boolean => {
let ca = self.bool().unwrap();
arg_min_bool(ca)
},
Date => {
let ca = s.i32().unwrap();
arg_min_numeric_dispatch(ca)
},
Datetime(_, _) | Duration(_) | Time => {
let ca = s.i64().unwrap();
arg_min_numeric_dispatch(ca)
},
dt if dt.is_numeric() => {
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
arg_min_numeric_dispatch(ca)
})
},
_ => None,
}
}
fn arg_max(&self) -> Option<usize> {
use DataType::*;
let s = self.to_physical_repr();
match self.dtype() {
#[cfg(feature = "dtype-categorical")]
Categorical(_, _) => {
let ca = self.categorical().unwrap();
if ca.null_count() == ca.len() {
return None;
}
if ca.uses_lexical_ordering() {
ca.iter_str()
.enumerate()
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
.map(|tpl| tpl.0)
} else {
let ca_phys = s.u32().unwrap();
arg_max_numeric_dispatch(ca_phys)
}
},
String => {
let ca = self.str().unwrap();
arg_max_str(ca)
},
Boolean => {
let ca = self.bool().unwrap();
arg_max_bool(ca)
},
Date => {
let ca = s.i32().unwrap();
arg_max_numeric_dispatch(ca)
},
Datetime(_, _) | Duration(_) | Time => {
let ca = s.i64().unwrap();
arg_max_numeric_dispatch(ca)
},
dt if dt.is_numeric() => {
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
arg_max_numeric_dispatch(ca)
})
},
_ => None,
}
}
}
fn arg_max_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
where
T: PolarsNumericType,
for<'b> &'b [T::Native]: ArgMinMax,
{
if ca.null_count() == ca.len() {
None
} else if T::get_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) {
arg_max_float_sorted(ca)
} else if let Ok(vals) = ca.cont_slice() {
arg_max_numeric_slice(vals, ca.is_sorted_flag())
} else {
arg_max_numeric(ca)
}
}
fn arg_min_numeric_dispatch<T>(ca: &ChunkedArray<T>) -> Option<usize>
where
T: PolarsNumericType,
for<'b> &'b [T::Native]: ArgMinMax,
{
if ca.null_count() == ca.len() {
None
} else if let Ok(vals) = ca.cont_slice() {
arg_min_numeric_slice(vals, ca.is_sorted_flag())
} else {
arg_min_numeric(ca)
}
}
pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option<usize> {
if ca.null_count() == ca.len() {
None
}
else if ca.null_count() == 0 && ca.chunks().len() == 1 {
let arr = ca.downcast_iter().next().unwrap();
let mask = arr.values();
Some(first_set_bit(mask))
} else {
let mut first_false_idx: Option<usize> = None;
ca.iter()
.enumerate()
.find_map(|(idx, val)| match val {
Some(true) => Some(idx),
Some(false) if first_false_idx.is_none() => {
first_false_idx = Some(idx);
None
},
_ => None,
})
.or(first_false_idx)
}
}
fn arg_max_float_sorted<T>(ca: &ChunkedArray<T>) -> Option<usize>
where
T: PolarsNumericType,
{
let out = match ca.is_sorted_flag() {
IsSorted::Ascending => float_arg_max_sorted_ascending(ca),
IsSorted::Descending => float_arg_max_sorted_descending(ca),
_ => unreachable!(),
};
Some(out)
}
fn arg_min_bool(ca: &BooleanChunked) -> Option<usize> {
if ca.null_count() == ca.len() {
None
} else if ca.null_count() == 0 && ca.chunks().len() == 1 {
let arr = ca.downcast_iter().next().unwrap();
let mask = arr.values();
Some(first_unset_bit(mask))
} else {
let mut first_true_idx: Option<usize> = None;
ca.iter()
.enumerate()
.find_map(|(idx, val)| match val {
Some(false) => Some(idx),
Some(true) if first_true_idx.is_none() => {
first_true_idx = Some(idx);
None
},
_ => None,
})
.or(first_true_idx)
}
}
fn arg_min_str(ca: &StringChunked) -> Option<usize> {
if ca.null_count() == ca.len() {
return None;
}
match ca.is_sorted_flag() {
IsSorted::Ascending => ca.first_non_null(),
IsSorted::Descending => ca.last_non_null(),
IsSorted::Not => ca
.iter()
.enumerate()
.flat_map(|(idx, val)| val.map(|val| (idx, val)))
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
.map(|tpl| tpl.0),
}
}
fn arg_max_str(ca: &StringChunked) -> Option<usize> {
if ca.null_count() == ca.len() {
return None;
}
match ca.is_sorted_flag() {
IsSorted::Ascending => ca.last_non_null(),
IsSorted::Descending => ca.first_non_null(),
IsSorted::Not => ca
.iter()
.enumerate()
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
.map(|tpl| tpl.0),
}
}
fn arg_min_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
where
T: PolarsNumericType,
for<'b> &'b [T::Native]: ArgMinMax,
{
match ca.is_sorted_flag() {
IsSorted::Ascending => ca.first_non_null(),
IsSorted::Descending => ca.last_non_null(),
IsSorted::Not => {
ca.downcast_iter()
.fold((None, None, 0), |acc, arr| {
if arr.len() == 0 {
return acc;
}
let chunk_min: Option<(usize, T::Native)> = if arr.null_count() > 0 {
arr.into_iter()
.enumerate()
.flat_map(|(idx, val)| val.map(|val| (idx, *val)))
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
} else {
let min_idx: usize = arr.values().as_slice().argmin();
Some((min_idx, arr.value(min_idx)))
};
let new_offset: usize = acc.2 + arr.len();
match acc {
(Some(_), Some(acc_v), offset) => match chunk_min {
Some((idx, val)) if val < acc_v => {
(Some(idx + offset), Some(val), new_offset)
},
_ => (acc.0, acc.1, new_offset),
},
(None, None, offset) => match chunk_min {
Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
None => (None, None, new_offset),
},
_ => unreachable!(),
}
})
.0
},
}
}
fn arg_max_numeric<'a, T>(ca: &'a ChunkedArray<T>) -> Option<usize>
where
T: PolarsNumericType,
for<'b> &'b [T::Native]: ArgMinMax,
{
match ca.is_sorted_flag() {
IsSorted::Ascending => ca.last_non_null(),
IsSorted::Descending => ca.first_non_null(),
IsSorted::Not => {
ca.downcast_iter()
.fold((None, None, 0), |acc, arr| {
if arr.len() == 0 {
return acc;
}
let chunk_max: Option<(usize, T::Native)> = if arr.null_count() > 0 {
arr.into_iter()
.enumerate()
.flat_map(|(idx, val)| val.map(|val| (idx, *val)))
.reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc })
} else {
let max_idx: usize = arr.values().as_slice().argmax();
Some((max_idx, arr.value(max_idx)))
};
let new_offset: usize = acc.2 + arr.len();
match acc {
(Some(_), Some(acc_v), offset) => match chunk_max {
Some((idx, val)) if acc_v < val => {
(Some(idx + offset), Some(val), new_offset)
},
_ => (acc.0, acc.1, new_offset),
},
(None, None, offset) => match chunk_max {
Some((idx, val)) => (Some(idx + offset), Some(val), new_offset),
None => (None, None, new_offset),
},
_ => unreachable!(),
}
})
.0
},
}
}
fn arg_min_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
where
for<'a> &'a [T]: ArgMinMax,
{
match is_sorted {
IsSorted::Ascending => Some(0),
IsSorted::Descending => Some(vals.len() - 1),
IsSorted::Not => Some(vals.argmin()), }
}
fn arg_max_numeric_slice<T>(vals: &[T], is_sorted: IsSorted) -> Option<usize>
where
for<'a> &'a [T]: ArgMinMax,
{
match is_sorted {
IsSorted::Ascending => Some(vals.len() - 1),
IsSorted::Descending => Some(0),
IsSorted::Not => Some(vals.argmax()), }
}