polars_core/chunked_array/ops/
search_sorted.rs

1use std::fmt::Debug;
2
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5
6use crate::prelude::*;
7
8#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default)]
9#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
10#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
11pub enum SearchSortedSide {
12    #[default]
13    Any,
14    Left,
15    Right,
16}
17
18/// Computes the first point on [lo, hi) where f is true, assuming it is first
19/// always false and then always true. It is assumed f(hi) is true.
20/// midpoint is a function that returns some lo < i < hi if one exists, else lo.
21fn lower_bound<I, F, M>(mut lo: I, mut hi: I, midpoint: M, f: F) -> I
22where
23    I: PartialEq + Eq,
24    M: Fn(&I, &I) -> I,
25    F: Fn(&I) -> bool,
26{
27    loop {
28        let m = midpoint(&lo, &hi);
29        if m == lo {
30            return if f(&lo) { lo } else { hi };
31        }
32
33        if f(&m) {
34            hi = m;
35        } else {
36            lo = m;
37        }
38    }
39}
40
41/// Search through a series of chunks for the first position where f(x) is true,
42/// assuming it is first always false and then always true.
43///
44/// It repeats this for each value in search_values. If the search value is null null_idx is
45/// returned.
46///
47/// Assumes the chunks are non-empty.
48pub fn lower_bound_chunks<'a, T, F>(
49    chunks: &[&'a T::Array],
50    search_values: impl Iterator<Item = Option<T::Physical<'a>>>,
51    null_idx: IdxSize,
52    f: F,
53) -> Vec<IdxSize>
54where
55    T: PolarsDataType,
56    F: Fn(&'a T::Array, usize, &T::Physical<'a>) -> bool,
57{
58    if chunks.is_empty() {
59        return search_values.map(|_| 0).collect();
60    }
61
62    // Fast-path: only a single chunk.
63    if chunks.len() == 1 {
64        let chunk = &chunks[0];
65        return search_values
66            .map(|ov| {
67                if let Some(v) = ov {
68                    lower_bound(0, chunk.len(), |l, r| (l + r) / 2, |m| f(chunk, *m, &v)) as IdxSize
69                } else {
70                    null_idx
71                }
72            })
73            .collect();
74    }
75
76    // Multiple chunks, precompute prefix sum of lengths so we can look up
77    // in O(1) the global position of chunk i.
78    let mut sz = 0;
79    let mut chunk_len_prefix_sum = Vec::with_capacity(chunks.len() + 1);
80    for c in chunks {
81        chunk_len_prefix_sum.push(sz);
82        sz += c.len();
83    }
84    chunk_len_prefix_sum.push(sz);
85
86    // For each search value do a binary search on (chunk_idx, idx_in_chunk) pairs.
87    search_values
88        .map(|ov| {
89            let Some(v) = ov else {
90                return null_idx;
91            };
92            let left = (0, 0);
93            let right = (chunks.len(), 0);
94            let midpoint = |l: &(usize, usize), r: &(usize, usize)| {
95                if l.0 == r.0 {
96                    // Within same chunk.
97                    (l.0, (l.1 + r.1) / 2)
98                } else if l.0 + 1 == r.0 {
99                    // Two adjacent chunks, might have to be l or r.
100                    let left_len = chunks[l.0].len() - l.1;
101
102                    let logical_mid = (left_len + r.1) / 2;
103                    if logical_mid < left_len {
104                        (l.0, l.1 + logical_mid)
105                    } else {
106                        (r.0, logical_mid - left_len)
107                    }
108                } else {
109                    // Has a chunk in between.
110                    ((l.0 + r.0) / 2, 0)
111                }
112            };
113
114            let bound = lower_bound(left, right, midpoint, |m| {
115                f(unsafe { chunks.get_unchecked(m.0) }, m.1, &v)
116            });
117
118            (chunk_len_prefix_sum[bound.0] + bound.1) as IdxSize
119        })
120        .collect()
121}
122
123#[allow(clippy::collapsible_else_if)]
124pub fn binary_search_ca<'a, T>(
125    ca: &'a ChunkedArray<T>,
126    search_values: impl Iterator<Item = Option<T::Physical<'a>>>,
127    side: SearchSortedSide,
128    descending: bool,
129) -> Vec<IdxSize>
130where
131    T: PolarsDataType,
132    T::Physical<'a>: TotalOrd + Debug + Copy,
133{
134    let chunks: Vec<_> = ca.downcast_iter().filter(|c| c.len() > 0).collect();
135    let has_nulls = ca.null_count() > 0;
136    let nulls_last = has_nulls && chunks[0].get(0).is_some();
137    let null_idx = if nulls_last {
138        if side == SearchSortedSide::Right {
139            ca.len()
140        } else {
141            ca.len() - ca.null_count()
142        }
143    } else {
144        if side == SearchSortedSide::Right {
145            ca.null_count()
146        } else {
147            0
148        }
149    } as IdxSize;
150
151    if !descending {
152        if !has_nulls {
153            if side == SearchSortedSide::Right {
154                lower_bound_chunks::<T, _>(
155                    &chunks,
156                    search_values,
157                    null_idx,
158                    |chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_gt(sv) },
159                )
160            } else {
161                lower_bound_chunks::<T, _>(
162                    &chunks,
163                    search_values,
164                    null_idx,
165                    |chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_ge(sv) },
166                )
167            }
168        } else {
169            if side == SearchSortedSide::Right {
170                lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
171                    if let Some(v) = unsafe { chunk.get_unchecked(i) } {
172                        v.tot_gt(sv)
173                    } else {
174                        nulls_last
175                    }
176                })
177            } else {
178                lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
179                    if let Some(v) = unsafe { chunk.get_unchecked(i) } {
180                        v.tot_ge(sv)
181                    } else {
182                        nulls_last
183                    }
184                })
185            }
186        }
187    } else {
188        if !has_nulls {
189            if side == SearchSortedSide::Right {
190                lower_bound_chunks::<T, _>(
191                    &chunks,
192                    search_values,
193                    null_idx,
194                    |chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_lt(sv) },
195                )
196            } else {
197                lower_bound_chunks::<T, _>(
198                    &chunks,
199                    search_values,
200                    null_idx,
201                    |chunk, i, sv| unsafe { chunk.value_unchecked(i).tot_le(sv) },
202                )
203            }
204        } else {
205            if side == SearchSortedSide::Right {
206                lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
207                    if let Some(v) = unsafe { chunk.get_unchecked(i) } {
208                        v.tot_lt(sv)
209                    } else {
210                        nulls_last
211                    }
212                })
213            } else {
214                lower_bound_chunks::<T, _>(&chunks, search_values, null_idx, |chunk, i, sv| {
215                    if let Some(v) = unsafe { chunk.get_unchecked(i) } {
216                        v.tot_le(sv)
217                    } else {
218                        nulls_last
219                    }
220                })
221            }
222        }
223    }
224}