polars_utils/
binary_search.rs

1use std::cmp::Ordering;
2use std::cmp::Ordering::{Greater, Less};
3
4/// Find the index of the first element of `arr` that is greater
5/// or equal to `val`.
6/// Assumes that `arr` is sorted.
7pub fn find_first_ge_index<T>(arr: &[T], val: T) -> usize
8where
9    T: Ord,
10{
11    match arr.binary_search(&val) {
12        Ok(x) => x,
13        Err(x) => x,
14    }
15}
16
17/// Find the index of the first element of `arr` that is greater
18/// than `val`.
19/// Assumes that `arr` is sorted.
20pub fn find_first_gt_index<T>(arr: &[T], val: T) -> usize
21where
22    T: Ord,
23{
24    match arr.binary_search(&val) {
25        Ok(x) => x + 1,
26        Err(x) => x,
27    }
28}
29
30// https://en.wikipedia.org/wiki/Exponential_search
31// Use if you expect matches to be close by. Otherwise use binary search.
32pub trait ExponentialSearch<T> {
33    fn exponential_search_by<F>(&self, f: F) -> Result<usize, usize>
34    where
35        F: FnMut(&T) -> Ordering;
36
37    fn partition_point_exponential<P>(&self, mut pred: P) -> usize
38    where
39        P: FnMut(&T) -> bool,
40    {
41        self.exponential_search_by(|x| if pred(x) { Less } else { Greater })
42            .unwrap_or_else(|i| i)
43    }
44}
45
46impl<T: std::fmt::Debug> ExponentialSearch<T> for &[T] {
47    fn exponential_search_by<F>(&self, mut f: F) -> Result<usize, usize>
48    where
49        F: FnMut(&T) -> Ordering,
50    {
51        if self.is_empty() {
52            return Err(0);
53        }
54
55        let mut bound = 1;
56
57        while bound < self.len() {
58            // SAFETY
59            // Bound is always >=0 and < len.
60            let cmp = f(unsafe { self.get_unchecked(bound) });
61
62            if cmp == Greater {
63                break;
64            }
65            bound *= 2
66        }
67        let end_bound = std::cmp::min(self.len(), bound);
68        // SAFETY:
69        // We checked the end bound and previous bound was within slice as per the `while` condition.
70        let prev_bound = bound / 2;
71
72        let slice = unsafe { self.get_unchecked(prev_bound..end_bound) };
73
74        match slice.binary_search_by(f) {
75            Ok(i) => Ok(i + prev_bound),
76            Err(i) => Err(i + prev_bound),
77        }
78    }
79}
80
81#[cfg(test)]
82mod test {
83    use super::*;
84
85    #[test]
86    fn test_partition_point() {
87        let v = [1, 2, 3, 3, 5, 6, 7];
88        let i = v.as_slice().partition_point_exponential(|&x| x < 5);
89        assert_eq!(i, 4);
90    }
91}