polars_core/chunked_array/comparison/
scalar.rs

1use super::*;
2
3#[derive(Clone, Copy)]
4enum CmpOp {
5    Lt,
6    Le,
7    Gt,
8    Ge,
9}
10
11// Given two monotonic functions f_a and f_d where f_a is ascending
12// (f_a(x[0]) <= f_a(x[1]) <= .. <= f_a(x[n-1])) and f_d is descending
13// (f_d(x[0]) >= f_d(x[1]) >= .. >= f_d(x[n-1])),
14// outputs a mask where both are true.
15//
16// If a function is not given it is always assumed to be true. If invert is
17// true the output mask is inverted.
18fn bitonic_mask<T: PolarsNumericType>(
19    ca: &ChunkedArray<T>,
20    f_a: Option<CmpOp>,
21    f_d: Option<CmpOp>,
22    rhs: &T::Native,
23    invert: bool,
24) -> BooleanChunked {
25    fn apply<T: PolarsNumericType>(op: CmpOp, x: T::Native, rhs: &T::Native) -> bool {
26        match op {
27            CmpOp::Lt => x.tot_lt(rhs),
28            CmpOp::Le => x.tot_le(rhs),
29            CmpOp::Gt => x.tot_gt(rhs),
30            CmpOp::Ge => x.tot_ge(rhs),
31        }
32    }
33    let mut output_order: Option<IsSorted> = None;
34    let mut last_value: Option<bool> = None;
35    let mut logical_extend = |len: usize, val: bool| {
36        if len != 0 {
37            if let Some(last_value) = last_value {
38                output_order = match (last_value, val, output_order) {
39                    (false, true, None) => Some(IsSorted::Ascending),
40                    (false, true, _) => Some(IsSorted::Not),
41                    (true, false, None) => Some(IsSorted::Descending),
42                    (true, false, _) => Some(IsSorted::Not),
43                    _ => output_order,
44                };
45            }
46            last_value = Some(val);
47        }
48    };
49
50    let chunks = ca.downcast_iter().map(|arr| {
51        let values = arr.values();
52        let true_range_start = if let Some(f_a) = f_a {
53            values.partition_point(|x| !apply::<T>(f_a, *x, rhs))
54        } else {
55            0
56        };
57        let true_range_end = if let Some(f_d) = f_d {
58            true_range_start
59                + values[true_range_start..].partition_point(|x| apply::<T>(f_d, *x, rhs))
60        } else {
61            values.len()
62        };
63        let mut mask = BitmapBuilder::with_capacity(arr.len());
64        mask.extend_constant(true_range_start, invert);
65        mask.extend_constant(true_range_end - true_range_start, !invert);
66        mask.extend_constant(arr.len() - true_range_end, invert);
67        logical_extend(true_range_start, invert);
68        logical_extend(true_range_end - true_range_start, !invert);
69        logical_extend(arr.len() - true_range_end, invert);
70        BooleanArray::from_data_default(mask.freeze(), None)
71    });
72
73    let mut ca = BooleanChunked::from_chunk_iter(ca.name().clone(), chunks);
74    ca.set_sorted_flag(output_order.unwrap_or(IsSorted::Ascending));
75    ca
76}
77
78impl<T, Rhs> ChunkCompareEq<Rhs> for ChunkedArray<T>
79where
80    T: PolarsNumericType,
81    Rhs: ToPrimitive,
82    T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
83{
84    type Item = BooleanChunked;
85
86    fn equal(&self, rhs: Rhs) -> BooleanChunked {
87        let rhs: T::Native = NumCast::from(rhs).unwrap();
88        let fa = Some(CmpOp::Ge);
89        let fd = Some(CmpOp::Le);
90        match (self.is_sorted_flag(), self.null_count()) {
91            (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
92            (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
93            _ => arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(&rhs).into()),
94        }
95    }
96
97    fn equal_missing(&self, rhs: Rhs) -> BooleanChunked {
98        if self.null_count() == 0 {
99            self.equal(rhs)
100        } else {
101            let rhs: T::Native = NumCast::from(rhs).unwrap();
102            arity::unary_mut_with_options(self, |arr| {
103                arr.tot_eq_missing_kernel_broadcast(&rhs).into()
104            })
105        }
106    }
107
108    fn not_equal(&self, rhs: Rhs) -> BooleanChunked {
109        let rhs: T::Native = NumCast::from(rhs).unwrap();
110        let fa = Some(CmpOp::Ge);
111        let fd = Some(CmpOp::Le);
112        match (self.is_sorted_flag(), self.null_count()) {
113            (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, true),
114            (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, true),
115            _ => arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(&rhs).into()),
116        }
117    }
118
119    fn not_equal_missing(&self, rhs: Rhs) -> BooleanChunked {
120        if self.null_count() == 0 {
121            self.not_equal(rhs)
122        } else {
123            let rhs: T::Native = NumCast::from(rhs).unwrap();
124            arity::unary_mut_with_options(self, |arr| {
125                arr.tot_ne_missing_kernel_broadcast(&rhs).into()
126            })
127        }
128    }
129}
130
131impl<T, Rhs> ChunkCompareIneq<Rhs> for ChunkedArray<T>
132where
133    T: PolarsNumericType,
134    Rhs: ToPrimitive,
135    T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
136{
137    type Item = BooleanChunked;
138
139    fn gt(&self, rhs: Rhs) -> BooleanChunked {
140        let rhs: T::Native = NumCast::from(rhs).unwrap();
141        let fa = Some(CmpOp::Gt);
142        let fd = None;
143        match (self.is_sorted_flag(), self.null_count()) {
144            (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
145            (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
146            _ => arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(&rhs).into()),
147        }
148    }
149
150    fn gt_eq(&self, rhs: Rhs) -> BooleanChunked {
151        let rhs: T::Native = NumCast::from(rhs).unwrap();
152        let fa = Some(CmpOp::Ge);
153        let fd = None;
154        match (self.is_sorted_flag(), self.null_count()) {
155            (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
156            (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
157            _ => arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(&rhs).into()),
158        }
159    }
160
161    fn lt(&self, rhs: Rhs) -> BooleanChunked {
162        let rhs: T::Native = NumCast::from(rhs).unwrap();
163        let fa = None;
164        let fd = Some(CmpOp::Lt);
165        match (self.is_sorted_flag(), self.null_count()) {
166            (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
167            (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
168            _ => arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(&rhs).into()),
169        }
170    }
171
172    fn lt_eq(&self, rhs: Rhs) -> BooleanChunked {
173        let rhs: T::Native = NumCast::from(rhs).unwrap();
174        let fa = None;
175        let fd = Some(CmpOp::Le);
176        match (self.is_sorted_flag(), self.null_count()) {
177            (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
178            (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
179            _ => arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(&rhs).into()),
180        }
181    }
182}
183
184macro_rules! binary_eq_ineq_impl {
185    ($($ca:ident),+) => {
186        $(
187        impl ChunkCompareEq<&[u8]> for $ca {
188            type Item = BooleanChunked;
189
190            fn equal(&self, rhs: &[u8]) -> BooleanChunked {
191                arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(rhs).into())
192            }
193
194            fn equal_missing(&self, rhs: &[u8]) -> BooleanChunked {
195                arity::unary_mut_with_options(self, |arr| arr.tot_eq_missing_kernel_broadcast(rhs).into())
196            }
197
198            fn not_equal(&self, rhs: &[u8]) -> BooleanChunked {
199                arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(rhs).into())
200            }
201
202            fn not_equal_missing(&self, rhs: &[u8]) -> BooleanChunked {
203                arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into())
204            }
205        }
206
207        impl ChunkCompareIneq<&[u8]> for $ca {
208            type Item = BooleanChunked;
209
210            fn gt(&self, rhs: &[u8]) -> BooleanChunked {
211                arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into())
212            }
213
214            fn gt_eq(&self, rhs: &[u8]) -> BooleanChunked {
215                arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(rhs).into())
216            }
217
218            fn lt(&self, rhs: &[u8]) -> BooleanChunked {
219                arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(rhs).into())
220            }
221
222            fn lt_eq(&self, rhs: &[u8]) -> BooleanChunked {
223                arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(rhs).into())
224            }
225        }
226        )+
227    };
228}
229
230binary_eq_ineq_impl!(BinaryChunked, BinaryOffsetChunked);
231
232impl ChunkCompareEq<&str> for StringChunked {
233    type Item = BooleanChunked;
234
235    fn equal(&self, rhs: &str) -> BooleanChunked {
236        arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(rhs).into())
237    }
238
239    fn equal_missing(&self, rhs: &str) -> BooleanChunked {
240        arity::unary_mut_with_options(self, |arr| arr.tot_eq_missing_kernel_broadcast(rhs).into())
241    }
242
243    fn not_equal(&self, rhs: &str) -> BooleanChunked {
244        arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(rhs).into())
245    }
246
247    fn not_equal_missing(&self, rhs: &str) -> BooleanChunked {
248        arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into())
249    }
250}
251
252impl ChunkCompareIneq<&str> for StringChunked {
253    type Item = BooleanChunked;
254
255    fn gt(&self, rhs: &str) -> BooleanChunked {
256        arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into())
257    }
258
259    fn gt_eq(&self, rhs: &str) -> BooleanChunked {
260        arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(rhs).into())
261    }
262
263    fn lt(&self, rhs: &str) -> BooleanChunked {
264        arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(rhs).into())
265    }
266
267    fn lt_eq(&self, rhs: &str) -> BooleanChunked {
268        arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(rhs).into())
269    }
270}
271
272#[cfg(test)]
273mod test {
274    use super::*;
275
276    #[test]
277    fn test_binary_search_cmp() {
278        let mut s = Series::new(PlSmallStr::EMPTY, &[1, 1, 2, 2, 4, 8]);
279        s.set_sorted_flag(IsSorted::Ascending);
280        let out = s.gt(10).unwrap();
281        assert!(!out.any());
282
283        let out = s.gt(0).unwrap();
284        assert!(out.all());
285
286        let out = s.gt(2).unwrap();
287        assert_eq!(
288            out.into_series(),
289            Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true])
290        );
291        let out = s.gt(3).unwrap();
292        assert_eq!(
293            out.into_series(),
294            Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true])
295        );
296
297        let out = s.gt_eq(10).unwrap();
298        assert!(!out.any());
299        let out = s.gt_eq(0).unwrap();
300        assert!(out.all());
301
302        let out = s.gt_eq(2).unwrap();
303        assert_eq!(
304            out.into_series(),
305            Series::new(PlSmallStr::EMPTY, [false, false, true, true, true, true])
306        );
307        let out = s.gt_eq(3).unwrap();
308        assert_eq!(
309            out.into_series(),
310            Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true])
311        );
312
313        let out = s.lt(10).unwrap();
314        assert!(out.all());
315        let out = s.lt(0).unwrap();
316        assert!(!out.any());
317
318        let out = s.lt(2).unwrap();
319        assert_eq!(
320            out.into_series(),
321            Series::new(PlSmallStr::EMPTY, [true, true, false, false, false, false])
322        );
323        let out = s.lt(3).unwrap();
324        assert_eq!(
325            out.into_series(),
326            Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false])
327        );
328
329        let out = s.lt_eq(10).unwrap();
330        assert!(out.all());
331        let out = s.lt_eq(0).unwrap();
332        assert!(!out.any());
333
334        let out = s.lt_eq(2).unwrap();
335        assert_eq!(
336            out.into_series(),
337            Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false])
338        );
339        let out = s.lt(3).unwrap();
340        assert_eq!(
341            out.into_series(),
342            Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false])
343        );
344    }
345}