Skip to main content

polars_core/chunked_array/ops/
search_sorted.rs

1use std::fmt::Debug;
2
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5
6use crate::prelude::*;
7
8#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default)]
9#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
10#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
11pub enum SearchSortedSide {
12    #[default]
13    Any,
14    Left,
15    Right,
16}
17
18impl SearchSortedSide {
19    pub fn flip(self) -> Self {
20        use SearchSortedSide::*;
21        match self {
22            Any => Any,
23            Left => Right,
24            Right => Left,
25        }
26    }
27}
28
29/// Computes the first point on [lo, hi) where f is true, assuming it is first
30/// always false and then always true. It is assumed f(hi) is true.
31/// midpoint is a function that returns some lo < i < hi if one exists, else lo.
32fn lower_bound<I, F, M>(mut lo: I, mut hi: I, midpoint: M, f: F) -> I
33where
34    I: PartialEq + Eq,
35    M: Fn(&I, &I) -> I,
36    F: Fn(&I) -> bool,
37{
38    loop {
39        let m = midpoint(&lo, &hi);
40        if m == lo {
41            return if f(&lo) { lo } else { hi };
42        }
43
44        if f(&m) {
45            hi = m;
46        } else {
47            lo = m;
48        }
49    }
50}
51
52/// Search through a series of chunks for the first position where f(x) is true,
53/// assuming it is first always false and then always true.
54///
55/// It repeats this for each value in search_values. If the search value is null null_idx is
56/// returned.
57///
58/// Assumes the chunks are non-empty.
59pub fn lower_bound_chunks<'a, T, F>(
60    chunks: &[&'a T::Array],
61    search_values: impl Iterator<Item = Option<T::Physical<'a>>>,
62    null_idx: IdxSize,
63    f: F,
64) -> Vec<IdxSize>
65where
66    T: PolarsDataType,
67    F: Fn(&'a T::Array, usize, &T::Physical<'a>) -> bool,
68{
69    if chunks.is_empty() {
70        return search_values.map(|_| 0).collect();
71    }
72
73    // Fast-path: only a single chunk.
74    if chunks.len() == 1 {
75        let chunk = &chunks[0];
76        return search_values
77            .map(|ov| {
78                if let Some(v) = ov {
79                    lower_bound(0, chunk.len(), |l, r| (l + r) / 2, |m| f(chunk, *m, &v)) as IdxSize
80                } else {
81                    null_idx
82                }
83            })
84            .collect();
85    }
86
87    // Multiple chunks, precompute prefix sum of lengths so we can look up
88    // in O(1) the global position of chunk i.
89    let mut sz = 0;
90    let mut chunk_len_prefix_sum = Vec::with_capacity(chunks.len() + 1);
91    for c in chunks {
92        chunk_len_prefix_sum.push(sz);
93        sz += c.len();
94    }
95    chunk_len_prefix_sum.push(sz);
96
97    // For each search value do a binary search on (chunk_idx, idx_in_chunk) pairs.
98    search_values
99        .map(|ov| {
100            let Some(v) = ov else {
101                return null_idx;
102            };
103            let left = (0, 0);
104            let right = (chunks.len(), 0);
105            let midpoint = |l: &(usize, usize), r: &(usize, usize)| {
106                if l.0 == r.0 {
107                    // Within same chunk.
108                    (l.0, (l.1 + r.1) / 2)
109                } else if l.0 + 1 == r.0 {
110                    // Two adjacent chunks, might have to be l or r.
111                    let left_len = chunks[l.0].len() - l.1;
112
113                    let logical_mid = (left_len + r.1) / 2;
114                    if logical_mid < left_len {
115                        (l.0, l.1 + logical_mid)
116                    } else {
117                        (r.0, logical_mid - left_len)
118                    }
119                } else {
120                    // Has a chunk in between.
121                    ((l.0 + r.0) / 2, 0)
122                }
123            };
124
125            let bound = lower_bound(left, right, midpoint, |m| {
126                f(unsafe { chunks.get_unchecked(m.0) }, m.1, &v)
127            });
128
129            (chunk_len_prefix_sum[bound.0] + bound.1) as IdxSize
130        })
131        .collect()
132}
133
134#[allow(clippy::collapsible_else_if)]
135pub fn binary_search_ca<'a, T>(
136    ca: &'a ChunkedArray<T>,
137    search_values: impl Iterator<Item = Option<T::Physical<'a>>>,
138    side: SearchSortedSide,
139    descending: bool,
140) -> Vec<IdxSize>
141where
142    T: PolarsDataType,
143    T::Physical<'a>: TotalOrd + Debug + Copy,
144{
145    let chunks: Vec<_> = ca.downcast_iter().filter(|c| c.len() > 0).collect();
146    let has_nulls = ca.null_count() > 0;
147    let nulls_last = has_nulls && chunks[0].get(0).is_some();
148    let null_idx = if nulls_last {
149        if side == SearchSortedSide::Right {
150            ca.len()
151        } else {
152            ca.len() - ca.null_count()
153        }
154    } else {
155        if side == SearchSortedSide::Right {
156            ca.null_count()
157        } else {
158            0
159        }
160    } as IdxSize;
161
162    if !descending {
163        if !has_nulls {
164            if side == SearchSortedSide::Right {
165                lower_bound_chunks::<T, _>(
166                    &chunks,
167                    search_values,
168                    null_idx,
169                    |chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_gt(sv) },
170                )
171            } else {
172                lower_bound_chunks::<T, _>(
173                    &chunks,
174                    search_values,
175                    null_idx,
176                    |chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_ge(sv) },
177                )
178            }
179        } else {
180            if side == SearchSortedSide::Right {
181                lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
182                    if let Some(v) = unsafe { chunk.get_unchecked(i) } {
183                        v.tot_gt(sv)
184                    } else {
185                        nulls_last
186                    }
187                })
188            } else {
189                lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
190                    if let Some(v) = unsafe { chunk.get_unchecked(i) } {
191                        v.tot_ge(sv)
192                    } else {
193                        nulls_last
194                    }
195                })
196            }
197        }
198    } else {
199        if !has_nulls {
200            if side == SearchSortedSide::Right {
201                lower_bound_chunks::<T, _>(
202                    &chunks,
203                    search_values,
204                    null_idx,
205                    |chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_lt(sv) },
206                )
207            } else {
208                lower_bound_chunks::<T, _>(
209                    &chunks,
210                    search_values,
211                    null_idx,
212                    |chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_le(sv) },
213                )
214            }
215        } else {
216            if side == SearchSortedSide::Right {
217                lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
218                    if let Some(v) = unsafe { chunk.get_unchecked(i) } {
219                        v.tot_lt(sv)
220                    } else {
221                        nulls_last
222                    }
223                })
224            } else {
225                lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
226                    if let Some(v) = unsafe { chunk.get_unchecked(i) } {
227                        v.tot_le(sv)
228                    } else {
229                        nulls_last
230                    }
231                })
232            }
233        }
234    }
235}