1use crate::prelude::arity::unary_mut_values;
2use crate::prelude::*;
3
4fn str_to_cat_enum(map: &CategoricalMapping, s: &str) -> PolarsResult<CatSize> {
5 map.get_cat(s).ok_or_else(|| polars_err!(InvalidOperation: "conversion from `str` to `enum` failed for value \"{s}\""))
6}
7
8fn cat_equality_helper<T: PolarsCategoricalType, EqPhys>(
9 lhs: &CategoricalChunked<T>,
10 rhs: &CategoricalChunked<T>,
11 eq_phys: EqPhys,
12) -> PolarsResult<BooleanChunked>
13where
14 EqPhys:
15 Fn(&ChunkedArray<T::PolarsPhysical>, &ChunkedArray<T::PolarsPhysical>) -> BooleanChunked,
16{
17 lhs.dtype().matches_schema_type(rhs.dtype())?;
18 Ok(eq_phys(lhs.physical(), rhs.physical()))
19}
20
21fn cat_compare_helper<T: PolarsCategoricalType, Cmp, CmpPhys>(
22 lhs: &CategoricalChunked<T>,
23 rhs: &CategoricalChunked<T>,
24 cmp: Cmp,
25 cmp_phys: CmpPhys,
26) -> PolarsResult<BooleanChunked>
27where
28 Cmp: Fn(&str, &str) -> bool,
29 CmpPhys:
30 Fn(&ChunkedArray<T::PolarsPhysical>, &ChunkedArray<T::PolarsPhysical>) -> BooleanChunked,
31{
32 lhs.dtype().matches_schema_type(rhs.dtype())?;
33 if lhs.is_enum() {
34 return Ok(cmp_phys(lhs.physical(), rhs.physical()));
35 }
36 let mapping = lhs.get_mapping();
37 match (lhs.len(), rhs.len()) {
38 (lhs_len, 1) => {
39 let Some(cat) = rhs.physical().get(0) else {
40 return Ok(BooleanChunked::full_null(lhs.name().clone(), lhs_len));
41 };
42
43 let v = unsafe { mapping.cat_to_str_unchecked(cat.as_cat()) };
45 Ok(lhs
46 .iter_str()
47 .map(|opt_s| opt_s.map(|s| cmp(s, v)))
48 .collect_ca_trusted(lhs.name().clone()))
49 },
50 (1, rhs_len) => {
51 let Some(cat) = lhs.physical().get(0) else {
52 return Ok(BooleanChunked::full_null(lhs.name().clone(), rhs_len));
53 };
54
55 let v = unsafe { mapping.cat_to_str_unchecked(cat.as_cat()) };
57 Ok(rhs
58 .iter_str()
59 .map(|opt_s| opt_s.map(|s| cmp(v, s)))
60 .collect_ca_trusted(lhs.name().clone()))
61 },
62 (lhs_len, rhs_len) => {
63 assert!(lhs_len == rhs_len);
64 Ok(lhs
65 .iter_str()
66 .zip(rhs.iter_str())
67 .map(|(l, r)| match (l, r) {
68 (None, _) => None,
69 (_, None) => None,
70 (Some(l), Some(r)) => Some(cmp(l, r)),
71 })
72 .collect_ca_trusted(lhs.name().clone()))
73 },
74 }
75}
76
77fn cat_str_equality_helper<T: PolarsCategoricalType, Eq, EqPhysScalar, EqStrScalar>(
78 lhs: &CategoricalChunked<T>,
79 rhs: &StringChunked,
80 eq: Eq,
81 eq_phys_scalar: EqPhysScalar,
82 eq_str_scalar: EqStrScalar,
83) -> BooleanChunked
84where
85 Eq: Fn(Option<&str>, Option<&str>) -> Option<bool>,
86 EqPhysScalar: Fn(&ChunkedArray<T::PolarsPhysical>, T::Native) -> BooleanChunked,
87 EqStrScalar: Fn(&StringChunked, &str) -> BooleanChunked,
88{
89 let mapping = lhs.get_mapping();
90 let null_eq = eq(None, None);
91 match (lhs.len(), rhs.len()) {
92 (lhs_len, 1) => {
93 let Some(s) = rhs.get(0) else {
94 return match null_eq {
95 Some(true) => lhs.physical().is_null(),
96 Some(false) => lhs.physical().is_not_null(),
97 None => BooleanChunked::full_null(lhs.name().clone(), lhs_len),
98 };
99 };
100
101 let is_eq = eq(Some(""), Some("")).unwrap();
102 cat_str_scalar_equality_helper(lhs, s, is_eq, null_eq.is_some(), &eq_phys_scalar)
103 },
104 (1, rhs_len) => {
105 let Some(cat) = lhs.physical().get(0) else {
106 return match null_eq {
107 Some(true) => rhs.is_null().with_name(lhs.name().clone()),
108 Some(false) => rhs.is_not_null().with_name(lhs.name().clone()),
109 None => BooleanChunked::full_null(lhs.name().clone(), rhs_len),
110 };
111 };
112
113 let s = unsafe { mapping.cat_to_str_unchecked(cat.as_cat()) };
115 eq_str_scalar(rhs, s).with_name(lhs.name().clone())
116 },
117 (lhs_len, rhs_len) => {
118 assert!(lhs_len == rhs_len);
119 lhs.iter_str()
120 .zip(rhs.iter())
121 .map(|(l, r)| eq(l, r))
122 .collect_ca_trusted(lhs.name().clone())
123 },
124 }
125}
126
127fn cat_str_compare_helper<T: PolarsCategoricalType, Cmp, CmpStrScalar>(
128 lhs: &CategoricalChunked<T>,
129 rhs: &StringChunked,
130 cmp: Cmp,
131 cmp_str_scalar: CmpStrScalar,
132) -> BooleanChunked
133where
134 Cmp: Fn(&str, &str) -> bool,
135 CmpStrScalar: Fn(&str, &StringChunked) -> BooleanChunked,
136{
137 let mapping = lhs.get_mapping();
138 match (lhs.len(), rhs.len()) {
139 (lhs_len, 1) => {
140 let Some(s) = rhs.get(0) else {
141 return BooleanChunked::full_null(lhs.name().clone(), lhs_len);
142 };
143 cat_str_scalar_compare_helper(lhs, s, cmp)
144 },
145 (1, rhs_len) => {
146 let Some(cat) = lhs.physical().get(0) else {
147 return BooleanChunked::full_null(lhs.name().clone(), rhs_len);
148 };
149
150 let s = unsafe { mapping.cat_to_str_unchecked(cat.as_cat()) };
152 cmp_str_scalar(s, rhs).with_name(lhs.name().clone())
153 },
154 (lhs_len, rhs_len) => {
155 assert!(lhs_len == rhs_len);
156 lhs.iter_str()
157 .zip(rhs.iter())
158 .map(|(l, r)| match (l, r) {
159 (None, _) => None,
160 (_, None) => None,
161 (Some(l), Some(r)) => Some(cmp(l, r)),
162 })
163 .collect_ca_trusted(lhs.name().clone())
164 },
165 }
166}
167
168fn cat_str_phys_compare_helper<T: PolarsCategoricalType, Cmp>(
169 lhs: &CategoricalChunked<T>,
170 rhs: &StringChunked,
171 cmp: Cmp,
172) -> PolarsResult<BooleanChunked>
173where
174 Cmp: Fn(T::Native, T::Native) -> bool,
175{
176 let mapping = lhs.get_mapping();
177 match (lhs.len(), rhs.len()) {
178 (lhs_len, 1) => {
179 let Some(s) = rhs.get(0) else {
180 return Ok(BooleanChunked::full_null(lhs.name().clone(), lhs_len));
181 };
182 cat_str_scalar_phys_compare_helper(lhs, s, cmp)
183 },
184 (1, rhs_len) => {
185 let Some(cat) = lhs.physical().get(0) else {
186 return Ok(BooleanChunked::full_null(lhs.name().clone(), rhs_len));
187 };
188
189 rhs.iter()
190 .map(|opt_r| {
191 if let Some(r) = opt_r {
192 let r = T::Native::from_cat(str_to_cat_enum(mapping, r)?);
193 Ok(Some(cmp(cat, r)))
194 } else {
195 Ok(None)
196 }
197 })
198 .try_collect_ca_trusted(lhs.name().clone())
199 },
200 (lhs_len, rhs_len) => {
201 assert!(lhs_len == rhs_len);
202 lhs.physical()
203 .iter()
204 .zip(rhs.iter())
205 .map(|(l, r)| match (l, r) {
206 (None, _) => Ok(None),
207 (_, None) => Ok(None),
208 (Some(l), Some(r)) => {
209 let r = T::Native::from_cat(str_to_cat_enum(mapping, r)?);
210 Ok(Some(cmp(l, r)))
211 },
212 })
213 .try_collect_ca_trusted(lhs.name().clone())
214 },
215 }
216}
217
218fn cat_str_scalar_equality_helper<T: PolarsCategoricalType, EqPhysScalar>(
219 lhs: &CategoricalChunked<T>,
220 rhs: &str,
221 is_eq: bool,
222 missing: bool,
223 eq_phys_scalar: EqPhysScalar,
224) -> BooleanChunked
225where
226 EqPhysScalar: Fn(&ChunkedArray<T::PolarsPhysical>, T::Native) -> BooleanChunked,
227{
228 let mapping = lhs.get_mapping();
229 let Some(cat) = mapping.get_cat(rhs) else {
230 return if missing {
231 if is_eq {
232 BooleanChunked::full(lhs.name().clone(), false, lhs.len())
233 } else {
234 BooleanChunked::full(lhs.name().clone(), true, lhs.len())
235 }
236 } else {
237 unary_mut_values(lhs.physical(), |arr| {
238 BooleanArray::full(arr.len(), !is_eq, ArrowDataType::Boolean)
239 })
240 };
241 };
242
243 eq_phys_scalar(lhs.physical(), T::Native::from_cat(cat))
244}
245
246fn cat_str_scalar_compare_helper<T: PolarsCategoricalType, Cmp>(
247 lhs: &CategoricalChunked<T>,
248 rhs: &str,
249 cmp: Cmp,
250) -> BooleanChunked
251where
252 Cmp: Fn(&str, &str) -> bool,
253{
254 lhs.iter_str()
255 .map(|opt_l| opt_l.map(|l| cmp(l, rhs)))
256 .collect_ca_trusted(lhs.name().clone())
257}
258
259fn cat_str_scalar_phys_compare_helper<T: PolarsCategoricalType, Cmp>(
260 lhs: &CategoricalChunked<T>,
261 rhs: &str,
262 cmp: Cmp,
263) -> PolarsResult<BooleanChunked>
264where
265 Cmp: Fn(T::Native, T::Native) -> bool,
266{
267 let r = T::Native::from_cat(str_to_cat_enum(lhs.get_mapping(), rhs)?);
268 Ok(lhs
269 .physical()
270 .iter()
271 .map(|opt_l| opt_l.map(|l| cmp(l, r)))
272 .collect_ca_trusted(lhs.name().clone()))
273}
274
275impl<T: PolarsCategoricalType> ChunkCompareEq<&CategoricalChunked<T>> for CategoricalChunked<T>
276where
277 ChunkedArray<T::PolarsPhysical>:
278 for<'a> ChunkCompareEq<&'a ChunkedArray<T::PolarsPhysical>, Item = BooleanChunked>,
279{
280 type Item = PolarsResult<BooleanChunked>;
281
282 fn equal(&self, rhs: &Self) -> Self::Item {
283 cat_equality_helper(self, rhs, |l, r| l.equal(r))
284 }
285
286 fn equal_missing(&self, rhs: &Self) -> Self::Item {
287 cat_equality_helper(self, rhs, |l, r| l.equal_missing(r))
288 }
289
290 fn not_equal(&self, rhs: &Self) -> Self::Item {
291 cat_equality_helper(self, rhs, |l, r| l.not_equal(r))
292 }
293
294 fn not_equal_missing(&self, rhs: &Self) -> Self::Item {
295 cat_equality_helper(self, rhs, |l, r| l.not_equal_missing(r))
296 }
297}
298
299impl<T: PolarsCategoricalType> ChunkCompareIneq<&CategoricalChunked<T>> for CategoricalChunked<T>
300where
301 ChunkedArray<T::PolarsPhysical>:
302 for<'a> ChunkCompareIneq<&'a ChunkedArray<T::PolarsPhysical>, Item = BooleanChunked>,
303{
304 type Item = PolarsResult<BooleanChunked>;
305
306 fn gt(&self, rhs: &CategoricalChunked<T>) -> Self::Item {
307 cat_compare_helper(self, rhs, |l, r| l > r, |l, r| l.gt(r))
308 }
309
310 fn gt_eq(&self, rhs: &CategoricalChunked<T>) -> Self::Item {
311 cat_compare_helper(self, rhs, |l, r| l >= r, |l, r| l.gt_eq(r))
312 }
313
314 fn lt(&self, rhs: &CategoricalChunked<T>) -> Self::Item {
315 cat_compare_helper(self, rhs, |l, r| l < r, |l, r| l.lt(r))
316 }
317
318 fn lt_eq(&self, rhs: &CategoricalChunked<T>) -> Self::Item {
319 cat_compare_helper(self, rhs, |l, r| l <= r, |l, r| l.lt_eq(r))
320 }
321}
322
323impl<T: PolarsCategoricalType> ChunkCompareEq<&StringChunked> for CategoricalChunked<T>
324where
325 ChunkedArray<T::PolarsPhysical>: for<'a> ChunkCompareEq<T::Native, Item = BooleanChunked>,
326{
327 type Item = BooleanChunked;
328
329 fn equal(&self, rhs: &StringChunked) -> Self::Item {
330 cat_str_equality_helper(
331 self,
332 rhs,
333 |l, r| l.zip(r).map(|(l, r)| l == r),
334 |l, c| l.equal(c),
335 |r, c| r.equal(c),
336 )
337 }
338
339 fn equal_missing(&self, rhs: &StringChunked) -> Self::Item {
340 cat_str_equality_helper(
341 self,
342 rhs,
343 |l, r| Some(l == r),
344 |l, c| l.equal_missing(c),
345 |r, c| r.equal_missing(c),
346 )
347 }
348
349 fn not_equal(&self, rhs: &StringChunked) -> Self::Item {
350 cat_str_equality_helper(
351 self,
352 rhs,
353 |l, r| l.zip(r).map(|(l, r)| l != r),
354 |l, c| l.not_equal(c),
355 |r, c| r.not_equal(c),
356 )
357 }
358
359 fn not_equal_missing(&self, rhs: &StringChunked) -> Self::Item {
360 cat_str_equality_helper(
361 self,
362 rhs,
363 |l, r| Some(l != r),
364 |l, c| l.not_equal_missing(c),
365 |r, c| r.not_equal_missing(c),
366 )
367 }
368}
369
370impl<T: PolarsCategoricalType> ChunkCompareIneq<&StringChunked> for CategoricalChunked<T> {
371 type Item = PolarsResult<BooleanChunked>;
372
373 fn gt(&self, rhs: &StringChunked) -> Self::Item {
374 if self.is_enum() {
375 cat_str_phys_compare_helper(self, rhs, |l, r| l > r)
376 } else {
377 Ok(cat_str_compare_helper(
378 self,
379 rhs,
380 |l, r| l > r,
381 |c, r| r.lt(c),
382 ))
383 }
384 }
385
386 fn gt_eq(&self, rhs: &StringChunked) -> Self::Item {
387 if self.is_enum() {
388 cat_str_phys_compare_helper(self, rhs, |l, r| l >= r)
389 } else {
390 Ok(cat_str_compare_helper(
391 self,
392 rhs,
393 |l, r| l >= r,
394 |c, r| r.lt_eq(c),
395 ))
396 }
397 }
398
399 fn lt(&self, rhs: &StringChunked) -> Self::Item {
400 if self.is_enum() {
401 cat_str_phys_compare_helper(self, rhs, |l, r| l < r)
402 } else {
403 Ok(cat_str_compare_helper(
404 self,
405 rhs,
406 |l, r| l < r,
407 |c, r| r.gt(c),
408 ))
409 }
410 }
411
412 fn lt_eq(&self, rhs: &StringChunked) -> Self::Item {
413 if self.is_enum() {
414 cat_str_phys_compare_helper(self, rhs, |l, r| l <= r)
415 } else {
416 Ok(cat_str_compare_helper(
417 self,
418 rhs,
419 |l, r| l <= r,
420 |c, r| r.gt_eq(c),
421 ))
422 }
423 }
424}
425
426impl<T: PolarsCategoricalType> ChunkCompareEq<&str> for CategoricalChunked<T>
427where
428 ChunkedArray<T::PolarsPhysical>: for<'a> ChunkCompareEq<T::Native, Item = BooleanChunked>,
429{
430 type Item = BooleanChunked;
431
432 fn equal(&self, rhs: &str) -> Self::Item {
433 cat_str_scalar_equality_helper(self, rhs, true, false, |l, c| l.equal(c))
434 }
435
436 fn equal_missing(&self, rhs: &str) -> Self::Item {
437 cat_str_scalar_equality_helper(self, rhs, true, true, |l, c| l.equal_missing(c))
438 }
439
440 fn not_equal(&self, rhs: &str) -> Self::Item {
441 cat_str_scalar_equality_helper(self, rhs, false, false, |r, c| r.not_equal(c))
442 }
443
444 fn not_equal_missing(&self, rhs: &str) -> Self::Item {
445 cat_str_scalar_equality_helper(self, rhs, false, true, |l, c| l.not_equal_missing(c))
446 }
447}
448
449impl<T: PolarsCategoricalType> ChunkCompareIneq<&str> for CategoricalChunked<T> {
450 type Item = BooleanChunked;
451
452 fn gt(&self, rhs: &str) -> Self::Item {
453 cat_str_scalar_compare_helper(self, rhs, |l, r| l > r)
454 }
455
456 fn gt_eq(&self, rhs: &str) -> Self::Item {
457 cat_str_scalar_compare_helper(self, rhs, |l, r| l >= r)
458 }
459
460 fn lt(&self, rhs: &str) -> Self::Item {
461 cat_str_scalar_compare_helper(self, rhs, |l, r| l < r)
462 }
463
464 fn lt_eq(&self, rhs: &str) -> Self::Item {
465 cat_str_scalar_compare_helper(self, rhs, |l, r| l <= r)
466 }
467}