Skip to main content

polars_utils/
sort.rs

1use std::cmp::Ordering;
2use std::mem::MaybeUninit;
3use std::ops::Deref;
4
5use num_traits::FromPrimitive;
6
7use crate::nulls::IsNull;
8use crate::total_ord::TotalOrd;
9
10unsafe fn assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
11    unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
12}
13
14pub fn arg_sort_ascending<'a, T: TotalOrd + Copy + 'a, Idx, I: IntoIterator<Item = T>>(
15    v: I,
16    scratch: &'a mut Vec<u8>,
17    n: usize,
18) -> &'a mut [Idx]
19where
20    Idx: FromPrimitive + Copy,
21{
22    // Needed to be able to write back to back in the same buffer.
23    debug_assert_eq!(align_of::<T>(), align_of::<(T, Idx)>());
24    let size = size_of::<(T, Idx)>();
25    let upper_bound = size * n + size;
26    scratch.reserve(upper_bound);
27    let scratch_slice = unsafe {
28        let cap_slice = scratch.spare_capacity_mut();
29        let (_, scratch_slice, _) = cap_slice.align_to_mut::<MaybeUninit<(T, Idx)>>();
30        &mut scratch_slice[..n]
31    };
32
33    for ((i, v), dst) in v.into_iter().enumerate().zip(scratch_slice.iter_mut()) {
34        *dst = MaybeUninit::new((v, Idx::from_usize(i).unwrap()));
35    }
36    debug_assert_eq!(n, scratch_slice.len());
37
38    let scratch_slice = unsafe { assume_init_mut(scratch_slice) };
39    scratch_slice.sort_by(|key1, key2| key1.0.tot_cmp(&key2.0));
40
41    // now we write the indexes in the same array.
42    // So from <T, Idxsize> to <IdxSize>
43    unsafe {
44        let src = scratch_slice.as_ptr();
45
46        let (_, scratch_slice_aligned_to_idx, _) = scratch_slice.align_to_mut::<Idx>();
47
48        let dst = scratch_slice_aligned_to_idx.as_mut_ptr();
49
50        for i in 0..n {
51            dst.add(i).write((*src.add(i)).1);
52        }
53
54        &mut scratch_slice_aligned_to_idx[..n]
55    }
56}
57
58#[derive(PartialEq, Eq, Clone, Hash)]
59#[repr(transparent)]
60pub struct ReorderWithNulls<T, const DESCENDING: bool, const NULLS_LAST: bool>(pub Option<T>);
61
62impl<T, const DESCENDING: bool, const NULLS_LAST: bool>
63    ReorderWithNulls<T, DESCENDING, NULLS_LAST>
64{
65    pub fn as_deref(&self) -> ReorderWithNulls<&<T as Deref>::Target, DESCENDING, NULLS_LAST>
66    where
67        T: Deref,
68    {
69        let x = self.0.as_deref();
70        ReorderWithNulls(x)
71    }
72}
73
74impl<T: PartialOrd, const DESCENDING: bool, const NULLS_LAST: bool> PartialOrd
75    for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
76{
77    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
78        match (&self.0, &other.0) {
79            (None, None) => Some(Ordering::Equal),
80            (None, Some(_)) => {
81                if NULLS_LAST {
82                    Some(Ordering::Greater)
83                } else {
84                    Some(Ordering::Less)
85                }
86            },
87            (Some(_), None) => {
88                if NULLS_LAST {
89                    Some(Ordering::Less)
90                } else {
91                    Some(Ordering::Greater)
92                }
93            },
94            (Some(l), Some(r)) => {
95                if DESCENDING {
96                    r.partial_cmp(l)
97                } else {
98                    l.partial_cmp(r)
99                }
100            },
101        }
102    }
103}
104
105impl<T: Ord, const DESCENDING: bool, const NULLS_LAST: bool> Ord
106    for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
107{
108    fn cmp(&self, other: &Self) -> Ordering {
109        reorder_cmp(&self.0, &other.0, DESCENDING, NULLS_LAST)
110    }
111}
112
113/// Compare two values with support for sort direction and nulls position.
114///
115/// # Panics
116///
117/// Panics if `T::partial_cmp(lhs, rhs)` returns `None`.
118#[inline]
119pub fn reorder_cmp<T: PartialOrd + IsNull>(
120    lhs: &T,
121    rhs: &T,
122    descending: bool,
123    nulls_last: bool,
124) -> Ordering {
125    match PartialOrd::partial_cmp(lhs, rhs).expect("expected total ordering") {
126        Ordering::Equal => Ordering::Equal,
127        _ if lhs.is_null() && nulls_last => Ordering::Greater,
128        _ if rhs.is_null() && nulls_last => Ordering::Less,
129        _ if lhs.is_null() => Ordering::Less,
130        _ if rhs.is_null() => Ordering::Greater,
131        ord if descending => ord.reverse(),
132        ord => ord,
133    }
134}