1use arrow::bitmap::Bitmap;
2use arrow::legacy::utils::FromTrustedLenIterator;
3use polars_compute::comparisons::TotalOrdKernel;
4
5use crate::chunked_array::cast::CastOptions;
6use crate::prelude::nulls::replace_non_null;
7use crate::prelude::*;
8
9#[cfg(feature = "dtype-categorical")]
10fn cat_equality_helper<'a, Compare, Missing>(
11 lhs: &'a CategoricalChunked,
12 rhs: &'a CategoricalChunked,
13 missing_function: Missing,
14 compare_function: Compare,
15) -> PolarsResult<BooleanChunked>
16where
17 Compare: Fn(&'a UInt32Chunked, &'a UInt32Chunked) -> BooleanChunked,
18 Missing: Fn(&'a CategoricalChunked) -> BooleanChunked,
19{
20 let rev_map_l = lhs.get_rev_map();
21 polars_ensure!(rev_map_l.same_src(rhs.get_rev_map()), string_cache_mismatch);
22 let rhs = rhs.physical();
23
24 if rhs.len() == 1 && rhs.null_count() == 0 {
26 let rhs = rhs.get(0).unwrap();
27 if rev_map_l.get_optional(rhs).is_none() {
28 return Ok(missing_function(lhs));
29 }
30 }
31 Ok(compare_function(lhs.physical(), rhs))
32}
33
34fn cat_compare_helper<'a, Compare, CompareString>(
35 lhs: &'a CategoricalChunked,
36 rhs: &'a CategoricalChunked,
37 compare_function: Compare,
38 compare_str_function: CompareString,
39) -> PolarsResult<BooleanChunked>
40where
41 Compare: Fn(&'a UInt32Chunked, &'a UInt32Chunked) -> BooleanChunked,
42 CompareString: Fn(&str, &str) -> bool,
43{
44 let rev_map_l = lhs.get_rev_map();
45 let rev_map_r = rhs.get_rev_map();
46 polars_ensure!(rev_map_l.same_src(rev_map_r), ComputeError: "can only compare categoricals of the same type with the same categories");
47
48 if lhs.is_enum() || !lhs.uses_lexical_ordering() {
49 Ok(compare_function(lhs.physical(), rhs.physical()))
50 } else {
51 match (lhs.len(), rhs.len()) {
52 (lhs_len, 1) => {
53 let v = unsafe {
55 rhs.physical()
56 .get(0)
57 .map(|phys| rev_map_r.get_unchecked(phys))
58 };
59 let Some(v) = v else {
60 return Ok(BooleanChunked::full_null(lhs.name().clone(), lhs_len));
61 };
62
63 Ok(lhs
64 .iter_str()
65 .map(|opt_s| opt_s.map(|s| compare_str_function(s, v)))
66 .collect_ca_trusted(lhs.name().clone()))
67 },
68 (1, rhs_len) => {
69 let v = unsafe {
71 lhs.physical()
72 .get(0)
73 .map(|phys| rev_map_l.get_unchecked(phys))
74 };
75 let Some(v) = v else {
76 return Ok(BooleanChunked::full_null(lhs.name().clone(), rhs_len));
77 };
78 Ok(rhs
79 .iter_str()
80 .map(|opt_s| opt_s.map(|s| compare_str_function(v, s)))
81 .collect_ca_trusted(lhs.name().clone()))
82 },
83 (lhs_len, rhs_len) if lhs_len == rhs_len => Ok(lhs
84 .iter_str()
85 .zip(rhs.iter_str())
86 .map(|(l, r)| match (l, r) {
87 (None, _) => None,
88 (_, None) => None,
89 (Some(l), Some(r)) => Some(compare_str_function(l, r)),
90 })
91 .collect_ca_trusted(lhs.name().clone())),
92 (lhs_len, rhs_len) => {
93 polars_bail!(ComputeError: "Columns are of unequal length: {} vs {}",lhs_len,rhs_len)
94 },
95 }
96 }
97}
98
99impl ChunkCompareEq<&CategoricalChunked> for CategoricalChunked {
100 type Item = PolarsResult<BooleanChunked>;
101
102 fn equal(&self, rhs: &CategoricalChunked) -> Self::Item {
103 cat_equality_helper(
104 self,
105 rhs,
106 |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false),
107 UInt32Chunked::equal,
108 )
109 }
110
111 fn equal_missing(&self, rhs: &CategoricalChunked) -> Self::Item {
112 cat_equality_helper(
113 self,
114 rhs,
115 |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()),
116 UInt32Chunked::equal_missing,
117 )
118 }
119
120 fn not_equal(&self, rhs: &CategoricalChunked) -> Self::Item {
121 cat_equality_helper(
122 self,
123 rhs,
124 |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true),
125 UInt32Chunked::not_equal,
126 )
127 }
128
129 fn not_equal_missing(&self, rhs: &CategoricalChunked) -> Self::Item {
130 cat_equality_helper(
131 self,
132 rhs,
133 |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()),
134 UInt32Chunked::not_equal_missing,
135 )
136 }
137}
138
139impl ChunkCompareIneq<&CategoricalChunked> for CategoricalChunked {
140 type Item = PolarsResult<BooleanChunked>;
141
142 fn gt(&self, rhs: &CategoricalChunked) -> Self::Item {
143 cat_compare_helper(self, rhs, UInt32Chunked::gt, |l, r| l > r)
144 }
145
146 fn gt_eq(&self, rhs: &CategoricalChunked) -> Self::Item {
147 cat_compare_helper(self, rhs, UInt32Chunked::gt_eq, |l, r| l >= r)
148 }
149
150 fn lt(&self, rhs: &CategoricalChunked) -> Self::Item {
151 cat_compare_helper(self, rhs, UInt32Chunked::lt, |l, r| l < r)
152 }
153
154 fn lt_eq(&self, rhs: &CategoricalChunked) -> Self::Item {
155 cat_compare_helper(self, rhs, UInt32Chunked::lt_eq, |l, r| l <= r)
156 }
157}
158
159fn cat_str_equality_helper<'a, Missing, CompareNone, CompareCat, ComparePhys, CompareString>(
160 lhs: &'a CategoricalChunked,
161 rhs: &'a StringChunked,
162 missing_function: Missing,
163 compare_to_none: CompareNone,
164 cat_compare_function: CompareCat,
165 phys_compare_function: ComparePhys,
166 str_compare_function: CompareString,
167) -> PolarsResult<BooleanChunked>
168where
169 Missing: Fn(&CategoricalChunked) -> BooleanChunked,
170 CompareNone: Fn(&CategoricalChunked) -> BooleanChunked,
171 ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
172 CompareCat: Fn(&CategoricalChunked, &CategoricalChunked) -> PolarsResult<BooleanChunked>,
173 CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked,
174{
175 if lhs.is_enum() {
176 let rhs_cat = rhs.clone().into_series().strict_cast(lhs.dtype())?;
177 cat_compare_function(lhs, rhs_cat.categorical().unwrap())
178 } else if rhs.len() == 1 {
179 match rhs.get(0) {
180 None => Ok(compare_to_none(lhs)),
181 Some(s) => {
182 cat_single_str_equality_helper(lhs, s, missing_function, phys_compare_function)
183 },
184 }
185 } else {
186 let lhs_string = lhs.cast_with_options(&DataType::String, CastOptions::NonStrict)?;
187 Ok(str_compare_function(lhs_string.str().unwrap(), rhs))
188 }
189}
190
191fn cat_str_compare_helper<'a, CompareCat, ComparePhys, CompareStringSingle, CompareString>(
192 lhs: &'a CategoricalChunked,
193 rhs: &'a StringChunked,
194 cat_compare_function: CompareCat,
195 phys_compare_function: ComparePhys,
196 str_single_compare_function: CompareStringSingle,
197 str_compare_function: CompareString,
198) -> PolarsResult<BooleanChunked>
199where
200 CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap,
201 ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
202 CompareCat: Fn(&CategoricalChunked, &CategoricalChunked) -> PolarsResult<BooleanChunked>,
203 CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked,
204{
205 if lhs.is_enum() {
206 let rhs_cat = rhs.clone().into_series().strict_cast(lhs.dtype())?;
207 cat_compare_function(lhs, rhs_cat.categorical().unwrap())
208 } else if rhs.len() == 1 {
209 match rhs.get(0) {
210 None => Ok(BooleanChunked::full_null(lhs.name().clone(), lhs.len())),
211 Some(s) => cat_single_str_compare_helper(
212 lhs,
213 s,
214 phys_compare_function,
215 str_single_compare_function,
216 ),
217 }
218 } else {
219 let lhs_string = lhs.cast_with_options(&DataType::String, CastOptions::NonStrict)?;
220 Ok(str_compare_function(lhs_string.str().unwrap(), rhs))
221 }
222}
223
224impl ChunkCompareEq<&StringChunked> for CategoricalChunked {
225 type Item = PolarsResult<BooleanChunked>;
226
227 fn equal(&self, rhs: &StringChunked) -> Self::Item {
228 cat_str_equality_helper(
229 self,
230 rhs,
231 |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false),
232 |lhs| BooleanChunked::full_null(lhs.name().clone(), lhs.len()),
233 |s1, s2| CategoricalChunked::equal(s1, s2),
234 UInt32Chunked::equal,
235 StringChunked::equal,
236 )
237 }
238 fn equal_missing(&self, rhs: &StringChunked) -> Self::Item {
239 cat_str_equality_helper(
240 self,
241 rhs,
242 |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()),
243 |lhs| lhs.physical().is_null(),
244 |s1, s2| CategoricalChunked::equal_missing(s1, s2),
245 UInt32Chunked::equal_missing,
246 StringChunked::equal_missing,
247 )
248 }
249
250 fn not_equal(&self, rhs: &StringChunked) -> Self::Item {
251 cat_str_equality_helper(
252 self,
253 rhs,
254 |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true),
255 |lhs| BooleanChunked::full_null(lhs.name().clone(), lhs.len()),
256 |s1, s2| CategoricalChunked::not_equal(s1, s2),
257 UInt32Chunked::not_equal,
258 StringChunked::not_equal,
259 )
260 }
261 fn not_equal_missing(&self, rhs: &StringChunked) -> Self::Item {
262 cat_str_equality_helper(
263 self,
264 rhs,
265 |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()),
266 |lhs| !lhs.physical().is_null(),
267 |s1, s2| CategoricalChunked::not_equal_missing(s1, s2),
268 UInt32Chunked::not_equal_missing,
269 StringChunked::not_equal_missing,
270 )
271 }
272}
273
274impl ChunkCompareIneq<&StringChunked> for CategoricalChunked {
275 type Item = PolarsResult<BooleanChunked>;
276
277 fn gt(&self, rhs: &StringChunked) -> Self::Item {
278 cat_str_compare_helper(
279 self,
280 rhs,
281 |s1, s2| CategoricalChunked::gt(s1, s2),
282 UInt32Chunked::gt,
283 Utf8ViewArray::tot_gt_kernel_broadcast,
284 StringChunked::gt,
285 )
286 }
287
288 fn gt_eq(&self, rhs: &StringChunked) -> Self::Item {
289 cat_str_compare_helper(
290 self,
291 rhs,
292 |s1, s2| CategoricalChunked::gt_eq(s1, s2),
293 UInt32Chunked::gt_eq,
294 Utf8ViewArray::tot_ge_kernel_broadcast,
295 StringChunked::gt_eq,
296 )
297 }
298
299 fn lt(&self, rhs: &StringChunked) -> Self::Item {
300 cat_str_compare_helper(
301 self,
302 rhs,
303 |s1, s2| CategoricalChunked::lt(s1, s2),
304 UInt32Chunked::lt,
305 Utf8ViewArray::tot_lt_kernel_broadcast,
306 StringChunked::lt,
307 )
308 }
309
310 fn lt_eq(&self, rhs: &StringChunked) -> Self::Item {
311 cat_str_compare_helper(
312 self,
313 rhs,
314 |s1, s2| CategoricalChunked::lt_eq(s1, s2),
315 UInt32Chunked::lt_eq,
316 Utf8ViewArray::tot_le_kernel_broadcast,
317 StringChunked::lt_eq,
318 )
319 }
320}
321
322fn cat_single_str_equality_helper<'a, ComparePhys, Missing>(
323 lhs: &'a CategoricalChunked,
324 rhs: &'a str,
325 missing_function: Missing,
326 phys_compare_function: ComparePhys,
327) -> PolarsResult<BooleanChunked>
328where
329 ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
330 Missing: Fn(&CategoricalChunked) -> BooleanChunked,
331{
332 let rev_map = lhs.get_rev_map();
333 let idx = rev_map.find(rhs);
334 if lhs.is_enum() {
335 let Some(idx) = idx else {
336 polars_bail!(
337 not_in_enum,
338 value = rhs,
339 categories = rev_map.get_categories()
340 )
341 };
342 Ok(phys_compare_function(lhs.physical(), idx))
343 } else {
344 match rev_map.find(rhs) {
345 None => Ok(missing_function(lhs)),
346 Some(idx) => Ok(phys_compare_function(lhs.physical(), idx)),
347 }
348 }
349}
350
351fn cat_single_str_compare_helper<'a, ComparePhys, CompareStringSingle>(
352 lhs: &'a CategoricalChunked,
353 rhs: &'a str,
354 phys_compare_function: ComparePhys,
355 str_single_compare_function: CompareStringSingle,
356) -> PolarsResult<BooleanChunked>
357where
358 CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap,
359 ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
360{
361 let rev_map = lhs.get_rev_map();
362 if lhs.is_enum() {
363 match rev_map.find(rhs) {
364 None => {
365 polars_bail!(
366 not_in_enum,
367 value = rhs,
368 categories = rev_map.get_categories()
369 )
370 },
371 Some(idx) => Ok(phys_compare_function(lhs.physical(), idx)),
372 }
373 } else {
374 let bitmap = str_single_compare_function(lhs.get_rev_map().get_categories(), rhs);
376
377 let mask = match lhs.get_rev_map().as_ref() {
378 RevMapping::Local(_, _) => {
379 BooleanChunked::from_iter_trusted_length(lhs.physical().into_iter().map(
380 |opt_idx| {
381 opt_idx.map(|idx| unsafe { bitmap.get_bit_unchecked(idx as usize) })
383 },
384 ))
385 },
386 RevMapping::Global(idx_map, _, _) => {
387 BooleanChunked::from_iter_trusted_length(lhs.physical().into_iter().map(
388 |opt_idx| {
389 opt_idx.map(|idx| unsafe {
391 let idx = *idx_map.get(&idx).unwrap();
392 bitmap.get_bit_unchecked(idx as usize)
393 })
394 },
395 ))
396 },
397 };
398
399 Ok(mask.with_name(lhs.name().clone()))
400 }
401}
402
403impl ChunkCompareEq<&str> for CategoricalChunked {
404 type Item = PolarsResult<BooleanChunked>;
405
406 fn equal(&self, rhs: &str) -> Self::Item {
407 cat_single_str_equality_helper(
408 self,
409 rhs,
410 |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false),
411 UInt32Chunked::equal,
412 )
413 }
414
415 fn equal_missing(&self, rhs: &str) -> Self::Item {
416 cat_single_str_equality_helper(
417 self,
418 rhs,
419 |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()),
420 UInt32Chunked::equal_missing,
421 )
422 }
423
424 fn not_equal(&self, rhs: &str) -> Self::Item {
425 cat_single_str_equality_helper(
426 self,
427 rhs,
428 |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true),
429 UInt32Chunked::not_equal,
430 )
431 }
432
433 fn not_equal_missing(&self, rhs: &str) -> Self::Item {
434 cat_single_str_equality_helper(
435 self,
436 rhs,
437 |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()),
438 UInt32Chunked::equal_missing,
439 )
440 }
441}
442
443impl ChunkCompareIneq<&str> for CategoricalChunked {
444 type Item = PolarsResult<BooleanChunked>;
445
446 fn gt(&self, rhs: &str) -> Self::Item {
447 cat_single_str_compare_helper(
448 self,
449 rhs,
450 UInt32Chunked::gt,
451 Utf8ViewArray::tot_gt_kernel_broadcast,
452 )
453 }
454
455 fn gt_eq(&self, rhs: &str) -> Self::Item {
456 cat_single_str_compare_helper(
457 self,
458 rhs,
459 UInt32Chunked::gt_eq,
460 Utf8ViewArray::tot_ge_kernel_broadcast,
461 )
462 }
463
464 fn lt(&self, rhs: &str) -> Self::Item {
465 cat_single_str_compare_helper(
466 self,
467 rhs,
468 UInt32Chunked::lt,
469 Utf8ViewArray::tot_lt_kernel_broadcast,
470 )
471 }
472
473 fn lt_eq(&self, rhs: &str) -> Self::Item {
474 cat_single_str_compare_helper(
475 self,
476 rhs,
477 UInt32Chunked::lt_eq,
478 Utf8ViewArray::tot_le_kernel_broadcast,
479 )
480 }
481}