polars_core/series/
comparison.rs

1//! Comparison operations on Series.
2
3use polars_error::feature_gated;
4
5use crate::prelude::*;
6use crate::series::arithmetic::coerce_lhs_rhs;
7use crate::series::nulls::replace_non_null;
8
9macro_rules! impl_eq_compare {
10    ($self:expr, $rhs:expr, $method:ident) => {{
11        use DataType::*;
12        let (lhs, rhs) = ($self, $rhs);
13        validate_types(lhs.dtype(), rhs.dtype())?;
14
15        polars_ensure!(
16            lhs.len() == rhs.len() ||
17
18            // Broadcast
19            lhs.len() == 1 ||
20            rhs.len() == 1,
21            ShapeMismatch: "could not compare between two series of different length ({} != {})",
22            lhs.len(),
23            rhs.len()
24        );
25
26        #[cfg(feature = "dtype-categorical")]
27        match (lhs.dtype(), rhs.dtype()) {
28            (Categorical(_, _) | Enum(_, _), Categorical(_, _) | Enum(_, _)) => {
29                return Ok(lhs
30                    .categorical()
31                    .unwrap()
32                    .$method(rhs.categorical().unwrap())?
33                    .with_name(lhs.name().clone()));
34            },
35            (Categorical(_, _) | Enum(_, _), String) => {
36                return Ok(lhs
37                    .categorical()
38                    .unwrap()
39                    .$method(rhs.str().unwrap())?
40                    .with_name(lhs.name().clone()));
41            },
42            (String, Categorical(_, _) | Enum(_, _)) => {
43                return Ok(rhs
44                    .categorical()
45                    .unwrap()
46                    .$method(lhs.str().unwrap())?
47                    .with_name(lhs.name().clone()));
48            },
49            _ => (),
50        };
51
52        let (lhs, rhs) = coerce_lhs_rhs(lhs, rhs)
53            .map_err(|_| polars_err!(
54                    SchemaMismatch: "could not evaluate comparison between series '{}' of dtype: {} and series '{}' of dtype: {}",
55                    lhs.name(), lhs.dtype(), rhs.name(), rhs.dtype()
56            ))?;
57        let lhs = lhs.to_physical_repr();
58        let rhs = rhs.to_physical_repr();
59        let mut out = match lhs.dtype() {
60            Null => lhs.null().unwrap().$method(rhs.null().unwrap()),
61            Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()),
62            String => lhs.str().unwrap().$method(rhs.str().unwrap()),
63            Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()),
64            UInt8 => feature_gated!("dtype-u8", lhs.u8().unwrap().$method(rhs.u8().unwrap())),
65            UInt16 => feature_gated!("dtype-u16", lhs.u16().unwrap().$method(rhs.u16().unwrap())),
66            UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()),
67            UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()),
68            Int8 => feature_gated!("dtype-i8", lhs.i8().unwrap().$method(rhs.i8().unwrap())),
69            Int16 => feature_gated!("dtype-i16", lhs.i16().unwrap().$method(rhs.i16().unwrap())),
70            Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()),
71            Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()),
72            Int128 => feature_gated!("dtype-i128", lhs.i128().unwrap().$method(rhs.i128().unwrap())),
73            Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()),
74            Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()),
75            List(_) => lhs.list().unwrap().$method(rhs.list().unwrap()),
76            #[cfg(feature = "dtype-array")]
77            Array(_, _) => lhs.array().unwrap().$method(rhs.array().unwrap()),
78            #[cfg(feature = "dtype-struct")]
79            Struct(_) => lhs.struct_().unwrap().$method(rhs.struct_().unwrap()),
80
81            dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()),
82        };
83        out.rename(lhs.name().clone());
84        PolarsResult::Ok(out)
85    }};
86}
87
88macro_rules! bail_invalid_ineq {
89    ($lhs:expr, $rhs:expr, $op:literal) => {
90        polars_bail!(
91            InvalidOperation: "cannot perform '{}' comparison between series '{}' of dtype: {} and series '{}' of dtype: {}",
92            $op,
93            $lhs.name(), $lhs.dtype(),
94            $rhs.name(), $rhs.dtype(),
95        )
96    };
97}
98
99macro_rules! impl_ineq_compare {
100    ($self:expr, $rhs:expr, $method:ident, $op:literal) => {{
101        use DataType::*;
102        let (lhs, rhs) = ($self, $rhs);
103        validate_types(lhs.dtype(), rhs.dtype())?;
104
105        polars_ensure!(
106            lhs.len() == rhs.len() ||
107
108            // Broadcast
109            lhs.len() == 1 ||
110            rhs.len() == 1,
111            ShapeMismatch:
112                "could not perform '{}' comparison between series '{}' of length: {} and series '{}' of length: {}, because they have different lengths",
113            $op,
114            lhs.name(), lhs.len(),
115            rhs.name(), rhs.len()
116        );
117
118        #[cfg(feature = "dtype-categorical")]
119        match (lhs.dtype(), rhs.dtype()) {
120            (Categorical(_, _) | Enum(_, _), Categorical(_, _) | Enum(_, _)) => {
121                return Ok(lhs
122                    .categorical()
123                    .unwrap()
124                    .$method(rhs.categorical().unwrap())?
125                    .with_name(lhs.name().clone()));
126            },
127            (Categorical(_, _) | Enum(_, _), String) => {
128                return Ok(lhs
129                    .categorical()
130                    .unwrap()
131                    .$method(rhs.str().unwrap())?
132                    .with_name(lhs.name().clone()));
133            },
134            (String, Categorical(_, _) | Enum(_, _)) => {
135                return Ok(rhs
136                    .categorical()
137                    .unwrap()
138                    .$method(lhs.str().unwrap())?
139                    .with_name(lhs.name().clone()));
140            },
141            _ => (),
142        };
143
144        let (lhs, rhs) = coerce_lhs_rhs(lhs, rhs).map_err(|_|
145            polars_err!(
146                SchemaMismatch: "could not evaluate '{}' comparison between series '{}' of dtype: {} and series '{}' of dtype: {}",
147                $op,
148                lhs.name(), lhs.dtype(),
149                rhs.name(), rhs.dtype()
150            )
151        )?;
152        let lhs = lhs.to_physical_repr();
153        let rhs = rhs.to_physical_repr();
154        let mut out = match lhs.dtype() {
155            Null => lhs.null().unwrap().$method(rhs.null().unwrap()),
156            Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()),
157            String => lhs.str().unwrap().$method(rhs.str().unwrap()),
158            Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()),
159            UInt8 => feature_gated!("dtype-u8", lhs.u8().unwrap().$method(rhs.u8().unwrap())),
160            UInt16 => feature_gated!("dtype-u16", lhs.u16().unwrap().$method(rhs.u16().unwrap())),
161            UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()),
162            UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()),
163            Int8 => feature_gated!("dtype-i8", lhs.i8().unwrap().$method(rhs.i8().unwrap())),
164            Int16 => feature_gated!("dtype-i16", lhs.i16().unwrap().$method(rhs.i16().unwrap())),
165            Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()),
166            Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()),
167            Int128 => feature_gated!("dtype-i128", lhs.i128().unwrap().$method(rhs.i128().unwrap())),
168            Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()),
169            Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()),
170            List(_) => bail_invalid_ineq!(lhs, rhs, $op),
171            #[cfg(feature = "dtype-array")]
172            Array(_, _) => bail_invalid_ineq!(lhs, rhs, $op),
173            #[cfg(feature = "dtype-struct")]
174            Struct(_) => bail_invalid_ineq!(lhs, rhs, $op),
175
176            dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()),
177        };
178        out.rename(lhs.name().clone());
179        PolarsResult::Ok(out)
180    }};
181}
182
183fn validate_types(left: &DataType, right: &DataType) -> PolarsResult<()> {
184    use DataType::*;
185
186    match (left, right) {
187        (String, dt) | (dt, String) if dt.is_primitive_numeric() => {
188            polars_bail!(ComputeError: "cannot compare string with numeric type ({})", dt)
189        },
190        #[cfg(feature = "dtype-categorical")]
191        (Categorical(_, _) | Enum(_, _), dt) | (dt, Categorical(_, _) | Enum(_, _))
192            if !(dt.is_categorical() | dt.is_string() | dt.is_enum()) =>
193        {
194            polars_bail!(ComputeError: "cannot compare categorical with {}", dt);
195        },
196        _ => (),
197    };
198    Ok(())
199}
200
201impl ChunkCompareEq<&Series> for Series {
202    type Item = PolarsResult<BooleanChunked>;
203
204    /// Create a boolean mask by checking for equality.
205    fn equal(&self, rhs: &Series) -> Self::Item {
206        impl_eq_compare!(self, rhs, equal)
207    }
208
209    /// Create a boolean mask by checking for equality.
210    fn equal_missing(&self, rhs: &Series) -> Self::Item {
211        impl_eq_compare!(self, rhs, equal_missing)
212    }
213
214    /// Create a boolean mask by checking for inequality.
215    fn not_equal(&self, rhs: &Series) -> Self::Item {
216        impl_eq_compare!(self, rhs, not_equal)
217    }
218
219    /// Create a boolean mask by checking for inequality.
220    fn not_equal_missing(&self, rhs: &Series) -> Self::Item {
221        impl_eq_compare!(self, rhs, not_equal_missing)
222    }
223}
224
225impl ChunkCompareIneq<&Series> for Series {
226    type Item = PolarsResult<BooleanChunked>;
227
228    /// Create a boolean mask by checking if self > rhs.
229    fn gt(&self, rhs: &Series) -> Self::Item {
230        impl_ineq_compare!(self, rhs, gt, ">")
231    }
232
233    /// Create a boolean mask by checking if self >= rhs.
234    fn gt_eq(&self, rhs: &Series) -> Self::Item {
235        impl_ineq_compare!(self, rhs, gt_eq, ">=")
236    }
237
238    /// Create a boolean mask by checking if self < rhs.
239    fn lt(&self, rhs: &Series) -> Self::Item {
240        impl_ineq_compare!(self, rhs, lt, "<")
241    }
242
243    /// Create a boolean mask by checking if self <= rhs.
244    fn lt_eq(&self, rhs: &Series) -> Self::Item {
245        impl_ineq_compare!(self, rhs, lt_eq, "<=")
246    }
247}
248
249impl<Rhs> ChunkCompareEq<Rhs> for Series
250where
251    Rhs: NumericNative,
252{
253    type Item = PolarsResult<BooleanChunked>;
254
255    fn equal(&self, rhs: Rhs) -> Self::Item {
256        validate_types(self.dtype(), &DataType::Int8)?;
257        let s = self.to_physical_repr();
258        Ok(apply_method_physical_numeric!(&s, equal, rhs))
259    }
260
261    fn equal_missing(&self, rhs: Rhs) -> Self::Item {
262        validate_types(self.dtype(), &DataType::Int8)?;
263        let s = self.to_physical_repr();
264        Ok(apply_method_physical_numeric!(&s, equal_missing, rhs))
265    }
266
267    fn not_equal(&self, rhs: Rhs) -> Self::Item {
268        validate_types(self.dtype(), &DataType::Int8)?;
269        let s = self.to_physical_repr();
270        Ok(apply_method_physical_numeric!(&s, not_equal, rhs))
271    }
272
273    fn not_equal_missing(&self, rhs: Rhs) -> Self::Item {
274        validate_types(self.dtype(), &DataType::Int8)?;
275        let s = self.to_physical_repr();
276        Ok(apply_method_physical_numeric!(&s, not_equal_missing, rhs))
277    }
278}
279
280impl<Rhs> ChunkCompareIneq<Rhs> for Series
281where
282    Rhs: NumericNative,
283{
284    type Item = PolarsResult<BooleanChunked>;
285
286    fn gt(&self, rhs: Rhs) -> Self::Item {
287        validate_types(self.dtype(), &DataType::Int8)?;
288        let s = self.to_physical_repr();
289        Ok(apply_method_physical_numeric!(&s, gt, rhs))
290    }
291
292    fn gt_eq(&self, rhs: Rhs) -> Self::Item {
293        validate_types(self.dtype(), &DataType::Int8)?;
294        let s = self.to_physical_repr();
295        Ok(apply_method_physical_numeric!(&s, gt_eq, rhs))
296    }
297
298    fn lt(&self, rhs: Rhs) -> Self::Item {
299        validate_types(self.dtype(), &DataType::Int8)?;
300        let s = self.to_physical_repr();
301        Ok(apply_method_physical_numeric!(&s, lt, rhs))
302    }
303
304    fn lt_eq(&self, rhs: Rhs) -> Self::Item {
305        validate_types(self.dtype(), &DataType::Int8)?;
306        let s = self.to_physical_repr();
307        Ok(apply_method_physical_numeric!(&s, lt_eq, rhs))
308    }
309}
310
311impl ChunkCompareEq<&str> for Series {
312    type Item = PolarsResult<BooleanChunked>;
313
314    fn equal(&self, rhs: &str) -> PolarsResult<BooleanChunked> {
315        validate_types(self.dtype(), &DataType::String)?;
316        match self.dtype() {
317            DataType::String => Ok(self.str().unwrap().equal(rhs)),
318            #[cfg(feature = "dtype-categorical")]
319            DataType::Categorical(_, _) | DataType::Enum(_, _) => {
320                self.categorical().unwrap().equal(rhs)
321            },
322            _ => Ok(BooleanChunked::full(self.name().clone(), false, self.len())),
323        }
324    }
325
326    fn equal_missing(&self, rhs: &str) -> Self::Item {
327        validate_types(self.dtype(), &DataType::String)?;
328        match self.dtype() {
329            DataType::String => Ok(self.str().unwrap().equal_missing(rhs)),
330            #[cfg(feature = "dtype-categorical")]
331            DataType::Categorical(_, _) | DataType::Enum(_, _) => {
332                self.categorical().unwrap().equal_missing(rhs)
333            },
334            _ => Ok(replace_non_null(
335                self.name().clone(),
336                self.0.chunks(),
337                false,
338            )),
339        }
340    }
341
342    fn not_equal(&self, rhs: &str) -> PolarsResult<BooleanChunked> {
343        validate_types(self.dtype(), &DataType::String)?;
344        match self.dtype() {
345            DataType::String => Ok(self.str().unwrap().not_equal(rhs)),
346            #[cfg(feature = "dtype-categorical")]
347            DataType::Categorical(_, _) | DataType::Enum(_, _) => {
348                self.categorical().unwrap().not_equal(rhs)
349            },
350            _ => Ok(BooleanChunked::full(self.name().clone(), true, self.len())),
351        }
352    }
353
354    fn not_equal_missing(&self, rhs: &str) -> Self::Item {
355        validate_types(self.dtype(), &DataType::String)?;
356        match self.dtype() {
357            DataType::String => Ok(self.str().unwrap().not_equal_missing(rhs)),
358            #[cfg(feature = "dtype-categorical")]
359            DataType::Categorical(_, _) | DataType::Enum(_, _) => {
360                self.categorical().unwrap().not_equal_missing(rhs)
361            },
362            _ => Ok(replace_non_null(self.name().clone(), self.0.chunks(), true)),
363        }
364    }
365}
366
367impl ChunkCompareIneq<&str> for Series {
368    type Item = PolarsResult<BooleanChunked>;
369
370    fn gt(&self, rhs: &str) -> Self::Item {
371        validate_types(self.dtype(), &DataType::String)?;
372        match self.dtype() {
373            DataType::String => Ok(self.str().unwrap().gt(rhs)),
374            #[cfg(feature = "dtype-categorical")]
375            DataType::Categorical(_, _) | DataType::Enum(_, _) => {
376                self.categorical().unwrap().gt(rhs)
377            },
378            _ => polars_bail!(
379                ComputeError: "cannot compare str value to series of type {}", self.dtype(),
380            ),
381        }
382    }
383
384    fn gt_eq(&self, rhs: &str) -> Self::Item {
385        validate_types(self.dtype(), &DataType::String)?;
386        match self.dtype() {
387            DataType::String => Ok(self.str().unwrap().gt_eq(rhs)),
388            #[cfg(feature = "dtype-categorical")]
389            DataType::Categorical(_, _) | DataType::Enum(_, _) => {
390                self.categorical().unwrap().gt_eq(rhs)
391            },
392            _ => polars_bail!(
393                ComputeError: "cannot compare str value to series of type {}", self.dtype(),
394            ),
395        }
396    }
397
398    fn lt(&self, rhs: &str) -> Self::Item {
399        validate_types(self.dtype(), &DataType::String)?;
400        match self.dtype() {
401            DataType::String => Ok(self.str().unwrap().lt(rhs)),
402            #[cfg(feature = "dtype-categorical")]
403            DataType::Categorical(_, _) | DataType::Enum(_, _) => {
404                self.categorical().unwrap().lt(rhs)
405            },
406            _ => polars_bail!(
407                ComputeError: "cannot compare str value to series of type {}", self.dtype(),
408            ),
409        }
410    }
411
412    fn lt_eq(&self, rhs: &str) -> Self::Item {
413        validate_types(self.dtype(), &DataType::String)?;
414        match self.dtype() {
415            DataType::String => Ok(self.str().unwrap().lt_eq(rhs)),
416            #[cfg(feature = "dtype-categorical")]
417            DataType::Categorical(_, _) | DataType::Enum(_, _) => {
418                self.categorical().unwrap().lt_eq(rhs)
419            },
420            _ => polars_bail!(
421                ComputeError: "cannot compare str value to series of type {}", self.dtype(),
422            ),
423        }
424    }
425}