1use super::*;
2
3#[derive(Clone, Copy)]
4enum CmpOp {
5 Lt,
6 Le,
7 Gt,
8 Ge,
9}
10
11fn bitonic_mask<T: PolarsNumericType>(
19 ca: &ChunkedArray<T>,
20 f_a: Option<CmpOp>,
21 f_d: Option<CmpOp>,
22 rhs: &T::Native,
23 invert: bool,
24) -> BooleanChunked {
25 fn apply<T: PolarsNumericType>(op: CmpOp, x: T::Native, rhs: &T::Native) -> bool {
26 match op {
27 CmpOp::Lt => x.tot_lt(rhs),
28 CmpOp::Le => x.tot_le(rhs),
29 CmpOp::Gt => x.tot_gt(rhs),
30 CmpOp::Ge => x.tot_ge(rhs),
31 }
32 }
33 let mut output_order: Option<IsSorted> = None;
34 let mut last_value: Option<bool> = None;
35 let mut logical_extend = |len: usize, val: bool| {
36 if len != 0 {
37 if let Some(last_value) = last_value {
38 output_order = match (last_value, val, output_order) {
39 (false, true, None) => Some(IsSorted::Ascending),
40 (false, true, _) => Some(IsSorted::Not),
41 (true, false, None) => Some(IsSorted::Descending),
42 (true, false, _) => Some(IsSorted::Not),
43 _ => output_order,
44 };
45 }
46 last_value = Some(val);
47 }
48 };
49
50 let chunks = ca.downcast_iter().map(|arr| {
51 let values = arr.values();
52 let true_range_start = if let Some(f_a) = f_a {
53 values.partition_point(|x| !apply::<T>(f_a, *x, rhs))
54 } else {
55 0
56 };
57 let true_range_end = if let Some(f_d) = f_d {
58 true_range_start
59 + values[true_range_start..].partition_point(|x| apply::<T>(f_d, *x, rhs))
60 } else {
61 values.len()
62 };
63 let mut mask = BitmapBuilder::with_capacity(arr.len());
64 mask.extend_constant(true_range_start, invert);
65 mask.extend_constant(true_range_end - true_range_start, !invert);
66 mask.extend_constant(arr.len() - true_range_end, invert);
67 logical_extend(true_range_start, invert);
68 logical_extend(true_range_end - true_range_start, !invert);
69 logical_extend(arr.len() - true_range_end, invert);
70 BooleanArray::from_data_default(mask.freeze(), None)
71 });
72
73 let mut ca = BooleanChunked::from_chunk_iter(ca.name().clone(), chunks);
74 ca.set_sorted_flag(output_order.unwrap_or(IsSorted::Ascending));
75 ca
76}
77
78impl<T, Rhs> ChunkCompareEq<Rhs> for ChunkedArray<T>
79where
80 T: PolarsNumericType,
81 Rhs: ToPrimitive,
82 T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
83{
84 type Item = BooleanChunked;
85
86 fn equal(&self, rhs: Rhs) -> BooleanChunked {
87 let rhs: T::Native = NumCast::from(rhs).unwrap();
88 let fa = Some(CmpOp::Ge);
89 let fd = Some(CmpOp::Le);
90 match (self.is_sorted_flag(), self.null_count()) {
91 (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
92 (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
93 _ => arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(&rhs).into()),
94 }
95 }
96
97 fn equal_missing(&self, rhs: Rhs) -> BooleanChunked {
98 if self.null_count() == 0 {
99 self.equal(rhs)
100 } else {
101 let rhs: T::Native = NumCast::from(rhs).unwrap();
102 arity::unary_mut_with_options(self, |arr| {
103 arr.tot_eq_missing_kernel_broadcast(&rhs).into()
104 })
105 }
106 }
107
108 fn not_equal(&self, rhs: Rhs) -> BooleanChunked {
109 let rhs: T::Native = NumCast::from(rhs).unwrap();
110 let fa = Some(CmpOp::Ge);
111 let fd = Some(CmpOp::Le);
112 match (self.is_sorted_flag(), self.null_count()) {
113 (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, true),
114 (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, true),
115 _ => arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(&rhs).into()),
116 }
117 }
118
119 fn not_equal_missing(&self, rhs: Rhs) -> BooleanChunked {
120 if self.null_count() == 0 {
121 self.not_equal(rhs)
122 } else {
123 let rhs: T::Native = NumCast::from(rhs).unwrap();
124 arity::unary_mut_with_options(self, |arr| {
125 arr.tot_ne_missing_kernel_broadcast(&rhs).into()
126 })
127 }
128 }
129}
130
131impl<T, Rhs> ChunkCompareIneq<Rhs> for ChunkedArray<T>
132where
133 T: PolarsNumericType,
134 Rhs: ToPrimitive,
135 T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
136{
137 type Item = BooleanChunked;
138
139 fn gt(&self, rhs: Rhs) -> BooleanChunked {
140 let rhs: T::Native = NumCast::from(rhs).unwrap();
141 let fa = Some(CmpOp::Gt);
142 let fd = None;
143 match (self.is_sorted_flag(), self.null_count()) {
144 (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
145 (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
146 _ => arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(&rhs).into()),
147 }
148 }
149
150 fn gt_eq(&self, rhs: Rhs) -> BooleanChunked {
151 let rhs: T::Native = NumCast::from(rhs).unwrap();
152 let fa = Some(CmpOp::Ge);
153 let fd = None;
154 match (self.is_sorted_flag(), self.null_count()) {
155 (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
156 (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
157 _ => arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(&rhs).into()),
158 }
159 }
160
161 fn lt(&self, rhs: Rhs) -> BooleanChunked {
162 let rhs: T::Native = NumCast::from(rhs).unwrap();
163 let fa = None;
164 let fd = Some(CmpOp::Lt);
165 match (self.is_sorted_flag(), self.null_count()) {
166 (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
167 (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
168 _ => arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(&rhs).into()),
169 }
170 }
171
172 fn lt_eq(&self, rhs: Rhs) -> BooleanChunked {
173 let rhs: T::Native = NumCast::from(rhs).unwrap();
174 let fa = None;
175 let fd = Some(CmpOp::Le);
176 match (self.is_sorted_flag(), self.null_count()) {
177 (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
178 (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
179 _ => arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(&rhs).into()),
180 }
181 }
182}
183
184macro_rules! binary_eq_ineq_impl {
185 ($($ca:ident),+) => {
186 $(
187 impl ChunkCompareEq<&[u8]> for $ca {
188 type Item = BooleanChunked;
189
190 fn equal(&self, rhs: &[u8]) -> BooleanChunked {
191 arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(rhs).into())
192 }
193
194 fn equal_missing(&self, rhs: &[u8]) -> BooleanChunked {
195 arity::unary_mut_with_options(self, |arr| arr.tot_eq_missing_kernel_broadcast(rhs).into())
196 }
197
198 fn not_equal(&self, rhs: &[u8]) -> BooleanChunked {
199 arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(rhs).into())
200 }
201
202 fn not_equal_missing(&self, rhs: &[u8]) -> BooleanChunked {
203 arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into())
204 }
205 }
206
207 impl ChunkCompareIneq<&[u8]> for $ca {
208 type Item = BooleanChunked;
209
210 fn gt(&self, rhs: &[u8]) -> BooleanChunked {
211 arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into())
212 }
213
214 fn gt_eq(&self, rhs: &[u8]) -> BooleanChunked {
215 arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(rhs).into())
216 }
217
218 fn lt(&self, rhs: &[u8]) -> BooleanChunked {
219 arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(rhs).into())
220 }
221
222 fn lt_eq(&self, rhs: &[u8]) -> BooleanChunked {
223 arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(rhs).into())
224 }
225 }
226 )+
227 };
228}
229
230binary_eq_ineq_impl!(BinaryChunked, BinaryOffsetChunked);
231
232impl ChunkCompareEq<&str> for StringChunked {
233 type Item = BooleanChunked;
234
235 fn equal(&self, rhs: &str) -> BooleanChunked {
236 arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(rhs).into())
237 }
238
239 fn equal_missing(&self, rhs: &str) -> BooleanChunked {
240 arity::unary_mut_with_options(self, |arr| arr.tot_eq_missing_kernel_broadcast(rhs).into())
241 }
242
243 fn not_equal(&self, rhs: &str) -> BooleanChunked {
244 arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(rhs).into())
245 }
246
247 fn not_equal_missing(&self, rhs: &str) -> BooleanChunked {
248 arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into())
249 }
250}
251
252impl ChunkCompareIneq<&str> for StringChunked {
253 type Item = BooleanChunked;
254
255 fn gt(&self, rhs: &str) -> BooleanChunked {
256 arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into())
257 }
258
259 fn gt_eq(&self, rhs: &str) -> BooleanChunked {
260 arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(rhs).into())
261 }
262
263 fn lt(&self, rhs: &str) -> BooleanChunked {
264 arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(rhs).into())
265 }
266
267 fn lt_eq(&self, rhs: &str) -> BooleanChunked {
268 arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(rhs).into())
269 }
270}
271
272#[cfg(test)]
273mod test {
274 use super::*;
275
276 #[test]
277 fn test_binary_search_cmp() {
278 let mut s = Series::new(PlSmallStr::EMPTY, &[1, 1, 2, 2, 4, 8]);
279 s.set_sorted_flag(IsSorted::Ascending);
280 let out = s.gt(10).unwrap();
281 assert!(!out.any());
282
283 let out = s.gt(0).unwrap();
284 assert!(out.all());
285
286 let out = s.gt(2).unwrap();
287 assert_eq!(
288 out.into_series(),
289 Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true])
290 );
291 let out = s.gt(3).unwrap();
292 assert_eq!(
293 out.into_series(),
294 Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true])
295 );
296
297 let out = s.gt_eq(10).unwrap();
298 assert!(!out.any());
299 let out = s.gt_eq(0).unwrap();
300 assert!(out.all());
301
302 let out = s.gt_eq(2).unwrap();
303 assert_eq!(
304 out.into_series(),
305 Series::new(PlSmallStr::EMPTY, [false, false, true, true, true, true])
306 );
307 let out = s.gt_eq(3).unwrap();
308 assert_eq!(
309 out.into_series(),
310 Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true])
311 );
312
313 let out = s.lt(10).unwrap();
314 assert!(out.all());
315 let out = s.lt(0).unwrap();
316 assert!(!out.any());
317
318 let out = s.lt(2).unwrap();
319 assert_eq!(
320 out.into_series(),
321 Series::new(PlSmallStr::EMPTY, [true, true, false, false, false, false])
322 );
323 let out = s.lt(3).unwrap();
324 assert_eq!(
325 out.into_series(),
326 Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false])
327 );
328
329 let out = s.lt_eq(10).unwrap();
330 assert!(out.all());
331 let out = s.lt_eq(0).unwrap();
332 assert!(!out.any());
333
334 let out = s.lt_eq(2).unwrap();
335 assert_eq!(
336 out.into_series(),
337 Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false])
338 );
339 let out = s.lt(3).unwrap();
340 assert_eq!(
341 out.into_series(),
342 Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false])
343 );
344 }
345}