polars_core/chunked_array/comparison/
categorical.rs

1use crate::prelude::arity::unary_mut_values;
2use crate::prelude::*;
3
4fn str_to_cat_enum(map: &CategoricalMapping, s: &str) -> PolarsResult<CatSize> {
5    map.get_cat(s).ok_or_else(|| polars_err!(InvalidOperation: "conversion from `str` to `enum` failed for value \"{s}\""))
6}
7
8fn cat_equality_helper<T: PolarsCategoricalType, EqPhys>(
9    lhs: &CategoricalChunked<T>,
10    rhs: &CategoricalChunked<T>,
11    eq_phys: EqPhys,
12) -> PolarsResult<BooleanChunked>
13where
14    EqPhys:
15        Fn(&ChunkedArray<T::PolarsPhysical>, &ChunkedArray<T::PolarsPhysical>) -> BooleanChunked,
16{
17    lhs.dtype().matches_schema_type(rhs.dtype())?;
18    Ok(eq_phys(lhs.physical(), rhs.physical()))
19}
20
21fn cat_compare_helper<T: PolarsCategoricalType, Cmp, CmpPhys>(
22    lhs: &CategoricalChunked<T>,
23    rhs: &CategoricalChunked<T>,
24    cmp: Cmp,
25    cmp_phys: CmpPhys,
26) -> PolarsResult<BooleanChunked>
27where
28    Cmp: Fn(&str, &str) -> bool,
29    CmpPhys:
30        Fn(&ChunkedArray<T::PolarsPhysical>, &ChunkedArray<T::PolarsPhysical>) -> BooleanChunked,
31{
32    lhs.dtype().matches_schema_type(rhs.dtype())?;
33    if lhs.is_enum() {
34        return Ok(cmp_phys(lhs.physical(), rhs.physical()));
35    }
36    let mapping = lhs.get_mapping();
37    match (lhs.len(), rhs.len()) {
38        (lhs_len, 1) => {
39            let Some(cat) = rhs.physical().get(0) else {
40                return Ok(BooleanChunked::full_null(lhs.name().clone(), lhs_len));
41            };
42
43            // SAFETY: physical is in range of the mapping.
44            let v = unsafe { mapping.cat_to_str_unchecked(cat.as_cat()) };
45            Ok(lhs
46                .iter_str()
47                .map(|opt_s| opt_s.map(|s| cmp(s, v)))
48                .collect_ca_trusted(lhs.name().clone()))
49        },
50        (1, rhs_len) => {
51            let Some(cat) = lhs.physical().get(0) else {
52                return Ok(BooleanChunked::full_null(lhs.name().clone(), rhs_len));
53            };
54
55            // SAFETY: physical is in range of the mapping.
56            let v = unsafe { mapping.cat_to_str_unchecked(cat.as_cat()) };
57            Ok(rhs
58                .iter_str()
59                .map(|opt_s| opt_s.map(|s| cmp(v, s)))
60                .collect_ca_trusted(lhs.name().clone()))
61        },
62        (lhs_len, rhs_len) => {
63            assert!(lhs_len == rhs_len);
64            Ok(lhs
65                .iter_str()
66                .zip(rhs.iter_str())
67                .map(|(l, r)| match (l, r) {
68                    (None, _) => None,
69                    (_, None) => None,
70                    (Some(l), Some(r)) => Some(cmp(l, r)),
71                })
72                .collect_ca_trusted(lhs.name().clone()))
73        },
74    }
75}
76
77fn cat_str_equality_helper<T: PolarsCategoricalType, Eq, EqPhysScalar, EqStrScalar>(
78    lhs: &CategoricalChunked<T>,
79    rhs: &StringChunked,
80    eq: Eq,
81    eq_phys_scalar: EqPhysScalar,
82    eq_str_scalar: EqStrScalar,
83) -> BooleanChunked
84where
85    Eq: Fn(Option<&str>, Option<&str>) -> Option<bool>,
86    EqPhysScalar: Fn(&ChunkedArray<T::PolarsPhysical>, T::Native) -> BooleanChunked,
87    EqStrScalar: Fn(&StringChunked, &str) -> BooleanChunked,
88{
89    let mapping = lhs.get_mapping();
90    let null_eq = eq(None, None);
91    match (lhs.len(), rhs.len()) {
92        (lhs_len, 1) => {
93            let Some(s) = rhs.get(0) else {
94                return match null_eq {
95                    Some(true) => lhs.physical().is_null(),
96                    Some(false) => lhs.physical().is_not_null(),
97                    None => BooleanChunked::full_null(lhs.name().clone(), lhs_len),
98                };
99            };
100
101            let is_eq = eq(Some(""), Some("")).unwrap();
102            cat_str_scalar_equality_helper(lhs, s, is_eq, null_eq.is_some(), &eq_phys_scalar)
103        },
104        (1, rhs_len) => {
105            let Some(cat) = lhs.physical().get(0) else {
106                return match null_eq {
107                    Some(true) => rhs.is_null().with_name(lhs.name().clone()),
108                    Some(false) => rhs.is_not_null().with_name(lhs.name().clone()),
109                    None => BooleanChunked::full_null(lhs.name().clone(), rhs_len),
110                };
111            };
112
113            // SAFETY: physical is in range of the mapping.
114            let s = unsafe { mapping.cat_to_str_unchecked(cat.as_cat()) };
115            eq_str_scalar(rhs, s).with_name(lhs.name().clone())
116        },
117        (lhs_len, rhs_len) => {
118            assert!(lhs_len == rhs_len);
119            lhs.iter_str()
120                .zip(rhs.iter())
121                .map(|(l, r)| eq(l, r))
122                .collect_ca_trusted(lhs.name().clone())
123        },
124    }
125}
126
127fn cat_str_compare_helper<T: PolarsCategoricalType, Cmp, CmpStrScalar>(
128    lhs: &CategoricalChunked<T>,
129    rhs: &StringChunked,
130    cmp: Cmp,
131    cmp_str_scalar: CmpStrScalar,
132) -> BooleanChunked
133where
134    Cmp: Fn(&str, &str) -> bool,
135    CmpStrScalar: Fn(&str, &StringChunked) -> BooleanChunked,
136{
137    let mapping = lhs.get_mapping();
138    match (lhs.len(), rhs.len()) {
139        (lhs_len, 1) => {
140            let Some(s) = rhs.get(0) else {
141                return BooleanChunked::full_null(lhs.name().clone(), lhs_len);
142            };
143            cat_str_scalar_compare_helper(lhs, s, cmp)
144        },
145        (1, rhs_len) => {
146            let Some(cat) = lhs.physical().get(0) else {
147                return BooleanChunked::full_null(lhs.name().clone(), rhs_len);
148            };
149
150            // SAFETY: physical is in range of the mapping.
151            let s = unsafe { mapping.cat_to_str_unchecked(cat.as_cat()) };
152            cmp_str_scalar(s, rhs).with_name(lhs.name().clone())
153        },
154        (lhs_len, rhs_len) => {
155            assert!(lhs_len == rhs_len);
156            lhs.iter_str()
157                .zip(rhs.iter())
158                .map(|(l, r)| match (l, r) {
159                    (None, _) => None,
160                    (_, None) => None,
161                    (Some(l), Some(r)) => Some(cmp(l, r)),
162                })
163                .collect_ca_trusted(lhs.name().clone())
164        },
165    }
166}
167
168fn cat_str_phys_compare_helper<T: PolarsCategoricalType, Cmp>(
169    lhs: &CategoricalChunked<T>,
170    rhs: &StringChunked,
171    cmp: Cmp,
172) -> PolarsResult<BooleanChunked>
173where
174    Cmp: Fn(T::Native, T::Native) -> bool,
175{
176    let mapping = lhs.get_mapping();
177    match (lhs.len(), rhs.len()) {
178        (lhs_len, 1) => {
179            let Some(s) = rhs.get(0) else {
180                return Ok(BooleanChunked::full_null(lhs.name().clone(), lhs_len));
181            };
182            cat_str_scalar_phys_compare_helper(lhs, s, cmp)
183        },
184        (1, rhs_len) => {
185            let Some(cat) = lhs.physical().get(0) else {
186                return Ok(BooleanChunked::full_null(lhs.name().clone(), rhs_len));
187            };
188
189            rhs.iter()
190                .map(|opt_r| {
191                    if let Some(r) = opt_r {
192                        let r = T::Native::from_cat(str_to_cat_enum(mapping, r)?);
193                        Ok(Some(cmp(cat, r)))
194                    } else {
195                        Ok(None)
196                    }
197                })
198                .try_collect_ca_trusted(lhs.name().clone())
199        },
200        (lhs_len, rhs_len) => {
201            assert!(lhs_len == rhs_len);
202            lhs.physical()
203                .iter()
204                .zip(rhs.iter())
205                .map(|(l, r)| match (l, r) {
206                    (None, _) => Ok(None),
207                    (_, None) => Ok(None),
208                    (Some(l), Some(r)) => {
209                        let r = T::Native::from_cat(str_to_cat_enum(mapping, r)?);
210                        Ok(Some(cmp(l, r)))
211                    },
212                })
213                .try_collect_ca_trusted(lhs.name().clone())
214        },
215    }
216}
217
218fn cat_str_scalar_equality_helper<T: PolarsCategoricalType, EqPhysScalar>(
219    lhs: &CategoricalChunked<T>,
220    rhs: &str,
221    is_eq: bool,
222    missing: bool,
223    eq_phys_scalar: EqPhysScalar,
224) -> BooleanChunked
225where
226    EqPhysScalar: Fn(&ChunkedArray<T::PolarsPhysical>, T::Native) -> BooleanChunked,
227{
228    let mapping = lhs.get_mapping();
229    let Some(cat) = mapping.get_cat(rhs) else {
230        return if missing {
231            if is_eq {
232                BooleanChunked::full(lhs.name().clone(), false, lhs.len())
233            } else {
234                BooleanChunked::full(lhs.name().clone(), true, lhs.len())
235            }
236        } else {
237            unary_mut_values(lhs.physical(), |arr| {
238                BooleanArray::full(arr.len(), !is_eq, ArrowDataType::Boolean)
239            })
240        };
241    };
242
243    eq_phys_scalar(lhs.physical(), T::Native::from_cat(cat))
244}
245
246fn cat_str_scalar_compare_helper<T: PolarsCategoricalType, Cmp>(
247    lhs: &CategoricalChunked<T>,
248    rhs: &str,
249    cmp: Cmp,
250) -> BooleanChunked
251where
252    Cmp: Fn(&str, &str) -> bool,
253{
254    lhs.iter_str()
255        .map(|opt_l| opt_l.map(|l| cmp(l, rhs)))
256        .collect_ca_trusted(lhs.name().clone())
257}
258
259fn cat_str_scalar_phys_compare_helper<T: PolarsCategoricalType, Cmp>(
260    lhs: &CategoricalChunked<T>,
261    rhs: &str,
262    cmp: Cmp,
263) -> PolarsResult<BooleanChunked>
264where
265    Cmp: Fn(T::Native, T::Native) -> bool,
266{
267    let r = T::Native::from_cat(str_to_cat_enum(lhs.get_mapping(), rhs)?);
268    Ok(lhs
269        .physical()
270        .iter()
271        .map(|opt_l| opt_l.map(|l| cmp(l, r)))
272        .collect_ca_trusted(lhs.name().clone()))
273}
274
275impl<T: PolarsCategoricalType> ChunkCompareEq<&CategoricalChunked<T>> for CategoricalChunked<T>
276where
277    ChunkedArray<T::PolarsPhysical>:
278        for<'a> ChunkCompareEq<&'a ChunkedArray<T::PolarsPhysical>, Item = BooleanChunked>,
279{
280    type Item = PolarsResult<BooleanChunked>;
281
282    fn equal(&self, rhs: &Self) -> Self::Item {
283        cat_equality_helper(self, rhs, |l, r| l.equal(r))
284    }
285
286    fn equal_missing(&self, rhs: &Self) -> Self::Item {
287        cat_equality_helper(self, rhs, |l, r| l.equal_missing(r))
288    }
289
290    fn not_equal(&self, rhs: &Self) -> Self::Item {
291        cat_equality_helper(self, rhs, |l, r| l.not_equal(r))
292    }
293
294    fn not_equal_missing(&self, rhs: &Self) -> Self::Item {
295        cat_equality_helper(self, rhs, |l, r| l.not_equal_missing(r))
296    }
297}
298
299impl<T: PolarsCategoricalType> ChunkCompareIneq<&CategoricalChunked<T>> for CategoricalChunked<T>
300where
301    ChunkedArray<T::PolarsPhysical>:
302        for<'a> ChunkCompareIneq<&'a ChunkedArray<T::PolarsPhysical>, Item = BooleanChunked>,
303{
304    type Item = PolarsResult<BooleanChunked>;
305
306    fn gt(&self, rhs: &CategoricalChunked<T>) -> Self::Item {
307        cat_compare_helper(self, rhs, |l, r| l > r, |l, r| l.gt(r))
308    }
309
310    fn gt_eq(&self, rhs: &CategoricalChunked<T>) -> Self::Item {
311        cat_compare_helper(self, rhs, |l, r| l >= r, |l, r| l.gt_eq(r))
312    }
313
314    fn lt(&self, rhs: &CategoricalChunked<T>) -> Self::Item {
315        cat_compare_helper(self, rhs, |l, r| l < r, |l, r| l.lt(r))
316    }
317
318    fn lt_eq(&self, rhs: &CategoricalChunked<T>) -> Self::Item {
319        cat_compare_helper(self, rhs, |l, r| l <= r, |l, r| l.lt_eq(r))
320    }
321}
322
323impl<T: PolarsCategoricalType> ChunkCompareEq<&StringChunked> for CategoricalChunked<T>
324where
325    ChunkedArray<T::PolarsPhysical>: for<'a> ChunkCompareEq<T::Native, Item = BooleanChunked>,
326{
327    type Item = BooleanChunked;
328
329    fn equal(&self, rhs: &StringChunked) -> Self::Item {
330        cat_str_equality_helper(
331            self,
332            rhs,
333            |l, r| l.zip(r).map(|(l, r)| l == r),
334            |l, c| l.equal(c),
335            |r, c| r.equal(c),
336        )
337    }
338
339    fn equal_missing(&self, rhs: &StringChunked) -> Self::Item {
340        cat_str_equality_helper(
341            self,
342            rhs,
343            |l, r| Some(l == r),
344            |l, c| l.equal_missing(c),
345            |r, c| r.equal_missing(c),
346        )
347    }
348
349    fn not_equal(&self, rhs: &StringChunked) -> Self::Item {
350        cat_str_equality_helper(
351            self,
352            rhs,
353            |l, r| l.zip(r).map(|(l, r)| l != r),
354            |l, c| l.not_equal(c),
355            |r, c| r.not_equal(c),
356        )
357    }
358
359    fn not_equal_missing(&self, rhs: &StringChunked) -> Self::Item {
360        cat_str_equality_helper(
361            self,
362            rhs,
363            |l, r| Some(l != r),
364            |l, c| l.not_equal_missing(c),
365            |r, c| r.not_equal_missing(c),
366        )
367    }
368}
369
370impl<T: PolarsCategoricalType> ChunkCompareIneq<&StringChunked> for CategoricalChunked<T> {
371    type Item = PolarsResult<BooleanChunked>;
372
373    fn gt(&self, rhs: &StringChunked) -> Self::Item {
374        if self.is_enum() {
375            cat_str_phys_compare_helper(self, rhs, |l, r| l > r)
376        } else {
377            Ok(cat_str_compare_helper(
378                self,
379                rhs,
380                |l, r| l > r,
381                |c, r| r.lt(c),
382            ))
383        }
384    }
385
386    fn gt_eq(&self, rhs: &StringChunked) -> Self::Item {
387        if self.is_enum() {
388            cat_str_phys_compare_helper(self, rhs, |l, r| l >= r)
389        } else {
390            Ok(cat_str_compare_helper(
391                self,
392                rhs,
393                |l, r| l >= r,
394                |c, r| r.lt_eq(c),
395            ))
396        }
397    }
398
399    fn lt(&self, rhs: &StringChunked) -> Self::Item {
400        if self.is_enum() {
401            cat_str_phys_compare_helper(self, rhs, |l, r| l < r)
402        } else {
403            Ok(cat_str_compare_helper(
404                self,
405                rhs,
406                |l, r| l < r,
407                |c, r| r.gt(c),
408            ))
409        }
410    }
411
412    fn lt_eq(&self, rhs: &StringChunked) -> Self::Item {
413        if self.is_enum() {
414            cat_str_phys_compare_helper(self, rhs, |l, r| l <= r)
415        } else {
416            Ok(cat_str_compare_helper(
417                self,
418                rhs,
419                |l, r| l <= r,
420                |c, r| r.gt_eq(c),
421            ))
422        }
423    }
424}
425
426impl<T: PolarsCategoricalType> ChunkCompareEq<&str> for CategoricalChunked<T>
427where
428    ChunkedArray<T::PolarsPhysical>: for<'a> ChunkCompareEq<T::Native, Item = BooleanChunked>,
429{
430    type Item = BooleanChunked;
431
432    fn equal(&self, rhs: &str) -> Self::Item {
433        cat_str_scalar_equality_helper(self, rhs, true, false, |l, c| l.equal(c))
434    }
435
436    fn equal_missing(&self, rhs: &str) -> Self::Item {
437        cat_str_scalar_equality_helper(self, rhs, true, true, |l, c| l.equal_missing(c))
438    }
439
440    fn not_equal(&self, rhs: &str) -> Self::Item {
441        cat_str_scalar_equality_helper(self, rhs, false, false, |r, c| r.not_equal(c))
442    }
443
444    fn not_equal_missing(&self, rhs: &str) -> Self::Item {
445        cat_str_scalar_equality_helper(self, rhs, false, true, |l, c| l.not_equal_missing(c))
446    }
447}
448
449impl<T: PolarsCategoricalType> ChunkCompareIneq<&str> for CategoricalChunked<T> {
450    type Item = BooleanChunked;
451
452    fn gt(&self, rhs: &str) -> Self::Item {
453        cat_str_scalar_compare_helper(self, rhs, |l, r| l > r)
454    }
455
456    fn gt_eq(&self, rhs: &str) -> Self::Item {
457        cat_str_scalar_compare_helper(self, rhs, |l, r| l >= r)
458    }
459
460    fn lt(&self, rhs: &str) -> Self::Item {
461        cat_str_scalar_compare_helper(self, rhs, |l, r| l < r)
462    }
463
464    fn lt_eq(&self, rhs: &str) -> Self::Item {
465        cat_str_scalar_compare_helper(self, rhs, |l, r| l <= r)
466    }
467}