polars_utils/
sort.rs

1use std::cmp::Ordering;
2use std::mem::MaybeUninit;
3
4use num_traits::FromPrimitive;
5
6use crate::total_ord::TotalOrd;
7
8unsafe fn assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
9    unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
10}
11
12pub fn arg_sort_ascending<'a, T: TotalOrd + Copy + 'a, Idx, I: IntoIterator<Item = T>>(
13    v: I,
14    scratch: &'a mut Vec<u8>,
15    n: usize,
16) -> &'a mut [Idx]
17where
18    Idx: FromPrimitive + Copy,
19{
20    // Needed to be able to write back to back in the same buffer.
21    debug_assert_eq!(align_of::<T>(), align_of::<(T, Idx)>());
22    let size = size_of::<(T, Idx)>();
23    let upper_bound = size * n + size;
24    scratch.reserve(upper_bound);
25    let scratch_slice = unsafe {
26        let cap_slice = scratch.spare_capacity_mut();
27        let (_, scratch_slice, _) = cap_slice.align_to_mut::<MaybeUninit<(T, Idx)>>();
28        &mut scratch_slice[..n]
29    };
30
31    for ((i, v), dst) in v.into_iter().enumerate().zip(scratch_slice.iter_mut()) {
32        *dst = MaybeUninit::new((v, Idx::from_usize(i).unwrap()));
33    }
34    debug_assert_eq!(n, scratch_slice.len());
35
36    let scratch_slice = unsafe { assume_init_mut(scratch_slice) };
37    scratch_slice.sort_by(|key1, key2| key1.0.tot_cmp(&key2.0));
38
39    // now we write the indexes in the same array.
40    // So from <T, Idxsize> to <IdxSize>
41    unsafe {
42        let src = scratch_slice.as_ptr();
43
44        let (_, scratch_slice_aligned_to_idx, _) = scratch_slice.align_to_mut::<Idx>();
45
46        let dst = scratch_slice_aligned_to_idx.as_mut_ptr();
47
48        for i in 0..n {
49            dst.add(i).write((*src.add(i)).1);
50        }
51
52        &mut scratch_slice_aligned_to_idx[..n]
53    }
54}
55
56#[derive(PartialEq, Eq, Clone, Hash)]
57#[repr(transparent)]
58pub struct ReorderWithNulls<T, const DESCENDING: bool, const NULLS_LAST: bool>(pub Option<T>);
59
60impl<T: PartialOrd, const DESCENDING: bool, const NULLS_LAST: bool> PartialOrd
61    for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
62{
63    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
64        match (&self.0, &other.0) {
65            (None, None) => Some(Ordering::Equal),
66            (None, Some(_)) => {
67                if NULLS_LAST {
68                    Some(Ordering::Greater)
69                } else {
70                    Some(Ordering::Less)
71                }
72            },
73            (Some(_), None) => {
74                if NULLS_LAST {
75                    Some(Ordering::Less)
76                } else {
77                    Some(Ordering::Greater)
78                }
79            },
80            (Some(l), Some(r)) => {
81                if DESCENDING {
82                    r.partial_cmp(l)
83                } else {
84                    l.partial_cmp(r)
85                }
86            },
87        }
88    }
89}
90
91impl<T: Ord, const DESCENDING: bool, const NULLS_LAST: bool> Ord
92    for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
93{
94    fn cmp(&self, other: &Self) -> Ordering {
95        match (&self.0, &other.0) {
96            (None, None) => Ordering::Equal,
97            (None, Some(_)) => {
98                if NULLS_LAST {
99                    Ordering::Greater
100                } else {
101                    Ordering::Less
102                }
103            },
104            (Some(_), None) => {
105                if NULLS_LAST {
106                    Ordering::Less
107                } else {
108                    Ordering::Greater
109                }
110            },
111            (Some(l), Some(r)) => {
112                if DESCENDING {
113                    r.cmp(l)
114                } else {
115                    l.cmp(r)
116                }
117            },
118        }
119    }
120}