polars_core/series/
comparison.rs

1//! Comparison operations on Series.
2
3use polars_error::feature_gated;
4
5use crate::prelude::*;
6use crate::series::arithmetic::coerce_lhs_rhs;
7use crate::series::nulls::replace_non_null;
8
9macro_rules! impl_eq_compare {
10    ($self:expr, $rhs:expr, $method:ident) => {{
11        use DataType::*;
12        let (lhs, rhs) = ($self, $rhs);
13        validate_types(lhs.dtype(), rhs.dtype())?;
14
15        polars_ensure!(
16            lhs.len() == rhs.len() ||
17
18            // Broadcast
19            lhs.len() == 1 ||
20            rhs.len() == 1,
21            ShapeMismatch: "could not compare between two series of different length ({} != {})",
22            lhs.len(),
23            rhs.len()
24        );
25
26        match (lhs.dtype(), rhs.dtype()) {
27            #[cfg(feature = "dtype-categorical")]
28            (Categorical(lcats, _), Categorical(rcats, _)) => {
29                ensure_same_categories(lcats, rcats)?;
30                return with_match_categorical_physical_type!(lcats.physical(), |$C| {
31                    lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap())
32                })
33            },
34            #[cfg(feature = "dtype-categorical")]
35            (Enum(lfcats, _), Enum(rfcats, _)) => {
36                ensure_same_frozen_categories(lfcats, rfcats)?;
37                return with_match_categorical_physical_type!(lfcats.physical(), |$C| {
38                    lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap())
39                })
40            },
41            #[cfg(feature = "dtype-categorical")]
42            (Categorical(_, _) | Enum(_, _), String) => {
43                return with_match_categorical_physical_type!(lhs.dtype().cat_physical().unwrap(), |$C| {
44                    Ok(lhs.cat::<$C>().unwrap().$method(rhs.str().unwrap()))
45                })
46            },
47            #[cfg(feature = "dtype-categorical")]
48            (String, Categorical(_, _) | Enum(_, _)) => {
49                return with_match_categorical_physical_type!(rhs.dtype().cat_physical().unwrap(), |$C| {
50                    Ok(rhs.cat::<$C>().unwrap().$method(lhs.str().unwrap()))
51                })
52            },
53
54            #[cfg(feature = "dtype-extension")]
55            (le @ Extension(_, _), re @ Extension(_, _)) if le == re => {
56                let lhs = lhs.ext().unwrap();
57                let rhs = rhs.ext().unwrap();
58                return lhs.storage().$method(rhs.storage());
59            },
60
61            #[cfg(feature = "dtype-extension")]
62            (Extension(_, storage), rdt) if **storage == *rdt => {
63                let lhs = lhs.ext().unwrap();
64                return lhs.storage().$method(rhs);
65            },
66
67            #[cfg(feature = "dtype-extension")]
68            (ldt, Extension(_, storage)) if *ldt == **storage => {
69                let rhs = rhs.ext().unwrap();
70                return lhs.$method(rhs.storage());
71            },
72            _ => (),
73        };
74
75        let (lhs, rhs) = coerce_lhs_rhs(lhs, rhs)
76            .map_err(|_| polars_err!(
77                    SchemaMismatch: "could not evaluate comparison between series '{}' of dtype: {:?} and series '{}' of dtype: {:?}",
78                    lhs.name(), lhs.dtype(), rhs.name(), rhs.dtype()
79            ))?;
80        let lhs = lhs.to_physical_repr();
81        let rhs = rhs.to_physical_repr();
82        let mut out = match lhs.dtype() {
83            Null => lhs.null().unwrap().$method(rhs.null().unwrap()),
84            Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()),
85            String => lhs.str().unwrap().$method(rhs.str().unwrap()),
86            Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()),
87            BinaryOffset => lhs.binary_offset().unwrap().$method(rhs.binary_offset().unwrap()),
88            UInt8 => feature_gated!("dtype-u8", lhs.u8().unwrap().$method(rhs.u8().unwrap())),
89            UInt16 => feature_gated!("dtype-u16", lhs.u16().unwrap().$method(rhs.u16().unwrap())),
90            UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()),
91            UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()),
92            UInt128 => feature_gated!("dtype-u128", lhs.u128().unwrap().$method(rhs.u128().unwrap())),
93            Int8 => feature_gated!("dtype-i8", lhs.i8().unwrap().$method(rhs.i8().unwrap())),
94            Int16 => feature_gated!("dtype-i16", lhs.i16().unwrap().$method(rhs.i16().unwrap())),
95            Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()),
96            Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()),
97            Int128 => feature_gated!("dtype-i128", lhs.i128().unwrap().$method(rhs.i128().unwrap())),
98            Float16 => feature_gated!("dtype-f16", lhs.f16().unwrap().$method(rhs.f16().unwrap())),
99            Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()),
100            Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()),
101            List(_) => lhs.list().unwrap().$method(rhs.list().unwrap()),
102            #[cfg(feature = "dtype-array")]
103            Array(_, _) => lhs.array().unwrap().$method(rhs.array().unwrap()),
104            #[cfg(feature = "dtype-struct")]
105            Struct(_) => lhs.struct_().unwrap().$method(rhs.struct_().unwrap()),
106
107            dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()),
108        };
109        out.rename(lhs.name().clone());
110        PolarsResult::Ok(out)
111    }};
112}
113
114macro_rules! bail_invalid_ineq {
115    ($lhs:expr, $rhs:expr, $op:literal) => {
116        polars_bail!(
117            InvalidOperation: "cannot perform '{}' comparison between series '{}' of dtype: {} and series '{}' of dtype: {}",
118            $op,
119            $lhs.name(), $lhs.dtype(),
120            $rhs.name(), $rhs.dtype(),
121        )
122    };
123}
124
125macro_rules! impl_ineq_compare {
126    ($self:expr, $rhs:expr, $method:ident, $op:literal, $rev_method:ident) => {{
127        use DataType::*;
128        let (lhs, rhs) = ($self, $rhs);
129        validate_types(lhs.dtype(), rhs.dtype())?;
130
131        polars_ensure!(
132            lhs.len() == rhs.len() ||
133
134            // Broadcast
135            lhs.len() == 1 ||
136            rhs.len() == 1,
137            ShapeMismatch:
138                "could not perform '{}' comparison between series '{}' of length: {} and series '{}' of length: {}, because they have different lengths",
139            $op,
140            lhs.name(), lhs.len(),
141            rhs.name(), rhs.len()
142        );
143
144        match (lhs.dtype(), rhs.dtype()) {
145            #[cfg(feature = "dtype-categorical")]
146            (Categorical(lcats, _), Categorical(rcats, _)) => {
147                ensure_same_categories(lcats, rcats)?;
148                return with_match_categorical_physical_type!(lcats.physical(), |$C| {
149                    lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap())
150                })
151            },
152            #[cfg(feature = "dtype-categorical")]
153            (Enum(lfcats, _), Enum(rfcats, _)) => {
154                ensure_same_frozen_categories(lfcats, rfcats)?;
155                return with_match_categorical_physical_type!(lfcats.physical(), |$C| {
156                    lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap())
157                })
158            },
159            #[cfg(feature = "dtype-categorical")]
160            (Categorical(_, _) | Enum(_, _), String) => {
161                return with_match_categorical_physical_type!(lhs.dtype().cat_physical().unwrap(), |$C| {
162                    lhs.cat::<$C>().unwrap().$method(rhs.str().unwrap())
163                })
164            },
165            #[cfg(feature = "dtype-categorical")]
166            (String, Categorical(_, _) | Enum(_, _)) => {
167                return with_match_categorical_physical_type!(rhs.dtype().cat_physical().unwrap(), |$C| {
168                    // We use the reverse method as string <-> enum comparisons are only implemented one-way.
169                    rhs.cat::<$C>().unwrap().$rev_method(lhs.str().unwrap())
170                })
171            },
172            #[cfg(feature = "dtype-extension")]
173            (le @ Extension(_, _), re @ Extension(_, _)) if le == re => {
174                let lhs = lhs.ext().unwrap();
175                let rhs = rhs.ext().unwrap();
176                return lhs.storage().$method(rhs.storage());
177            },
178
179            #[cfg(feature = "dtype-extension")]
180            (Extension(_, storage), rdt) if **storage == *rdt => {
181                let lhs = lhs.ext().unwrap();
182                return lhs.storage().$method(rhs);
183            },
184
185            #[cfg(feature = "dtype-extension")]
186            (ldt, Extension(_, storage)) if *ldt == **storage => {
187                let rhs = rhs.ext().unwrap();
188                return lhs.$method(rhs.storage());
189            },
190            _ => (),
191        };
192
193        let (lhs, rhs) = coerce_lhs_rhs(lhs, rhs).map_err(|_|
194            polars_err!(
195                SchemaMismatch: "could not evaluate '{}' comparison between series '{}' of dtype: {:?} and series '{}' of dtype: {:?}",
196                $op,
197                lhs.name(), lhs.dtype(),
198                rhs.name(), rhs.dtype()
199            )
200        )?;
201        let lhs = lhs.to_physical_repr();
202        let rhs = rhs.to_physical_repr();
203        let mut out = match lhs.dtype() {
204            Null => lhs.null().unwrap().$method(rhs.null().unwrap()),
205            Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()),
206            String => lhs.str().unwrap().$method(rhs.str().unwrap()),
207            Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()),
208            BinaryOffset => lhs.binary_offset().unwrap().$method(rhs.binary_offset().unwrap()),
209            UInt8 => feature_gated!("dtype-u8", lhs.u8().unwrap().$method(rhs.u8().unwrap())),
210            UInt16 => feature_gated!("dtype-u16", lhs.u16().unwrap().$method(rhs.u16().unwrap())),
211            UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()),
212            UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()),
213            UInt128 => feature_gated!("dtype-u128", lhs.u128().unwrap().$method(rhs.u128().unwrap())),
214            Int8 => feature_gated!("dtype-i8", lhs.i8().unwrap().$method(rhs.i8().unwrap())),
215            Int16 => feature_gated!("dtype-i16", lhs.i16().unwrap().$method(rhs.i16().unwrap())),
216            Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()),
217            Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()),
218            Int128 => feature_gated!("dtype-i128", lhs.i128().unwrap().$method(rhs.i128().unwrap())),
219            Float16 => feature_gated!("dtype-f16", lhs.f16().unwrap().$method(rhs.f16().unwrap())),
220            Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()),
221            Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()),
222            List(_) => bail_invalid_ineq!(lhs, rhs, $op),
223            #[cfg(feature = "dtype-array")]
224            Array(_, _) => bail_invalid_ineq!(lhs, rhs, $op),
225            #[cfg(feature = "dtype-struct")]
226            Struct(_) => bail_invalid_ineq!(lhs, rhs, $op),
227
228            dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}'; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()),
229        };
230        out.rename(lhs.name().clone());
231        PolarsResult::Ok(out)
232    }};
233}
234
235fn validate_types(left: &DataType, right: &DataType) -> PolarsResult<()> {
236    use DataType::*;
237
238    match (left, right) {
239        (String, dt) | (dt, String) if dt.is_primitive_numeric() => {
240            polars_bail!(ComputeError: "cannot compare string with numeric type ({})", dt)
241        },
242        #[cfg(feature = "dtype-categorical")]
243        (Categorical(_, _) | Enum(_, _), dt) | (dt, Categorical(_, _) | Enum(_, _))
244            if !(dt.is_categorical() | dt.is_string() | dt.is_enum()) =>
245        {
246            polars_bail!(ComputeError: "cannot compare categorical with {}", dt)
247        },
248        #[cfg(feature = "dtype-duration")]
249        (Date, Duration(_)) | (Duration(_), Date) => {
250            polars_bail!(ComputeError: "cannot compare date with duration")
251        },
252        _ => (),
253    };
254    Ok(())
255}
256
257impl ChunkCompareEq<&Series> for Series {
258    type Item = PolarsResult<BooleanChunked>;
259
260    /// Create a boolean mask by checking for equality.
261    fn equal(&self, rhs: &Series) -> Self::Item {
262        impl_eq_compare!(self, rhs, equal)
263    }
264
265    /// Create a boolean mask by checking for equality.
266    fn equal_missing(&self, rhs: &Series) -> Self::Item {
267        impl_eq_compare!(self, rhs, equal_missing)
268    }
269
270    /// Create a boolean mask by checking for inequality.
271    fn not_equal(&self, rhs: &Series) -> Self::Item {
272        impl_eq_compare!(self, rhs, not_equal)
273    }
274
275    /// Create a boolean mask by checking for inequality.
276    fn not_equal_missing(&self, rhs: &Series) -> Self::Item {
277        impl_eq_compare!(self, rhs, not_equal_missing)
278    }
279}
280
281impl ChunkCompareIneq<&Series> for Series {
282    type Item = PolarsResult<BooleanChunked>;
283
284    /// Create a boolean mask by checking if self > rhs.
285    fn gt(&self, rhs: &Series) -> Self::Item {
286        impl_ineq_compare!(self, rhs, gt, ">", lt)
287    }
288
289    /// Create a boolean mask by checking if self >= rhs.
290    fn gt_eq(&self, rhs: &Series) -> Self::Item {
291        impl_ineq_compare!(self, rhs, gt_eq, ">=", lt_eq)
292    }
293
294    /// Create a boolean mask by checking if self < rhs.
295    fn lt(&self, rhs: &Series) -> Self::Item {
296        impl_ineq_compare!(self, rhs, lt, "<", gt)
297    }
298
299    /// Create a boolean mask by checking if self <= rhs.
300    fn lt_eq(&self, rhs: &Series) -> Self::Item {
301        impl_ineq_compare!(self, rhs, lt_eq, "<=", gt_eq)
302    }
303}
304
305impl<Rhs> ChunkCompareEq<Rhs> for Series
306where
307    Rhs: NumericNative,
308{
309    type Item = PolarsResult<BooleanChunked>;
310
311    fn equal(&self, rhs: Rhs) -> Self::Item {
312        validate_types(self.dtype(), &DataType::Int8)?;
313        let s = self.to_physical_repr();
314        Ok(apply_method_physical_numeric!(&s, equal, rhs))
315    }
316
317    fn equal_missing(&self, rhs: Rhs) -> Self::Item {
318        validate_types(self.dtype(), &DataType::Int8)?;
319        let s = self.to_physical_repr();
320        Ok(apply_method_physical_numeric!(&s, equal_missing, rhs))
321    }
322
323    fn not_equal(&self, rhs: Rhs) -> Self::Item {
324        validate_types(self.dtype(), &DataType::Int8)?;
325        let s = self.to_physical_repr();
326        Ok(apply_method_physical_numeric!(&s, not_equal, rhs))
327    }
328
329    fn not_equal_missing(&self, rhs: Rhs) -> Self::Item {
330        validate_types(self.dtype(), &DataType::Int8)?;
331        let s = self.to_physical_repr();
332        Ok(apply_method_physical_numeric!(&s, not_equal_missing, rhs))
333    }
334}
335
336impl<Rhs> ChunkCompareIneq<Rhs> for Series
337where
338    Rhs: NumericNative,
339{
340    type Item = PolarsResult<BooleanChunked>;
341
342    fn gt(&self, rhs: Rhs) -> Self::Item {
343        validate_types(self.dtype(), &DataType::Int8)?;
344        let s = self.to_physical_repr();
345        Ok(apply_method_physical_numeric!(&s, gt, rhs))
346    }
347
348    fn gt_eq(&self, rhs: Rhs) -> Self::Item {
349        validate_types(self.dtype(), &DataType::Int8)?;
350        let s = self.to_physical_repr();
351        Ok(apply_method_physical_numeric!(&s, gt_eq, rhs))
352    }
353
354    fn lt(&self, rhs: Rhs) -> Self::Item {
355        validate_types(self.dtype(), &DataType::Int8)?;
356        let s = self.to_physical_repr();
357        Ok(apply_method_physical_numeric!(&s, lt, rhs))
358    }
359
360    fn lt_eq(&self, rhs: Rhs) -> Self::Item {
361        validate_types(self.dtype(), &DataType::Int8)?;
362        let s = self.to_physical_repr();
363        Ok(apply_method_physical_numeric!(&s, lt_eq, rhs))
364    }
365}
366
367impl ChunkCompareEq<&str> for Series {
368    type Item = PolarsResult<BooleanChunked>;
369
370    fn equal(&self, rhs: &str) -> PolarsResult<BooleanChunked> {
371        validate_types(self.dtype(), &DataType::String)?;
372        match self.dtype() {
373            DataType::String => Ok(self.str().unwrap().equal(rhs)),
374            #[cfg(feature = "dtype-categorical")]
375            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
376                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
377                    self.cat::<$C>().unwrap().equal(rhs)
378                }),
379            ),
380            #[cfg(feature = "dtype-extension")]
381            DataType::Extension(_, _) => self.ext().unwrap().storage().equal(rhs),
382            _ => Ok(BooleanChunked::full(self.name().clone(), false, self.len())),
383        }
384    }
385
386    fn equal_missing(&self, rhs: &str) -> Self::Item {
387        validate_types(self.dtype(), &DataType::String)?;
388        match self.dtype() {
389            DataType::String => Ok(self.str().unwrap().equal_missing(rhs)),
390            #[cfg(feature = "dtype-categorical")]
391            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
392                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
393                    self.cat::<$C>().unwrap().equal_missing(rhs)
394                }),
395            ),
396            #[cfg(feature = "dtype-extension")]
397            DataType::Extension(_, _) => self.ext().unwrap().storage().equal_missing(rhs),
398            _ => Ok(replace_non_null(
399                self.name().clone(),
400                self.0.chunks(),
401                false,
402            )),
403        }
404    }
405
406    fn not_equal(&self, rhs: &str) -> PolarsResult<BooleanChunked> {
407        validate_types(self.dtype(), &DataType::String)?;
408        match self.dtype() {
409            DataType::String => Ok(self.str().unwrap().not_equal(rhs)),
410            #[cfg(feature = "dtype-categorical")]
411            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
412                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
413                    self.cat::<$C>().unwrap().not_equal(rhs)
414                }),
415            ),
416            #[cfg(feature = "dtype-extension")]
417            DataType::Extension(_, _) => self.ext().unwrap().storage().not_equal(rhs),
418            _ => Ok(BooleanChunked::full(self.name().clone(), true, self.len())),
419        }
420    }
421
422    fn not_equal_missing(&self, rhs: &str) -> Self::Item {
423        validate_types(self.dtype(), &DataType::String)?;
424        match self.dtype() {
425            DataType::String => Ok(self.str().unwrap().not_equal_missing(rhs)),
426            #[cfg(feature = "dtype-categorical")]
427            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
428                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
429                    self.cat::<$C>().unwrap().not_equal_missing(rhs)
430                }),
431            ),
432            #[cfg(feature = "dtype-extension")]
433            DataType::Extension(_, _) => self.ext().unwrap().storage().not_equal_missing(rhs),
434            _ => Ok(replace_non_null(self.name().clone(), self.0.chunks(), true)),
435        }
436    }
437}
438
439impl ChunkCompareIneq<&str> for Series {
440    type Item = PolarsResult<BooleanChunked>;
441
442    fn gt(&self, rhs: &str) -> Self::Item {
443        validate_types(self.dtype(), &DataType::String)?;
444        match self.dtype() {
445            DataType::String => Ok(self.str().unwrap().gt(rhs)),
446            #[cfg(feature = "dtype-categorical")]
447            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
448                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
449                    self.cat::<$C>().unwrap().gt(rhs)
450                }),
451            ),
452            #[cfg(feature = "dtype-extension")]
453            DataType::Extension(_, _) => self.ext().unwrap().storage().gt(rhs),
454            _ => polars_bail!(
455                ComputeError: "cannot compare str value to series of type {}", self.dtype(),
456            ),
457        }
458    }
459
460    fn gt_eq(&self, rhs: &str) -> Self::Item {
461        validate_types(self.dtype(), &DataType::String)?;
462        match self.dtype() {
463            DataType::String => Ok(self.str().unwrap().gt_eq(rhs)),
464            #[cfg(feature = "dtype-categorical")]
465            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
466                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
467                    self.cat::<$C>().unwrap().gt_eq(rhs)
468                }),
469            ),
470            #[cfg(feature = "dtype-extension")]
471            DataType::Extension(_, _) => self.ext().unwrap().storage().gt_eq(rhs),
472            _ => polars_bail!(
473                ComputeError: "cannot compare str value to series of type {}", self.dtype(),
474            ),
475        }
476    }
477
478    fn lt(&self, rhs: &str) -> Self::Item {
479        validate_types(self.dtype(), &DataType::String)?;
480        match self.dtype() {
481            DataType::String => Ok(self.str().unwrap().lt(rhs)),
482            #[cfg(feature = "dtype-categorical")]
483            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
484                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
485                    self.cat::<$C>().unwrap().lt(rhs)
486                }),
487            ),
488            #[cfg(feature = "dtype-extension")]
489            DataType::Extension(_, _) => self.ext().unwrap().storage().lt(rhs),
490            _ => polars_bail!(
491                ComputeError: "cannot compare str value to series of type {}", self.dtype(),
492            ),
493        }
494    }
495
496    fn lt_eq(&self, rhs: &str) -> Self::Item {
497        validate_types(self.dtype(), &DataType::String)?;
498        match self.dtype() {
499            DataType::String => Ok(self.str().unwrap().lt_eq(rhs)),
500            #[cfg(feature = "dtype-categorical")]
501            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
502                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
503                    self.cat::<$C>().unwrap().lt_eq(rhs)
504                }),
505            ),
506            #[cfg(feature = "dtype-extension")]
507            DataType::Extension(_, _) => self.ext().unwrap().storage().lt_eq(rhs),
508            _ => polars_bail!(
509                ComputeError: "cannot compare str value to series of type {}", self.dtype(),
510            ),
511        }
512    }
513}