polars_core/chunked_array/comparison/
scalar.rs

1use super::*;
2
3#[derive(Clone, Copy)]
4enum CmpOp {
5    Lt,
6    Le,
7    Gt,
8    Ge,
9}
10
11// Given two monotonic functions f_a and f_d where f_a is ascending
12// (f_a(x[0]) <= f_a(x[1]) <= .. <= f_a(x[n-1])) and f_d is descending
13// (f_d(x[0]) >= f_d(x[1]) >= .. >= f_d(x[n-1])),
14// outputs a mask where both are true.
15//
16// If a function is not given it is always assumed to be true. If invert is
17// true the output mask is inverted.
18fn bitonic_mask<T: PolarsNumericType>(
19    ca: &ChunkedArray<T>,
20    f_a: Option<CmpOp>,
21    f_d: Option<CmpOp>,
22    rhs: &T::Native,
23    invert: bool,
24) -> BooleanChunked {
25    fn apply<T: PolarsNumericType>(op: CmpOp, x: T::Native, rhs: &T::Native) -> bool {
26        match op {
27            CmpOp::Lt => x.tot_lt(rhs),
28            CmpOp::Le => x.tot_le(rhs),
29            CmpOp::Gt => x.tot_gt(rhs),
30            CmpOp::Ge => x.tot_ge(rhs),
31        }
32    }
33    let mut output_order: Option<IsSorted> = None;
34    let mut last_value: Option<bool> = None;
35    let mut logical_extend = |len: usize, val: bool| {
36        if len != 0 {
37            if let Some(last_value) = last_value {
38                output_order = match (last_value, val, output_order) {
39                    (false, true, None) => Some(IsSorted::Ascending),
40                    (false, true, _) => Some(IsSorted::Not),
41                    (true, false, None) => Some(IsSorted::Descending),
42                    (true, false, _) => Some(IsSorted::Not),
43                    _ => output_order,
44                };
45            }
46            last_value = Some(val);
47        }
48    };
49
50    let chunks = ca.downcast_iter().map(|arr| {
51        let values = arr.values();
52        let true_range_start = if let Some(f_a) = f_a {
53            values.partition_point(|x| !apply::<T>(f_a, *x, rhs))
54        } else {
55            0
56        };
57        let true_range_end = if let Some(f_d) = f_d {
58            true_range_start
59                + values[true_range_start..].partition_point(|x| apply::<T>(f_d, *x, rhs))
60        } else {
61            values.len()
62        };
63        let mut mask = BitmapBuilder::with_capacity(arr.len());
64        mask.extend_constant(true_range_start, invert);
65        mask.extend_constant(true_range_end - true_range_start, !invert);
66        mask.extend_constant(arr.len() - true_range_end, invert);
67        logical_extend(true_range_start, invert);
68        logical_extend(true_range_end - true_range_start, !invert);
69        logical_extend(arr.len() - true_range_end, invert);
70        BooleanArray::from_data_default(mask.freeze(), None)
71    });
72
73    let mut ca = BooleanChunked::from_chunk_iter(ca.name().clone(), chunks);
74    ca.set_sorted_flag(output_order.unwrap_or(IsSorted::Ascending));
75    ca
76}
77
78impl<T, Rhs> ChunkCompareEq<Rhs> for ChunkedArray<T>
79where
80    T: PolarsNumericType,
81    Rhs: ToPrimitive,
82    T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
83{
84    type Item = BooleanChunked;
85
86    fn equal(&self, rhs: Rhs) -> BooleanChunked {
87        let rhs: T::Native = NumCast::from(rhs).unwrap();
88        let fa = Some(CmpOp::Ge);
89        let fd = Some(CmpOp::Le);
90        match (self.is_sorted_flag(), self.null_count()) {
91            (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
92            (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
93            _ => arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(&rhs).into()),
94        }
95    }
96
97    fn equal_missing(&self, rhs: Rhs) -> BooleanChunked {
98        if self.null_count() == 0 {
99            self.equal(rhs)
100        } else {
101            let rhs: T::Native = NumCast::from(rhs).unwrap();
102            arity::unary_mut_with_options(self, |arr| {
103                arr.tot_eq_missing_kernel_broadcast(&rhs).into()
104            })
105        }
106    }
107
108    fn not_equal(&self, rhs: Rhs) -> BooleanChunked {
109        let rhs: T::Native = NumCast::from(rhs).unwrap();
110        let fa = Some(CmpOp::Ge);
111        let fd = Some(CmpOp::Le);
112        match (self.is_sorted_flag(), self.null_count()) {
113            (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, true),
114            (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, true),
115            _ => arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(&rhs).into()),
116        }
117    }
118
119    fn not_equal_missing(&self, rhs: Rhs) -> BooleanChunked {
120        if self.null_count() == 0 {
121            self.not_equal(rhs)
122        } else {
123            let rhs: T::Native = NumCast::from(rhs).unwrap();
124            arity::unary_mut_with_options(self, |arr| {
125                arr.tot_ne_missing_kernel_broadcast(&rhs).into()
126            })
127        }
128    }
129}
130
131impl<T, Rhs> ChunkCompareIneq<Rhs> for ChunkedArray<T>
132where
133    T: PolarsNumericType,
134    Rhs: ToPrimitive,
135    T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
136{
137    type Item = BooleanChunked;
138
139    fn gt(&self, rhs: Rhs) -> BooleanChunked {
140        let rhs: T::Native = NumCast::from(rhs).unwrap();
141        let fa = Some(CmpOp::Gt);
142        let fd = None;
143        match (self.is_sorted_flag(), self.null_count()) {
144            (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
145            (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
146            _ => arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(&rhs).into()),
147        }
148    }
149
150    fn gt_eq(&self, rhs: Rhs) -> BooleanChunked {
151        let rhs: T::Native = NumCast::from(rhs).unwrap();
152        let fa = Some(CmpOp::Ge);
153        let fd = None;
154        match (self.is_sorted_flag(), self.null_count()) {
155            (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
156            (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
157            _ => arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(&rhs).into()),
158        }
159    }
160
161    fn lt(&self, rhs: Rhs) -> BooleanChunked {
162        let rhs: T::Native = NumCast::from(rhs).unwrap();
163        let fa = None;
164        let fd = Some(CmpOp::Lt);
165        match (self.is_sorted_flag(), self.null_count()) {
166            (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
167            (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
168            _ => arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(&rhs).into()),
169        }
170    }
171
172    fn lt_eq(&self, rhs: Rhs) -> BooleanChunked {
173        let rhs: T::Native = NumCast::from(rhs).unwrap();
174        let fa = None;
175        let fd = Some(CmpOp::Le);
176        match (self.is_sorted_flag(), self.null_count()) {
177            (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
178            (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
179            _ => arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(&rhs).into()),
180        }
181    }
182}
183
184impl ChunkCompareEq<&[u8]> for BinaryChunked {
185    type Item = BooleanChunked;
186
187    fn equal(&self, rhs: &[u8]) -> BooleanChunked {
188        arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(rhs).into())
189    }
190
191    fn equal_missing(&self, rhs: &[u8]) -> BooleanChunked {
192        arity::unary_mut_with_options(self, |arr| arr.tot_eq_missing_kernel_broadcast(rhs).into())
193    }
194
195    fn not_equal(&self, rhs: &[u8]) -> BooleanChunked {
196        arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(rhs).into())
197    }
198
199    fn not_equal_missing(&self, rhs: &[u8]) -> BooleanChunked {
200        arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into())
201    }
202}
203
204impl ChunkCompareIneq<&[u8]> for BinaryChunked {
205    type Item = BooleanChunked;
206
207    fn gt(&self, rhs: &[u8]) -> BooleanChunked {
208        arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into())
209    }
210
211    fn gt_eq(&self, rhs: &[u8]) -> BooleanChunked {
212        arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(rhs).into())
213    }
214
215    fn lt(&self, rhs: &[u8]) -> BooleanChunked {
216        arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(rhs).into())
217    }
218
219    fn lt_eq(&self, rhs: &[u8]) -> BooleanChunked {
220        arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(rhs).into())
221    }
222}
223
224impl ChunkCompareEq<&str> for StringChunked {
225    type Item = BooleanChunked;
226
227    fn equal(&self, rhs: &str) -> BooleanChunked {
228        arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(rhs).into())
229    }
230
231    fn equal_missing(&self, rhs: &str) -> BooleanChunked {
232        arity::unary_mut_with_options(self, |arr| arr.tot_eq_missing_kernel_broadcast(rhs).into())
233    }
234
235    fn not_equal(&self, rhs: &str) -> BooleanChunked {
236        arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(rhs).into())
237    }
238
239    fn not_equal_missing(&self, rhs: &str) -> BooleanChunked {
240        arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into())
241    }
242}
243
244impl ChunkCompareIneq<&str> for StringChunked {
245    type Item = BooleanChunked;
246
247    fn gt(&self, rhs: &str) -> BooleanChunked {
248        arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into())
249    }
250
251    fn gt_eq(&self, rhs: &str) -> BooleanChunked {
252        arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(rhs).into())
253    }
254
255    fn lt(&self, rhs: &str) -> BooleanChunked {
256        arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(rhs).into())
257    }
258
259    fn lt_eq(&self, rhs: &str) -> BooleanChunked {
260        arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(rhs).into())
261    }
262}
263
264#[cfg(test)]
265mod test {
266    use super::*;
267
268    #[test]
269    fn test_binary_search_cmp() {
270        let mut s = Series::new(PlSmallStr::EMPTY, &[1, 1, 2, 2, 4, 8]);
271        s.set_sorted_flag(IsSorted::Ascending);
272        let out = s.gt(10).unwrap();
273        assert!(!out.any());
274
275        let out = s.gt(0).unwrap();
276        assert!(out.all());
277
278        let out = s.gt(2).unwrap();
279        assert_eq!(
280            out.into_series(),
281            Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true])
282        );
283        let out = s.gt(3).unwrap();
284        assert_eq!(
285            out.into_series(),
286            Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true])
287        );
288
289        let out = s.gt_eq(10).unwrap();
290        assert!(!out.any());
291        let out = s.gt_eq(0).unwrap();
292        assert!(out.all());
293
294        let out = s.gt_eq(2).unwrap();
295        assert_eq!(
296            out.into_series(),
297            Series::new(PlSmallStr::EMPTY, [false, false, true, true, true, true])
298        );
299        let out = s.gt_eq(3).unwrap();
300        assert_eq!(
301            out.into_series(),
302            Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true])
303        );
304
305        let out = s.lt(10).unwrap();
306        assert!(out.all());
307        let out = s.lt(0).unwrap();
308        assert!(!out.any());
309
310        let out = s.lt(2).unwrap();
311        assert_eq!(
312            out.into_series(),
313            Series::new(PlSmallStr::EMPTY, [true, true, false, false, false, false])
314        );
315        let out = s.lt(3).unwrap();
316        assert_eq!(
317            out.into_series(),
318            Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false])
319        );
320
321        let out = s.lt_eq(10).unwrap();
322        assert!(out.all());
323        let out = s.lt_eq(0).unwrap();
324        assert!(!out.any());
325
326        let out = s.lt_eq(2).unwrap();
327        assert_eq!(
328            out.into_series(),
329            Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false])
330        );
331        let out = s.lt(3).unwrap();
332        assert_eq!(
333            out.into_series(),
334            Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false])
335        );
336    }
337}