polars_utils/
total_ord.rs

1use std::cmp::Ordering;
2use std::hash::{BuildHasher, Hash, Hasher};
3
4use bytemuck::TransparentWrapper;
5
6use crate::hashing::{BytesHash, DirtyHash};
7use crate::nulls::IsNull;
8
9/// Converts an f32 into a canonical form, where -0 == 0 and all NaNs map to
10/// the same value.
11#[inline]
12pub fn canonical_f32(x: f32) -> f32 {
13    // -0.0 + 0.0 becomes 0.0.
14    let convert_zero = x + 0.0;
15    if convert_zero.is_nan() {
16        f32::from_bits(0x7fc00000) // Canonical quiet NaN.
17    } else {
18        convert_zero
19    }
20}
21
22/// Converts an f64 into a canonical form, where -0 == 0 and all NaNs map to
23/// the same value.
24#[inline]
25pub fn canonical_f64(x: f64) -> f64 {
26    // -0.0 + 0.0 becomes 0.0.
27    let convert_zero = x + 0.0;
28    if convert_zero.is_nan() {
29        f64::from_bits(0x7ff8000000000000) // Canonical quiet NaN.
30    } else {
31        convert_zero
32    }
33}
34
35/// Alternative trait for Eq. By consistently using this we can still be
36/// generic w.r.t Eq while getting a total ordering for floats.
37pub trait TotalEq {
38    fn tot_eq(&self, other: &Self) -> bool;
39
40    #[inline]
41    fn tot_ne(&self, other: &Self) -> bool {
42        !(self.tot_eq(other))
43    }
44}
45
46/// Alternative trait for Ord. By consistently using this we can still be
47/// generic w.r.t Ord while getting a total ordering for floats.
48pub trait TotalOrd: TotalEq {
49    fn tot_cmp(&self, other: &Self) -> Ordering;
50
51    #[inline]
52    fn tot_lt(&self, other: &Self) -> bool {
53        self.tot_cmp(other) == Ordering::Less
54    }
55
56    #[inline]
57    fn tot_gt(&self, other: &Self) -> bool {
58        self.tot_cmp(other) == Ordering::Greater
59    }
60
61    #[inline]
62    fn tot_le(&self, other: &Self) -> bool {
63        self.tot_cmp(other) != Ordering::Greater
64    }
65
66    #[inline]
67    fn tot_ge(&self, other: &Self) -> bool {
68        self.tot_cmp(other) != Ordering::Less
69    }
70}
71
72/// Alternative trait for Hash. By consistently using this we can still be
73/// generic w.r.t Hash while being able to hash floats.
74pub trait TotalHash {
75    fn tot_hash<H>(&self, state: &mut H)
76    where
77        H: Hasher;
78
79    fn tot_hash_slice<H>(data: &[Self], state: &mut H)
80    where
81        H: Hasher,
82        Self: Sized,
83    {
84        for piece in data {
85            piece.tot_hash(state)
86        }
87    }
88}
89
90pub trait BuildHasherTotalExt: BuildHasher {
91    fn tot_hash_one<T>(&self, x: T) -> u64
92    where
93        T: TotalHash,
94        Self: Sized,
95        <Self as BuildHasher>::Hasher: Hasher,
96    {
97        let mut hasher = self.build_hasher();
98        x.tot_hash(&mut hasher);
99        hasher.finish()
100    }
101}
102
103impl<T: BuildHasher> BuildHasherTotalExt for T {}
104
105#[derive(Debug)]
106#[cfg_attr(
107    feature = "serde",
108    derive(serde::Serialize, serde::Deserialize),
109    serde(transparent)
110)]
111#[cfg_attr(
112    feature = "dsl-schema",
113    derive(schemars::JsonSchema),
114    schemars(transparent)
115)]
116#[repr(transparent)]
117pub struct TotalOrdWrap<T>(pub T);
118unsafe impl<T> TransparentWrapper<T> for TotalOrdWrap<T> {}
119
120impl<T: TotalOrd> PartialOrd for TotalOrdWrap<T> {
121    #[inline(always)]
122    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
123        Some(self.cmp(other))
124    }
125
126    #[inline(always)]
127    fn lt(&self, other: &Self) -> bool {
128        self.0.tot_lt(&other.0)
129    }
130
131    #[inline(always)]
132    fn le(&self, other: &Self) -> bool {
133        self.0.tot_le(&other.0)
134    }
135
136    #[inline(always)]
137    fn gt(&self, other: &Self) -> bool {
138        self.0.tot_gt(&other.0)
139    }
140
141    #[inline(always)]
142    fn ge(&self, other: &Self) -> bool {
143        self.0.tot_ge(&other.0)
144    }
145}
146
147impl<T: TotalOrd> Ord for TotalOrdWrap<T> {
148    #[inline(always)]
149    fn cmp(&self, other: &Self) -> Ordering {
150        self.0.tot_cmp(&other.0)
151    }
152}
153
154impl<T: TotalEq> PartialEq for TotalOrdWrap<T> {
155    #[inline(always)]
156    fn eq(&self, other: &Self) -> bool {
157        self.0.tot_eq(&other.0)
158    }
159
160    #[inline(always)]
161    #[allow(clippy::partialeq_ne_impl)]
162    fn ne(&self, other: &Self) -> bool {
163        self.0.tot_ne(&other.0)
164    }
165}
166
167impl<T: TotalEq> Eq for TotalOrdWrap<T> {}
168
169impl<T: TotalHash> Hash for TotalOrdWrap<T> {
170    #[inline(always)]
171    fn hash<H: Hasher>(&self, state: &mut H) {
172        self.0.tot_hash(state);
173    }
174}
175
176impl<T: Clone> Clone for TotalOrdWrap<T> {
177    #[inline]
178    fn clone(&self) -> Self {
179        Self(self.0.clone())
180    }
181}
182
183impl<T: Copy> Copy for TotalOrdWrap<T> {}
184
185impl<T: IsNull> IsNull for TotalOrdWrap<T> {
186    const HAS_NULLS: bool = T::HAS_NULLS;
187    type Inner = T::Inner;
188
189    #[inline(always)]
190    fn is_null(&self) -> bool {
191        self.0.is_null()
192    }
193
194    #[inline(always)]
195    fn unwrap_inner(self) -> Self::Inner {
196        self.0.unwrap_inner()
197    }
198}
199
200impl DirtyHash for f32 {
201    #[inline(always)]
202    fn dirty_hash(&self) -> u64 {
203        canonical_f32(*self).to_bits().dirty_hash()
204    }
205}
206
207impl DirtyHash for f64 {
208    #[inline(always)]
209    fn dirty_hash(&self) -> u64 {
210        canonical_f64(*self).to_bits().dirty_hash()
211    }
212}
213
214impl<T: DirtyHash> DirtyHash for TotalOrdWrap<T> {
215    #[inline(always)]
216    fn dirty_hash(&self) -> u64 {
217        self.0.dirty_hash()
218    }
219}
220
221macro_rules! impl_trivial_total {
222    ($T: ty) => {
223        impl TotalEq for $T {
224            #[inline(always)]
225            fn tot_eq(&self, other: &Self) -> bool {
226                self == other
227            }
228
229            #[inline(always)]
230            fn tot_ne(&self, other: &Self) -> bool {
231                self != other
232            }
233        }
234
235        impl TotalOrd for $T {
236            #[inline(always)]
237            fn tot_cmp(&self, other: &Self) -> Ordering {
238                self.cmp(other)
239            }
240
241            #[inline(always)]
242            fn tot_lt(&self, other: &Self) -> bool {
243                self < other
244            }
245
246            #[inline(always)]
247            fn tot_gt(&self, other: &Self) -> bool {
248                self > other
249            }
250
251            #[inline(always)]
252            fn tot_le(&self, other: &Self) -> bool {
253                self <= other
254            }
255
256            #[inline(always)]
257            fn tot_ge(&self, other: &Self) -> bool {
258                self >= other
259            }
260        }
261
262        impl TotalHash for $T {
263            #[inline(always)]
264            fn tot_hash<H>(&self, state: &mut H)
265            where
266                H: Hasher,
267            {
268                self.hash(state);
269            }
270        }
271    };
272}
273
274// We can't do a blanket impl because Rust complains f32 might implement
275// Ord / Eq someday.
276impl_trivial_total!(());
277impl_trivial_total!(bool);
278impl_trivial_total!(u8);
279impl_trivial_total!(u16);
280impl_trivial_total!(u32);
281impl_trivial_total!(u64);
282impl_trivial_total!(u128);
283impl_trivial_total!(usize);
284impl_trivial_total!(i8);
285impl_trivial_total!(i16);
286impl_trivial_total!(i32);
287impl_trivial_total!(i64);
288impl_trivial_total!(i128);
289impl_trivial_total!(isize);
290impl_trivial_total!(char);
291impl_trivial_total!(&str);
292impl_trivial_total!(&[u8]);
293impl_trivial_total!(String);
294
295macro_rules! impl_float_eq_ord {
296    ($T:ty) => {
297        impl TotalEq for $T {
298            #[inline]
299            fn tot_eq(&self, other: &Self) -> bool {
300                if self.is_nan() {
301                    other.is_nan()
302                } else {
303                    self == other
304                }
305            }
306        }
307
308        impl TotalOrd for $T {
309            #[inline(always)]
310            fn tot_cmp(&self, other: &Self) -> Ordering {
311                if self.tot_lt(other) {
312                    Ordering::Less
313                } else if self.tot_gt(other) {
314                    Ordering::Greater
315                } else {
316                    Ordering::Equal
317                }
318            }
319
320            #[inline(always)]
321            fn tot_lt(&self, other: &Self) -> bool {
322                !self.tot_ge(other)
323            }
324
325            #[inline(always)]
326            fn tot_gt(&self, other: &Self) -> bool {
327                other.tot_lt(self)
328            }
329
330            #[inline(always)]
331            fn tot_le(&self, other: &Self) -> bool {
332                other.tot_ge(self)
333            }
334
335            #[inline(always)]
336            fn tot_ge(&self, other: &Self) -> bool {
337                // We consider all NaNs equal, and NaN is the largest possible
338                // value. Thus if self is NaN we always return true. Otherwise
339                // self >= other is correct. If other is not NaN it is trivially
340                // correct, and if it is we note that nothing can be greater or
341                // equal to NaN except NaN itself, which we already handled earlier.
342                self.is_nan() | (self >= other)
343            }
344        }
345    };
346}
347
348impl_float_eq_ord!(f32);
349impl_float_eq_ord!(f64);
350
351impl TotalHash for f32 {
352    #[inline(always)]
353    fn tot_hash<H>(&self, state: &mut H)
354    where
355        H: Hasher,
356    {
357        canonical_f32(*self).to_bits().hash(state)
358    }
359}
360
361impl TotalHash for f64 {
362    #[inline(always)]
363    fn tot_hash<H>(&self, state: &mut H)
364    where
365        H: Hasher,
366    {
367        canonical_f64(*self).to_bits().hash(state)
368    }
369}
370
371// Blanket implementations.
372impl<T: TotalEq> TotalEq for Option<T> {
373    #[inline(always)]
374    fn tot_eq(&self, other: &Self) -> bool {
375        match (self, other) {
376            (None, None) => true,
377            (Some(a), Some(b)) => a.tot_eq(b),
378            _ => false,
379        }
380    }
381
382    #[inline(always)]
383    fn tot_ne(&self, other: &Self) -> bool {
384        match (self, other) {
385            (None, None) => false,
386            (Some(a), Some(b)) => a.tot_ne(b),
387            _ => true,
388        }
389    }
390}
391
392impl<T: TotalOrd> TotalOrd for Option<T> {
393    #[inline(always)]
394    fn tot_cmp(&self, other: &Self) -> Ordering {
395        match (self, other) {
396            (None, None) => Ordering::Equal,
397            (None, Some(_)) => Ordering::Less,
398            (Some(_), None) => Ordering::Greater,
399            (Some(a), Some(b)) => a.tot_cmp(b),
400        }
401    }
402
403    #[inline(always)]
404    fn tot_lt(&self, other: &Self) -> bool {
405        match (self, other) {
406            (None, Some(_)) => true,
407            (Some(a), Some(b)) => a.tot_lt(b),
408            _ => false,
409        }
410    }
411
412    #[inline(always)]
413    fn tot_gt(&self, other: &Self) -> bool {
414        other.tot_lt(self)
415    }
416
417    #[inline(always)]
418    fn tot_le(&self, other: &Self) -> bool {
419        match (self, other) {
420            (Some(_), None) => false,
421            (Some(a), Some(b)) => a.tot_lt(b),
422            _ => true,
423        }
424    }
425
426    #[inline(always)]
427    fn tot_ge(&self, other: &Self) -> bool {
428        other.tot_le(self)
429    }
430}
431
432impl<T: TotalHash> TotalHash for Option<T> {
433    #[inline]
434    fn tot_hash<H>(&self, state: &mut H)
435    where
436        H: Hasher,
437    {
438        self.is_some().tot_hash(state);
439        if let Some(slf) = self {
440            slf.tot_hash(state)
441        }
442    }
443}
444
445impl<T: TotalEq + ?Sized> TotalEq for &T {
446    #[inline(always)]
447    fn tot_eq(&self, other: &Self) -> bool {
448        (*self).tot_eq(*other)
449    }
450
451    #[inline(always)]
452    fn tot_ne(&self, other: &Self) -> bool {
453        (*self).tot_ne(*other)
454    }
455}
456
457impl<T: TotalHash + ?Sized> TotalHash for &T {
458    #[inline(always)]
459    fn tot_hash<H>(&self, state: &mut H)
460    where
461        H: Hasher,
462    {
463        (*self).tot_hash(state)
464    }
465}
466
467impl<T: TotalEq, U: TotalEq> TotalEq for (T, U) {
468    #[inline]
469    fn tot_eq(&self, other: &Self) -> bool {
470        self.0.tot_eq(&other.0) && self.1.tot_eq(&other.1)
471    }
472}
473
474impl<T: TotalOrd, U: TotalOrd> TotalOrd for (T, U) {
475    #[inline]
476    fn tot_cmp(&self, other: &Self) -> Ordering {
477        self.0
478            .tot_cmp(&other.0)
479            .then_with(|| self.1.tot_cmp(&other.1))
480    }
481}
482
483impl TotalHash for BytesHash<'_> {
484    #[inline(always)]
485    fn tot_hash<H>(&self, state: &mut H)
486    where
487        H: Hasher,
488    {
489        self.hash(state)
490    }
491}
492
493impl TotalEq for BytesHash<'_> {
494    #[inline(always)]
495    fn tot_eq(&self, other: &Self) -> bool {
496        self == other
497    }
498}
499
500/// This elides creating a [`TotalOrdWrap`] for types that don't need it.
501pub trait ToTotalOrd {
502    type TotalOrdItem: Hash + Eq;
503    type SourceItem;
504
505    fn to_total_ord(&self) -> Self::TotalOrdItem;
506
507    fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem;
508}
509
510macro_rules! impl_to_total_ord_identity {
511    ($T: ty) => {
512        impl ToTotalOrd for $T {
513            type TotalOrdItem = $T;
514            type SourceItem = $T;
515
516            #[inline]
517            fn to_total_ord(&self) -> Self::TotalOrdItem {
518                self.clone()
519            }
520
521            #[inline]
522            fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem {
523                ord_item
524            }
525        }
526    };
527}
528
529impl_to_total_ord_identity!(bool);
530impl_to_total_ord_identity!(u8);
531impl_to_total_ord_identity!(u16);
532impl_to_total_ord_identity!(u32);
533impl_to_total_ord_identity!(u64);
534impl_to_total_ord_identity!(u128);
535impl_to_total_ord_identity!(usize);
536impl_to_total_ord_identity!(i8);
537impl_to_total_ord_identity!(i16);
538impl_to_total_ord_identity!(i32);
539impl_to_total_ord_identity!(i64);
540impl_to_total_ord_identity!(i128);
541impl_to_total_ord_identity!(isize);
542impl_to_total_ord_identity!(char);
543impl_to_total_ord_identity!(String);
544
545macro_rules! impl_to_total_ord_lifetimed_ref_identity {
546    ($T: ty) => {
547        impl<'a> ToTotalOrd for &'a $T {
548            type TotalOrdItem = &'a $T;
549            type SourceItem = &'a $T;
550
551            #[inline]
552            fn to_total_ord(&self) -> Self::TotalOrdItem {
553                *self
554            }
555
556            #[inline]
557            fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem {
558                ord_item
559            }
560        }
561    };
562}
563
564impl_to_total_ord_lifetimed_ref_identity!(str);
565impl_to_total_ord_lifetimed_ref_identity!([u8]);
566
567macro_rules! impl_to_total_ord_wrapped {
568    ($T: ty) => {
569        impl ToTotalOrd for $T {
570            type TotalOrdItem = TotalOrdWrap<$T>;
571            type SourceItem = $T;
572
573            #[inline]
574            fn to_total_ord(&self) -> Self::TotalOrdItem {
575                TotalOrdWrap(self.clone())
576            }
577
578            #[inline]
579            fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem {
580                ord_item.0
581            }
582        }
583    };
584}
585
586impl_to_total_ord_wrapped!(f32);
587impl_to_total_ord_wrapped!(f64);
588
589/// This is safe without needing to map the option value to TotalOrdWrap, since
590/// for example:
591/// `TotalOrdWrap<Option<T>>` implements `Eq + Hash`, iff:
592/// `Option<T>` implements `TotalEq + TotalHash`, iff:
593/// `T` implements `TotalEq + TotalHash`
594impl<T: Copy + TotalEq + TotalHash> ToTotalOrd for Option<T> {
595    type TotalOrdItem = TotalOrdWrap<Option<T>>;
596    type SourceItem = Option<T>;
597
598    #[inline]
599    fn to_total_ord(&self) -> Self::TotalOrdItem {
600        TotalOrdWrap(*self)
601    }
602
603    #[inline]
604    fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem {
605        ord_item.0
606    }
607}
608
609impl<T: ToTotalOrd> ToTotalOrd for &T {
610    type TotalOrdItem = T::TotalOrdItem;
611    type SourceItem = T::SourceItem;
612
613    #[inline]
614    fn to_total_ord(&self) -> Self::TotalOrdItem {
615        (*self).to_total_ord()
616    }
617
618    #[inline]
619    fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem {
620        T::peel_total_ord(ord_item)
621    }
622}
623
624impl<'a> ToTotalOrd for BytesHash<'a> {
625    type TotalOrdItem = BytesHash<'a>;
626    type SourceItem = BytesHash<'a>;
627
628    #[inline]
629    fn to_total_ord(&self) -> Self::TotalOrdItem {
630        *self
631    }
632
633    #[inline]
634    fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem {
635        ord_item
636    }
637}