polars_core/chunked_array/comparison/
categorical.rs

1use arrow::bitmap::Bitmap;
2use arrow::legacy::utils::FromTrustedLenIterator;
3use polars_compute::comparisons::TotalOrdKernel;
4
5use crate::chunked_array::cast::CastOptions;
6use crate::prelude::nulls::replace_non_null;
7use crate::prelude::*;
8
9#[cfg(feature = "dtype-categorical")]
10fn cat_equality_helper<'a, Compare, Missing>(
11    lhs: &'a CategoricalChunked,
12    rhs: &'a CategoricalChunked,
13    missing_function: Missing,
14    compare_function: Compare,
15) -> PolarsResult<BooleanChunked>
16where
17    Compare: Fn(&'a UInt32Chunked, &'a UInt32Chunked) -> BooleanChunked,
18    Missing: Fn(&'a CategoricalChunked) -> BooleanChunked,
19{
20    let rev_map_l = lhs.get_rev_map();
21    polars_ensure!(rev_map_l.same_src(rhs.get_rev_map()), string_cache_mismatch);
22    let rhs = rhs.physical();
23
24    // Fast path for globals
25    if rhs.len() == 1 && rhs.null_count() == 0 {
26        let rhs = rhs.get(0).unwrap();
27        if rev_map_l.get_optional(rhs).is_none() {
28            return Ok(missing_function(lhs));
29        }
30    }
31    Ok(compare_function(lhs.physical(), rhs))
32}
33
34fn cat_compare_helper<'a, Compare, CompareString>(
35    lhs: &'a CategoricalChunked,
36    rhs: &'a CategoricalChunked,
37    compare_function: Compare,
38    compare_str_function: CompareString,
39) -> PolarsResult<BooleanChunked>
40where
41    Compare: Fn(&'a UInt32Chunked, &'a UInt32Chunked) -> BooleanChunked,
42    CompareString: Fn(&str, &str) -> bool,
43{
44    let rev_map_l = lhs.get_rev_map();
45    let rev_map_r = rhs.get_rev_map();
46    polars_ensure!(rev_map_l.same_src(rev_map_r), ComputeError: "can only compare categoricals of the same type with the same categories");
47
48    if lhs.is_enum() || !lhs.uses_lexical_ordering() {
49        Ok(compare_function(lhs.physical(), rhs.physical()))
50    } else {
51        match (lhs.len(), rhs.len()) {
52            (lhs_len, 1) => {
53                // SAFETY: physical is in range of revmap
54                let v = unsafe {
55                    rhs.physical()
56                        .get(0)
57                        .map(|phys| rev_map_r.get_unchecked(phys))
58                };
59                let Some(v) = v else {
60                    return Ok(BooleanChunked::full_null(lhs.name().clone(), lhs_len));
61                };
62
63                Ok(lhs
64                    .iter_str()
65                    .map(|opt_s| opt_s.map(|s| compare_str_function(s, v)))
66                    .collect_ca_trusted(lhs.name().clone()))
67            },
68            (1, rhs_len) => {
69                // SAFETY: physical is in range of revmap
70                let v = unsafe {
71                    lhs.physical()
72                        .get(0)
73                        .map(|phys| rev_map_l.get_unchecked(phys))
74                };
75                let Some(v) = v else {
76                    return Ok(BooleanChunked::full_null(lhs.name().clone(), rhs_len));
77                };
78                Ok(rhs
79                    .iter_str()
80                    .map(|opt_s| opt_s.map(|s| compare_str_function(v, s)))
81                    .collect_ca_trusted(lhs.name().clone()))
82            },
83            (lhs_len, rhs_len) if lhs_len == rhs_len => Ok(lhs
84                .iter_str()
85                .zip(rhs.iter_str())
86                .map(|(l, r)| match (l, r) {
87                    (None, _) => None,
88                    (_, None) => None,
89                    (Some(l), Some(r)) => Some(compare_str_function(l, r)),
90                })
91                .collect_ca_trusted(lhs.name().clone())),
92            (lhs_len, rhs_len) => {
93                polars_bail!(ComputeError: "Columns are of unequal length: {} vs {}",lhs_len,rhs_len)
94            },
95        }
96    }
97}
98
99impl ChunkCompareEq<&CategoricalChunked> for CategoricalChunked {
100    type Item = PolarsResult<BooleanChunked>;
101
102    fn equal(&self, rhs: &CategoricalChunked) -> Self::Item {
103        cat_equality_helper(
104            self,
105            rhs,
106            |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false),
107            UInt32Chunked::equal,
108        )
109    }
110
111    fn equal_missing(&self, rhs: &CategoricalChunked) -> Self::Item {
112        cat_equality_helper(
113            self,
114            rhs,
115            |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()),
116            UInt32Chunked::equal_missing,
117        )
118    }
119
120    fn not_equal(&self, rhs: &CategoricalChunked) -> Self::Item {
121        cat_equality_helper(
122            self,
123            rhs,
124            |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true),
125            UInt32Chunked::not_equal,
126        )
127    }
128
129    fn not_equal_missing(&self, rhs: &CategoricalChunked) -> Self::Item {
130        cat_equality_helper(
131            self,
132            rhs,
133            |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()),
134            UInt32Chunked::not_equal_missing,
135        )
136    }
137}
138
139impl ChunkCompareIneq<&CategoricalChunked> for CategoricalChunked {
140    type Item = PolarsResult<BooleanChunked>;
141
142    fn gt(&self, rhs: &CategoricalChunked) -> Self::Item {
143        cat_compare_helper(self, rhs, UInt32Chunked::gt, |l, r| l > r)
144    }
145
146    fn gt_eq(&self, rhs: &CategoricalChunked) -> Self::Item {
147        cat_compare_helper(self, rhs, UInt32Chunked::gt_eq, |l, r| l >= r)
148    }
149
150    fn lt(&self, rhs: &CategoricalChunked) -> Self::Item {
151        cat_compare_helper(self, rhs, UInt32Chunked::lt, |l, r| l < r)
152    }
153
154    fn lt_eq(&self, rhs: &CategoricalChunked) -> Self::Item {
155        cat_compare_helper(self, rhs, UInt32Chunked::lt_eq, |l, r| l <= r)
156    }
157}
158
159fn cat_str_equality_helper<'a, Missing, CompareNone, CompareCat, ComparePhys, CompareString>(
160    lhs: &'a CategoricalChunked,
161    rhs: &'a StringChunked,
162    missing_function: Missing,
163    compare_to_none: CompareNone,
164    cat_compare_function: CompareCat,
165    phys_compare_function: ComparePhys,
166    str_compare_function: CompareString,
167) -> PolarsResult<BooleanChunked>
168where
169    Missing: Fn(&CategoricalChunked) -> BooleanChunked,
170    CompareNone: Fn(&CategoricalChunked) -> BooleanChunked,
171    ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
172    CompareCat: Fn(&CategoricalChunked, &CategoricalChunked) -> PolarsResult<BooleanChunked>,
173    CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked,
174{
175    if lhs.is_enum() {
176        let rhs_cat = rhs.clone().into_series().strict_cast(lhs.dtype())?;
177        cat_compare_function(lhs, rhs_cat.categorical().unwrap())
178    } else if rhs.len() == 1 {
179        match rhs.get(0) {
180            None => Ok(compare_to_none(lhs)),
181            Some(s) => {
182                cat_single_str_equality_helper(lhs, s, missing_function, phys_compare_function)
183            },
184        }
185    } else {
186        let lhs_string = lhs.cast_with_options(&DataType::String, CastOptions::NonStrict)?;
187        Ok(str_compare_function(lhs_string.str().unwrap(), rhs))
188    }
189}
190
191fn cat_str_compare_helper<'a, CompareCat, ComparePhys, CompareStringSingle, CompareString>(
192    lhs: &'a CategoricalChunked,
193    rhs: &'a StringChunked,
194    cat_compare_function: CompareCat,
195    phys_compare_function: ComparePhys,
196    str_single_compare_function: CompareStringSingle,
197    str_compare_function: CompareString,
198) -> PolarsResult<BooleanChunked>
199where
200    CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap,
201    ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
202    CompareCat: Fn(&CategoricalChunked, &CategoricalChunked) -> PolarsResult<BooleanChunked>,
203    CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked,
204{
205    if lhs.is_enum() {
206        let rhs_cat = rhs.clone().into_series().strict_cast(lhs.dtype())?;
207        cat_compare_function(lhs, rhs_cat.categorical().unwrap())
208    } else if rhs.len() == 1 {
209        match rhs.get(0) {
210            None => Ok(BooleanChunked::full_null(lhs.name().clone(), lhs.len())),
211            Some(s) => cat_single_str_compare_helper(
212                lhs,
213                s,
214                phys_compare_function,
215                str_single_compare_function,
216            ),
217        }
218    } else {
219        let lhs_string = lhs.cast_with_options(&DataType::String, CastOptions::NonStrict)?;
220        Ok(str_compare_function(lhs_string.str().unwrap(), rhs))
221    }
222}
223
224impl ChunkCompareEq<&StringChunked> for CategoricalChunked {
225    type Item = PolarsResult<BooleanChunked>;
226
227    fn equal(&self, rhs: &StringChunked) -> Self::Item {
228        cat_str_equality_helper(
229            self,
230            rhs,
231            |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false),
232            |lhs| BooleanChunked::full_null(lhs.name().clone(), lhs.len()),
233            |s1, s2| CategoricalChunked::equal(s1, s2),
234            UInt32Chunked::equal,
235            StringChunked::equal,
236        )
237    }
238    fn equal_missing(&self, rhs: &StringChunked) -> Self::Item {
239        cat_str_equality_helper(
240            self,
241            rhs,
242            |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()),
243            |lhs| lhs.physical().is_null(),
244            |s1, s2| CategoricalChunked::equal_missing(s1, s2),
245            UInt32Chunked::equal_missing,
246            StringChunked::equal_missing,
247        )
248    }
249
250    fn not_equal(&self, rhs: &StringChunked) -> Self::Item {
251        cat_str_equality_helper(
252            self,
253            rhs,
254            |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true),
255            |lhs| BooleanChunked::full_null(lhs.name().clone(), lhs.len()),
256            |s1, s2| CategoricalChunked::not_equal(s1, s2),
257            UInt32Chunked::not_equal,
258            StringChunked::not_equal,
259        )
260    }
261    fn not_equal_missing(&self, rhs: &StringChunked) -> Self::Item {
262        cat_str_equality_helper(
263            self,
264            rhs,
265            |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()),
266            |lhs| !lhs.physical().is_null(),
267            |s1, s2| CategoricalChunked::not_equal_missing(s1, s2),
268            UInt32Chunked::not_equal_missing,
269            StringChunked::not_equal_missing,
270        )
271    }
272}
273
274impl ChunkCompareIneq<&StringChunked> for CategoricalChunked {
275    type Item = PolarsResult<BooleanChunked>;
276
277    fn gt(&self, rhs: &StringChunked) -> Self::Item {
278        cat_str_compare_helper(
279            self,
280            rhs,
281            |s1, s2| CategoricalChunked::gt(s1, s2),
282            UInt32Chunked::gt,
283            Utf8ViewArray::tot_gt_kernel_broadcast,
284            StringChunked::gt,
285        )
286    }
287
288    fn gt_eq(&self, rhs: &StringChunked) -> Self::Item {
289        cat_str_compare_helper(
290            self,
291            rhs,
292            |s1, s2| CategoricalChunked::gt_eq(s1, s2),
293            UInt32Chunked::gt_eq,
294            Utf8ViewArray::tot_ge_kernel_broadcast,
295            StringChunked::gt_eq,
296        )
297    }
298
299    fn lt(&self, rhs: &StringChunked) -> Self::Item {
300        cat_str_compare_helper(
301            self,
302            rhs,
303            |s1, s2| CategoricalChunked::lt(s1, s2),
304            UInt32Chunked::lt,
305            Utf8ViewArray::tot_lt_kernel_broadcast,
306            StringChunked::lt,
307        )
308    }
309
310    fn lt_eq(&self, rhs: &StringChunked) -> Self::Item {
311        cat_str_compare_helper(
312            self,
313            rhs,
314            |s1, s2| CategoricalChunked::lt_eq(s1, s2),
315            UInt32Chunked::lt_eq,
316            Utf8ViewArray::tot_le_kernel_broadcast,
317            StringChunked::lt_eq,
318        )
319    }
320}
321
322fn cat_single_str_equality_helper<'a, ComparePhys, Missing>(
323    lhs: &'a CategoricalChunked,
324    rhs: &'a str,
325    missing_function: Missing,
326    phys_compare_function: ComparePhys,
327) -> PolarsResult<BooleanChunked>
328where
329    ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
330    Missing: Fn(&CategoricalChunked) -> BooleanChunked,
331{
332    let rev_map = lhs.get_rev_map();
333    let idx = rev_map.find(rhs);
334    if lhs.is_enum() {
335        let Some(idx) = idx else {
336            polars_bail!(
337                not_in_enum,
338                value = rhs,
339                categories = rev_map.get_categories()
340            )
341        };
342        Ok(phys_compare_function(lhs.physical(), idx))
343    } else {
344        match rev_map.find(rhs) {
345            None => Ok(missing_function(lhs)),
346            Some(idx) => Ok(phys_compare_function(lhs.physical(), idx)),
347        }
348    }
349}
350
351fn cat_single_str_compare_helper<'a, ComparePhys, CompareStringSingle>(
352    lhs: &'a CategoricalChunked,
353    rhs: &'a str,
354    phys_compare_function: ComparePhys,
355    str_single_compare_function: CompareStringSingle,
356) -> PolarsResult<BooleanChunked>
357where
358    CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap,
359    ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
360{
361    let rev_map = lhs.get_rev_map();
362    if lhs.is_enum() {
363        match rev_map.find(rhs) {
364            None => {
365                polars_bail!(
366                    not_in_enum,
367                    value = rhs,
368                    categories = rev_map.get_categories()
369                )
370            },
371            Some(idx) => Ok(phys_compare_function(lhs.physical(), idx)),
372        }
373    } else {
374        // Apply comparison on categories map and then do a lookup
375        let bitmap = str_single_compare_function(lhs.get_rev_map().get_categories(), rhs);
376
377        let mask = match lhs.get_rev_map().as_ref() {
378            RevMapping::Local(_, _) => {
379                BooleanChunked::from_iter_trusted_length(lhs.physical().into_iter().map(
380                    |opt_idx| {
381                        // SAFETY: indexing into bitmap with same length as original array
382                        opt_idx.map(|idx| unsafe { bitmap.get_bit_unchecked(idx as usize) })
383                    },
384                ))
385            },
386            RevMapping::Global(idx_map, _, _) => {
387                BooleanChunked::from_iter_trusted_length(lhs.physical().into_iter().map(
388                    |opt_idx| {
389                        // SAFETY: indexing into bitmap with same length as original array
390                        opt_idx.map(|idx| unsafe {
391                            let idx = *idx_map.get(&idx).unwrap();
392                            bitmap.get_bit_unchecked(idx as usize)
393                        })
394                    },
395                ))
396            },
397        };
398
399        Ok(mask.with_name(lhs.name().clone()))
400    }
401}
402
403impl ChunkCompareEq<&str> for CategoricalChunked {
404    type Item = PolarsResult<BooleanChunked>;
405
406    fn equal(&self, rhs: &str) -> Self::Item {
407        cat_single_str_equality_helper(
408            self,
409            rhs,
410            |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false),
411            UInt32Chunked::equal,
412        )
413    }
414
415    fn equal_missing(&self, rhs: &str) -> Self::Item {
416        cat_single_str_equality_helper(
417            self,
418            rhs,
419            |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()),
420            UInt32Chunked::equal_missing,
421        )
422    }
423
424    fn not_equal(&self, rhs: &str) -> Self::Item {
425        cat_single_str_equality_helper(
426            self,
427            rhs,
428            |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true),
429            UInt32Chunked::not_equal,
430        )
431    }
432
433    fn not_equal_missing(&self, rhs: &str) -> Self::Item {
434        cat_single_str_equality_helper(
435            self,
436            rhs,
437            |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()),
438            UInt32Chunked::equal_missing,
439        )
440    }
441}
442
443impl ChunkCompareIneq<&str> for CategoricalChunked {
444    type Item = PolarsResult<BooleanChunked>;
445
446    fn gt(&self, rhs: &str) -> Self::Item {
447        cat_single_str_compare_helper(
448            self,
449            rhs,
450            UInt32Chunked::gt,
451            Utf8ViewArray::tot_gt_kernel_broadcast,
452        )
453    }
454
455    fn gt_eq(&self, rhs: &str) -> Self::Item {
456        cat_single_str_compare_helper(
457            self,
458            rhs,
459            UInt32Chunked::gt_eq,
460            Utf8ViewArray::tot_ge_kernel_broadcast,
461        )
462    }
463
464    fn lt(&self, rhs: &str) -> Self::Item {
465        cat_single_str_compare_helper(
466            self,
467            rhs,
468            UInt32Chunked::lt,
469            Utf8ViewArray::tot_lt_kernel_broadcast,
470        )
471    }
472
473    fn lt_eq(&self, rhs: &str) -> Self::Item {
474        cat_single_str_compare_helper(
475            self,
476            rhs,
477            UInt32Chunked::lt_eq,
478            Utf8ViewArray::tot_le_kernel_broadcast,
479        )
480    }
481}