Skip to main content

polars_core/chunked_array/ops/
float_sorted_arg_max.rs

1use num_traits::Float;
2
3use self::search_sorted::{SearchSortedSide, binary_search_ca};
4use crate::prelude::*;
5
6impl<T> ChunkedArray<T>
7where
8    T: PolarsFloatType,
9    T::Native: Float,
10{
11    fn float_arg_max_sorted_ascending(&self) -> usize {
12        let ca = self;
13        debug_assert!(ca.is_sorted_ascending_flag());
14
15        let maybe_max_idx = ca.last_non_null().unwrap();
16        let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) };
17        if !maybe_max.is_nan() {
18            return maybe_max_idx;
19        }
20
21        let search_val = std::iter::once(Some(T::Native::nan()));
22        let idx = binary_search_ca(ca, search_val, SearchSortedSide::Left, false)[0] as usize;
23        idx.saturating_sub(1)
24    }
25
26    fn float_arg_max_sorted_descending(&self) -> usize {
27        let ca = self;
28        debug_assert!(ca.is_sorted_descending_flag());
29
30        let maybe_max_idx = ca.first_non_null().unwrap();
31
32        let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) };
33        if !maybe_max.is_nan() {
34            return maybe_max_idx;
35        }
36
37        let search_val = std::iter::once(Some(T::Native::nan()));
38        let idx = binary_search_ca(ca, search_val, SearchSortedSide::Right, true)[0] as usize;
39        if idx == ca.len() { idx - 1 } else { idx }
40    }
41}
42
43/// # Safety
44/// `ca` has a float dtype, has at least 1 non-null value and is sorted ascending
45pub fn float_arg_max_sorted_ascending<T>(ca: &ChunkedArray<T>) -> usize
46where
47    T: PolarsNumericType,
48{
49    with_match_physical_float_polars_type!(ca.dtype(), |$T| {
50        let ca: &ChunkedArray<$T> = unsafe {
51            &*(ca as *const ChunkedArray<T> as *const ChunkedArray<$T>)
52        };
53        ca.float_arg_max_sorted_ascending()
54    })
55}
56
57/// # Safety
58/// `ca` has a float dtype, has at least 1 non-null value and is sorted descending
59pub fn float_arg_max_sorted_descending<T>(ca: &ChunkedArray<T>) -> usize
60where
61    T: PolarsNumericType,
62{
63    with_match_physical_float_polars_type!(ca.dtype(), |$T| {
64        let ca: &ChunkedArray<$T> = unsafe {
65            &*(ca as *const ChunkedArray<T> as *const ChunkedArray<$T>)
66        };
67        ca.float_arg_max_sorted_descending()
68    })
69}