1use std::cmp::Ordering;
2use std::mem::MaybeUninit;
3use std::ops::Deref;
4
5use num_traits::FromPrimitive;
6
7use crate::nulls::IsNull;
8use crate::total_ord::TotalOrd;
9
10unsafe fn assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
11 unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
12}
13
14pub fn arg_sort_ascending<'a, T: TotalOrd + Copy + 'a, Idx, I: IntoIterator<Item = T>>(
15 v: I,
16 scratch: &'a mut Vec<u8>,
17 n: usize,
18) -> &'a mut [Idx]
19where
20 Idx: FromPrimitive + Copy,
21{
22 debug_assert_eq!(align_of::<T>(), align_of::<(T, Idx)>());
24 let size = size_of::<(T, Idx)>();
25 let upper_bound = size * n + size;
26 scratch.reserve(upper_bound);
27 let scratch_slice = unsafe {
28 let cap_slice = scratch.spare_capacity_mut();
29 let (_, scratch_slice, _) = cap_slice.align_to_mut::<MaybeUninit<(T, Idx)>>();
30 &mut scratch_slice[..n]
31 };
32
33 for ((i, v), dst) in v.into_iter().enumerate().zip(scratch_slice.iter_mut()) {
34 *dst = MaybeUninit::new((v, Idx::from_usize(i).unwrap()));
35 }
36 debug_assert_eq!(n, scratch_slice.len());
37
38 let scratch_slice = unsafe { assume_init_mut(scratch_slice) };
39 scratch_slice.sort_by(|key1, key2| key1.0.tot_cmp(&key2.0));
40
41 unsafe {
44 let src = scratch_slice.as_ptr();
45
46 let (_, scratch_slice_aligned_to_idx, _) = scratch_slice.align_to_mut::<Idx>();
47
48 let dst = scratch_slice_aligned_to_idx.as_mut_ptr();
49
50 for i in 0..n {
51 dst.add(i).write((*src.add(i)).1);
52 }
53
54 &mut scratch_slice_aligned_to_idx[..n]
55 }
56}
57
58#[derive(PartialEq, Eq, Clone, Hash)]
59#[repr(transparent)]
60pub struct ReorderWithNulls<T, const DESCENDING: bool, const NULLS_LAST: bool>(pub Option<T>);
61
62impl<T, const DESCENDING: bool, const NULLS_LAST: bool>
63 ReorderWithNulls<T, DESCENDING, NULLS_LAST>
64{
65 pub fn as_deref(&self) -> ReorderWithNulls<&<T as Deref>::Target, DESCENDING, NULLS_LAST>
66 where
67 T: Deref,
68 {
69 let x = self.0.as_deref();
70 ReorderWithNulls(x)
71 }
72}
73
74impl<T: PartialOrd, const DESCENDING: bool, const NULLS_LAST: bool> PartialOrd
75 for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
76{
77 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
78 match (&self.0, &other.0) {
79 (None, None) => Some(Ordering::Equal),
80 (None, Some(_)) => {
81 if NULLS_LAST {
82 Some(Ordering::Greater)
83 } else {
84 Some(Ordering::Less)
85 }
86 },
87 (Some(_), None) => {
88 if NULLS_LAST {
89 Some(Ordering::Less)
90 } else {
91 Some(Ordering::Greater)
92 }
93 },
94 (Some(l), Some(r)) => {
95 if DESCENDING {
96 r.partial_cmp(l)
97 } else {
98 l.partial_cmp(r)
99 }
100 },
101 }
102 }
103}
104
105impl<T: Ord, const DESCENDING: bool, const NULLS_LAST: bool> Ord
106 for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
107{
108 fn cmp(&self, other: &Self) -> Ordering {
109 reorder_cmp(&self.0, &other.0, DESCENDING, NULLS_LAST)
110 }
111}
112
113#[inline]
119pub fn reorder_cmp<T: PartialOrd + IsNull>(
120 lhs: &T,
121 rhs: &T,
122 descending: bool,
123 nulls_last: bool,
124) -> Ordering {
125 match PartialOrd::partial_cmp(lhs, rhs).expect("expected total ordering") {
126 Ordering::Equal => Ordering::Equal,
127 _ if lhs.is_null() && nulls_last => Ordering::Greater,
128 _ if rhs.is_null() && nulls_last => Ordering::Less,
129 _ if lhs.is_null() => Ordering::Less,
130 _ if rhs.is_null() => Ordering::Greater,
131 ord if descending => ord.reverse(),
132 ord => ord,
133 }
134}