polars_core/series/
comparison.rs

1//! Comparison operations on Series.
2
3use polars_error::feature_gated;
4
5use crate::prelude::*;
6use crate::series::arithmetic::coerce_lhs_rhs;
7use crate::series::nulls::replace_non_null;
8
9macro_rules! impl_eq_compare {
10    ($self:expr, $rhs:expr, $method:ident) => {{
11        use DataType::*;
12        let (lhs, rhs) = ($self, $rhs);
13        validate_types(lhs.dtype(), rhs.dtype())?;
14
15        polars_ensure!(
16            lhs.len() == rhs.len() ||
17
18            // Broadcast
19            lhs.len() == 1 ||
20            rhs.len() == 1,
21            ShapeMismatch: "could not compare between two series of different length ({} != {})",
22            lhs.len(),
23            rhs.len()
24        );
25
26        match (lhs.dtype(), rhs.dtype()) {
27            #[cfg(feature = "dtype-categorical")]
28            (Categorical(lcats, _), Categorical(rcats, _)) => {
29                ensure_same_categories(lcats, rcats)?;
30                return with_match_categorical_physical_type!(lcats.physical(), |$C| {
31                    lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap())
32                })
33            },
34            #[cfg(feature = "dtype-categorical")]
35            (Enum(lfcats, _), Enum(rfcats, _)) => {
36                ensure_same_frozen_categories(lfcats, rfcats)?;
37                return with_match_categorical_physical_type!(lfcats.physical(), |$C| {
38                    lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap())
39                })
40            },
41            #[cfg(feature = "dtype-categorical")]
42            (Categorical(_, _) | Enum(_, _), String) => {
43                return with_match_categorical_physical_type!(lhs.dtype().cat_physical().unwrap(), |$C| {
44                    Ok(lhs.cat::<$C>().unwrap().$method(rhs.str().unwrap()))
45                })
46            },
47            #[cfg(feature = "dtype-categorical")]
48            (String, Categorical(_, _) | Enum(_, _)) => {
49                return with_match_categorical_physical_type!(rhs.dtype().cat_physical().unwrap(), |$C| {
50                    Ok(rhs.cat::<$C>().unwrap().$method(lhs.str().unwrap()))
51                })
52            },
53
54            #[cfg(feature = "dtype-extension")]
55            (le @ Extension(_, _), re @ Extension(_, _)) if le == re => {
56                let lhs = lhs.ext().unwrap();
57                let rhs = rhs.ext().unwrap();
58                return lhs.storage().$method(rhs.storage());
59            },
60
61            #[cfg(feature = "dtype-extension")]
62            (Extension(_, storage), rdt) if **storage == *rdt => {
63                let lhs = lhs.ext().unwrap();
64                return lhs.storage().$method(rhs);
65            },
66
67            #[cfg(feature = "dtype-extension")]
68            (ldt, Extension(_, storage)) if *ldt == **storage => {
69                let rhs = rhs.ext().unwrap();
70                return lhs.$method(rhs.storage());
71            },
72            _ => (),
73        };
74
75        let (lhs, rhs) = coerce_lhs_rhs(lhs, rhs)
76            .map_err(|_| polars_err!(
77                    SchemaMismatch: "could not evaluate comparison between series '{}' of dtype: {:?} and series '{}' of dtype: {:?}",
78                    lhs.name(), lhs.dtype(), rhs.name(), rhs.dtype()
79            ))?;
80        let lhs = lhs.to_physical_repr();
81        let rhs = rhs.to_physical_repr();
82        let mut out = match lhs.dtype() {
83            Null => lhs.null().unwrap().$method(rhs.null().unwrap()),
84            Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()),
85            String => lhs.str().unwrap().$method(rhs.str().unwrap()),
86            Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()),
87            BinaryOffset => lhs.binary_offset().unwrap().$method(rhs.binary_offset().unwrap()),
88            UInt8 => feature_gated!("dtype-u8", lhs.u8().unwrap().$method(rhs.u8().unwrap())),
89            UInt16 => feature_gated!("dtype-u16", lhs.u16().unwrap().$method(rhs.u16().unwrap())),
90            UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()),
91            UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()),
92            UInt128 => feature_gated!("dtype-u128", lhs.u128().unwrap().$method(rhs.u128().unwrap())),
93            Int8 => feature_gated!("dtype-i8", lhs.i8().unwrap().$method(rhs.i8().unwrap())),
94            Int16 => feature_gated!("dtype-i16", lhs.i16().unwrap().$method(rhs.i16().unwrap())),
95            Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()),
96            Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()),
97            Int128 => feature_gated!("dtype-i128", lhs.i128().unwrap().$method(rhs.i128().unwrap())),
98            Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()),
99            Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()),
100            List(_) => lhs.list().unwrap().$method(rhs.list().unwrap()),
101            #[cfg(feature = "dtype-array")]
102            Array(_, _) => lhs.array().unwrap().$method(rhs.array().unwrap()),
103            #[cfg(feature = "dtype-struct")]
104            Struct(_) => lhs.struct_().unwrap().$method(rhs.struct_().unwrap()),
105
106            dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()),
107        };
108        out.rename(lhs.name().clone());
109        PolarsResult::Ok(out)
110    }};
111}
112
113macro_rules! bail_invalid_ineq {
114    ($lhs:expr, $rhs:expr, $op:literal) => {
115        polars_bail!(
116            InvalidOperation: "cannot perform '{}' comparison between series '{}' of dtype: {} and series '{}' of dtype: {}",
117            $op,
118            $lhs.name(), $lhs.dtype(),
119            $rhs.name(), $rhs.dtype(),
120        )
121    };
122}
123
124macro_rules! impl_ineq_compare {
125    ($self:expr, $rhs:expr, $method:ident, $op:literal, $rev_method:ident) => {{
126        use DataType::*;
127        let (lhs, rhs) = ($self, $rhs);
128        validate_types(lhs.dtype(), rhs.dtype())?;
129
130        polars_ensure!(
131            lhs.len() == rhs.len() ||
132
133            // Broadcast
134            lhs.len() == 1 ||
135            rhs.len() == 1,
136            ShapeMismatch:
137                "could not perform '{}' comparison between series '{}' of length: {} and series '{}' of length: {}, because they have different lengths",
138            $op,
139            lhs.name(), lhs.len(),
140            rhs.name(), rhs.len()
141        );
142
143        match (lhs.dtype(), rhs.dtype()) {
144            #[cfg(feature = "dtype-categorical")]
145            (Categorical(lcats, _), Categorical(rcats, _)) => {
146                ensure_same_categories(lcats, rcats)?;
147                return with_match_categorical_physical_type!(lcats.physical(), |$C| {
148                    lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap())
149                })
150            },
151            #[cfg(feature = "dtype-categorical")]
152            (Enum(lfcats, _), Enum(rfcats, _)) => {
153                ensure_same_frozen_categories(lfcats, rfcats)?;
154                return with_match_categorical_physical_type!(lfcats.physical(), |$C| {
155                    lhs.cat::<$C>().unwrap().$method(rhs.cat::<$C>().unwrap())
156                })
157            },
158            #[cfg(feature = "dtype-categorical")]
159            (Categorical(_, _) | Enum(_, _), String) => {
160                return with_match_categorical_physical_type!(lhs.dtype().cat_physical().unwrap(), |$C| {
161                    lhs.cat::<$C>().unwrap().$method(rhs.str().unwrap())
162                })
163            },
164            #[cfg(feature = "dtype-categorical")]
165            (String, Categorical(_, _) | Enum(_, _)) => {
166                return with_match_categorical_physical_type!(rhs.dtype().cat_physical().unwrap(), |$C| {
167                    // We use the reverse method as string <-> enum comparisons are only implemented one-way.
168                    rhs.cat::<$C>().unwrap().$rev_method(lhs.str().unwrap())
169                })
170            },
171            #[cfg(feature = "dtype-extension")]
172            (le @ Extension(_, _), re @ Extension(_, _)) if le == re => {
173                let lhs = lhs.ext().unwrap();
174                let rhs = rhs.ext().unwrap();
175                return lhs.storage().$method(rhs.storage());
176            },
177
178            #[cfg(feature = "dtype-extension")]
179            (Extension(_, storage), rdt) if **storage == *rdt => {
180                let lhs = lhs.ext().unwrap();
181                return lhs.storage().$method(rhs);
182            },
183
184            #[cfg(feature = "dtype-extension")]
185            (ldt, Extension(_, storage)) if *ldt == **storage => {
186                let rhs = rhs.ext().unwrap();
187                return lhs.$method(rhs.storage());
188            },
189            _ => (),
190        };
191
192        let (lhs, rhs) = coerce_lhs_rhs(lhs, rhs).map_err(|_|
193            polars_err!(
194                SchemaMismatch: "could not evaluate '{}' comparison between series '{}' of dtype: {:?} and series '{}' of dtype: {:?}",
195                $op,
196                lhs.name(), lhs.dtype(),
197                rhs.name(), rhs.dtype()
198            )
199        )?;
200        let lhs = lhs.to_physical_repr();
201        let rhs = rhs.to_physical_repr();
202        let mut out = match lhs.dtype() {
203            Null => lhs.null().unwrap().$method(rhs.null().unwrap()),
204            Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()),
205            String => lhs.str().unwrap().$method(rhs.str().unwrap()),
206            Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()),
207            BinaryOffset => lhs.binary_offset().unwrap().$method(rhs.binary_offset().unwrap()),
208            UInt8 => feature_gated!("dtype-u8", lhs.u8().unwrap().$method(rhs.u8().unwrap())),
209            UInt16 => feature_gated!("dtype-u16", lhs.u16().unwrap().$method(rhs.u16().unwrap())),
210            UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()),
211            UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()),
212            UInt128 => feature_gated!("dtype-u128", lhs.u128().unwrap().$method(rhs.u128().unwrap())),
213            Int8 => feature_gated!("dtype-i8", lhs.i8().unwrap().$method(rhs.i8().unwrap())),
214            Int16 => feature_gated!("dtype-i16", lhs.i16().unwrap().$method(rhs.i16().unwrap())),
215            Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()),
216            Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()),
217            Int128 => feature_gated!("dtype-i128", lhs.i128().unwrap().$method(rhs.i128().unwrap())),
218            Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()),
219            Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()),
220            List(_) => bail_invalid_ineq!(lhs, rhs, $op),
221            #[cfg(feature = "dtype-array")]
222            Array(_, _) => bail_invalid_ineq!(lhs, rhs, $op),
223            #[cfg(feature = "dtype-struct")]
224            Struct(_) => bail_invalid_ineq!(lhs, rhs, $op),
225
226            dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()),
227        };
228        out.rename(lhs.name().clone());
229        PolarsResult::Ok(out)
230    }};
231}
232
233fn validate_types(left: &DataType, right: &DataType) -> PolarsResult<()> {
234    use DataType::*;
235
236    match (left, right) {
237        (String, dt) | (dt, String) if dt.is_primitive_numeric() => {
238            polars_bail!(ComputeError: "cannot compare string with numeric type ({})", dt)
239        },
240        #[cfg(feature = "dtype-categorical")]
241        (Categorical(_, _) | Enum(_, _), dt) | (dt, Categorical(_, _) | Enum(_, _))
242            if !(dt.is_categorical() | dt.is_string() | dt.is_enum()) =>
243        {
244            polars_bail!(ComputeError: "cannot compare categorical with {}", dt);
245        },
246        _ => (),
247    };
248    Ok(())
249}
250
251impl ChunkCompareEq<&Series> for Series {
252    type Item = PolarsResult<BooleanChunked>;
253
254    /// Create a boolean mask by checking for equality.
255    fn equal(&self, rhs: &Series) -> Self::Item {
256        impl_eq_compare!(self, rhs, equal)
257    }
258
259    /// Create a boolean mask by checking for equality.
260    fn equal_missing(&self, rhs: &Series) -> Self::Item {
261        impl_eq_compare!(self, rhs, equal_missing)
262    }
263
264    /// Create a boolean mask by checking for inequality.
265    fn not_equal(&self, rhs: &Series) -> Self::Item {
266        impl_eq_compare!(self, rhs, not_equal)
267    }
268
269    /// Create a boolean mask by checking for inequality.
270    fn not_equal_missing(&self, rhs: &Series) -> Self::Item {
271        impl_eq_compare!(self, rhs, not_equal_missing)
272    }
273}
274
275impl ChunkCompareIneq<&Series> for Series {
276    type Item = PolarsResult<BooleanChunked>;
277
278    /// Create a boolean mask by checking if self > rhs.
279    fn gt(&self, rhs: &Series) -> Self::Item {
280        impl_ineq_compare!(self, rhs, gt, ">", lt)
281    }
282
283    /// Create a boolean mask by checking if self >= rhs.
284    fn gt_eq(&self, rhs: &Series) -> Self::Item {
285        impl_ineq_compare!(self, rhs, gt_eq, ">=", lt_eq)
286    }
287
288    /// Create a boolean mask by checking if self < rhs.
289    fn lt(&self, rhs: &Series) -> Self::Item {
290        impl_ineq_compare!(self, rhs, lt, "<", gt)
291    }
292
293    /// Create a boolean mask by checking if self <= rhs.
294    fn lt_eq(&self, rhs: &Series) -> Self::Item {
295        impl_ineq_compare!(self, rhs, lt_eq, "<=", gt_eq)
296    }
297}
298
299impl<Rhs> ChunkCompareEq<Rhs> for Series
300where
301    Rhs: NumericNative,
302{
303    type Item = PolarsResult<BooleanChunked>;
304
305    fn equal(&self, rhs: Rhs) -> Self::Item {
306        validate_types(self.dtype(), &DataType::Int8)?;
307        let s = self.to_physical_repr();
308        Ok(apply_method_physical_numeric!(&s, equal, rhs))
309    }
310
311    fn equal_missing(&self, rhs: Rhs) -> Self::Item {
312        validate_types(self.dtype(), &DataType::Int8)?;
313        let s = self.to_physical_repr();
314        Ok(apply_method_physical_numeric!(&s, equal_missing, rhs))
315    }
316
317    fn not_equal(&self, rhs: Rhs) -> Self::Item {
318        validate_types(self.dtype(), &DataType::Int8)?;
319        let s = self.to_physical_repr();
320        Ok(apply_method_physical_numeric!(&s, not_equal, rhs))
321    }
322
323    fn not_equal_missing(&self, rhs: Rhs) -> Self::Item {
324        validate_types(self.dtype(), &DataType::Int8)?;
325        let s = self.to_physical_repr();
326        Ok(apply_method_physical_numeric!(&s, not_equal_missing, rhs))
327    }
328}
329
330impl<Rhs> ChunkCompareIneq<Rhs> for Series
331where
332    Rhs: NumericNative,
333{
334    type Item = PolarsResult<BooleanChunked>;
335
336    fn gt(&self, rhs: Rhs) -> Self::Item {
337        validate_types(self.dtype(), &DataType::Int8)?;
338        let s = self.to_physical_repr();
339        Ok(apply_method_physical_numeric!(&s, gt, rhs))
340    }
341
342    fn gt_eq(&self, rhs: Rhs) -> Self::Item {
343        validate_types(self.dtype(), &DataType::Int8)?;
344        let s = self.to_physical_repr();
345        Ok(apply_method_physical_numeric!(&s, gt_eq, rhs))
346    }
347
348    fn lt(&self, rhs: Rhs) -> Self::Item {
349        validate_types(self.dtype(), &DataType::Int8)?;
350        let s = self.to_physical_repr();
351        Ok(apply_method_physical_numeric!(&s, lt, rhs))
352    }
353
354    fn lt_eq(&self, rhs: Rhs) -> Self::Item {
355        validate_types(self.dtype(), &DataType::Int8)?;
356        let s = self.to_physical_repr();
357        Ok(apply_method_physical_numeric!(&s, lt_eq, rhs))
358    }
359}
360
361impl ChunkCompareEq<&str> for Series {
362    type Item = PolarsResult<BooleanChunked>;
363
364    fn equal(&self, rhs: &str) -> PolarsResult<BooleanChunked> {
365        validate_types(self.dtype(), &DataType::String)?;
366        match self.dtype() {
367            DataType::String => Ok(self.str().unwrap().equal(rhs)),
368            #[cfg(feature = "dtype-categorical")]
369            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
370                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
371                    self.cat::<$C>().unwrap().equal(rhs)
372                }),
373            ),
374            #[cfg(feature = "dtype-extension")]
375            DataType::Extension(_, _) => self.ext().unwrap().storage().equal(rhs),
376            _ => Ok(BooleanChunked::full(self.name().clone(), false, self.len())),
377        }
378    }
379
380    fn equal_missing(&self, rhs: &str) -> Self::Item {
381        validate_types(self.dtype(), &DataType::String)?;
382        match self.dtype() {
383            DataType::String => Ok(self.str().unwrap().equal_missing(rhs)),
384            #[cfg(feature = "dtype-categorical")]
385            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
386                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
387                    self.cat::<$C>().unwrap().equal_missing(rhs)
388                }),
389            ),
390            #[cfg(feature = "dtype-extension")]
391            DataType::Extension(_, _) => self.ext().unwrap().storage().equal_missing(rhs),
392            _ => Ok(replace_non_null(
393                self.name().clone(),
394                self.0.chunks(),
395                false,
396            )),
397        }
398    }
399
400    fn not_equal(&self, rhs: &str) -> PolarsResult<BooleanChunked> {
401        validate_types(self.dtype(), &DataType::String)?;
402        match self.dtype() {
403            DataType::String => Ok(self.str().unwrap().not_equal(rhs)),
404            #[cfg(feature = "dtype-categorical")]
405            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
406                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
407                    self.cat::<$C>().unwrap().not_equal(rhs)
408                }),
409            ),
410            #[cfg(feature = "dtype-extension")]
411            DataType::Extension(_, _) => self.ext().unwrap().storage().not_equal(rhs),
412            _ => Ok(BooleanChunked::full(self.name().clone(), true, self.len())),
413        }
414    }
415
416    fn not_equal_missing(&self, rhs: &str) -> Self::Item {
417        validate_types(self.dtype(), &DataType::String)?;
418        match self.dtype() {
419            DataType::String => Ok(self.str().unwrap().not_equal_missing(rhs)),
420            #[cfg(feature = "dtype-categorical")]
421            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
422                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
423                    self.cat::<$C>().unwrap().not_equal_missing(rhs)
424                }),
425            ),
426            #[cfg(feature = "dtype-extension")]
427            DataType::Extension(_, _) => self.ext().unwrap().storage().not_equal_missing(rhs),
428            _ => Ok(replace_non_null(self.name().clone(), self.0.chunks(), true)),
429        }
430    }
431}
432
433impl ChunkCompareIneq<&str> for Series {
434    type Item = PolarsResult<BooleanChunked>;
435
436    fn gt(&self, rhs: &str) -> Self::Item {
437        validate_types(self.dtype(), &DataType::String)?;
438        match self.dtype() {
439            DataType::String => Ok(self.str().unwrap().gt(rhs)),
440            #[cfg(feature = "dtype-categorical")]
441            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
442                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
443                    self.cat::<$C>().unwrap().gt(rhs)
444                }),
445            ),
446            #[cfg(feature = "dtype-extension")]
447            DataType::Extension(_, _) => self.ext().unwrap().storage().gt(rhs),
448            _ => polars_bail!(
449                ComputeError: "cannot compare str value to series of type {}", self.dtype(),
450            ),
451        }
452    }
453
454    fn gt_eq(&self, rhs: &str) -> Self::Item {
455        validate_types(self.dtype(), &DataType::String)?;
456        match self.dtype() {
457            DataType::String => Ok(self.str().unwrap().gt_eq(rhs)),
458            #[cfg(feature = "dtype-categorical")]
459            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
460                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
461                    self.cat::<$C>().unwrap().gt_eq(rhs)
462                }),
463            ),
464            #[cfg(feature = "dtype-extension")]
465            DataType::Extension(_, _) => self.ext().unwrap().storage().gt_eq(rhs),
466            _ => polars_bail!(
467                ComputeError: "cannot compare str value to series of type {}", self.dtype(),
468            ),
469        }
470    }
471
472    fn lt(&self, rhs: &str) -> Self::Item {
473        validate_types(self.dtype(), &DataType::String)?;
474        match self.dtype() {
475            DataType::String => Ok(self.str().unwrap().lt(rhs)),
476            #[cfg(feature = "dtype-categorical")]
477            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
478                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
479                    self.cat::<$C>().unwrap().lt(rhs)
480                }),
481            ),
482            #[cfg(feature = "dtype-extension")]
483            DataType::Extension(_, _) => self.ext().unwrap().storage().lt(rhs),
484            _ => polars_bail!(
485                ComputeError: "cannot compare str value to series of type {}", self.dtype(),
486            ),
487        }
488    }
489
490    fn lt_eq(&self, rhs: &str) -> Self::Item {
491        validate_types(self.dtype(), &DataType::String)?;
492        match self.dtype() {
493            DataType::String => Ok(self.str().unwrap().lt_eq(rhs)),
494            #[cfg(feature = "dtype-categorical")]
495            DataType::Categorical(_, _) | DataType::Enum(_, _) => Ok(
496                with_match_categorical_physical_type!(self.dtype().cat_physical().unwrap(), |$C| {
497                    self.cat::<$C>().unwrap().lt_eq(rhs)
498                }),
499            ),
500            #[cfg(feature = "dtype-extension")]
501            DataType::Extension(_, _) => self.ext().unwrap().storage().lt_eq(rhs),
502            _ => polars_bail!(
503                ComputeError: "cannot compare str value to series of type {}", self.dtype(),
504            ),
505        }
506    }
507}