1use std::cmp::Ordering;
2use std::mem::MaybeUninit;
3
4use num_traits::FromPrimitive;
5use rayon::ThreadPool;
6use rayon::prelude::*;
7
8use crate::IdxSize;
9use crate::total_ord::TotalOrd;
10
11#[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
25pub unsafe fn perfect_sort(pool: &ThreadPool, idx: &[(IdxSize, IdxSize)], out: &mut Vec<IdxSize>) {
26 let chunk_size = std::cmp::max(
27 idx.len() / pool.current_num_threads(),
28 pool.current_num_threads(),
29 );
30
31 out.reserve(idx.len());
32 let ptr = out.as_mut_ptr() as *const IdxSize as usize;
33
34 pool.install(|| {
35 idx.par_chunks(chunk_size).for_each(|indices| {
36 let ptr = ptr as *mut IdxSize;
37 for (idx_val, idx_location) in indices {
38 unsafe { *ptr.add(*idx_location as usize) = *idx_val };
42 }
43 });
44 });
45 unsafe { out.set_len(idx.len()) };
48}
49
50#[cfg(all(not(target_os = "emscripten"), target_family = "wasm"))]
52pub unsafe fn perfect_sort(
53 pool: &crate::wasm::Pool,
54 idx: &[(IdxSize, IdxSize)],
55 out: &mut Vec<IdxSize>,
56) {
57 let chunk_size = std::cmp::max(
58 idx.len() / pool.current_num_threads(),
59 pool.current_num_threads(),
60 );
61
62 out.reserve(idx.len());
63 let ptr = out.as_mut_ptr() as *const IdxSize as usize;
64
65 pool.install(|| {
66 idx.par_chunks(chunk_size).for_each(|indices| {
67 let ptr = ptr as *mut IdxSize;
68 for (idx_val, idx_location) in indices {
69 *ptr.add(*idx_location as usize) = *idx_val;
73 }
74 });
75 });
76 out.set_len(idx.len());
79}
80
81unsafe fn assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
82 unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
83}
84
85pub fn arg_sort_ascending<'a, T: TotalOrd + Copy + 'a, Idx, I: IntoIterator<Item = T>>(
86 v: I,
87 scratch: &'a mut Vec<u8>,
88 n: usize,
89) -> &'a mut [Idx]
90where
91 Idx: FromPrimitive + Copy,
92{
93 debug_assert_eq!(align_of::<T>(), align_of::<(T, Idx)>());
95 let size = size_of::<(T, Idx)>();
96 let upper_bound = size * n + size;
97 scratch.reserve(upper_bound);
98 let scratch_slice = unsafe {
99 let cap_slice = scratch.spare_capacity_mut();
100 let (_, scratch_slice, _) = cap_slice.align_to_mut::<MaybeUninit<(T, Idx)>>();
101 &mut scratch_slice[..n]
102 };
103
104 for ((i, v), dst) in v.into_iter().enumerate().zip(scratch_slice.iter_mut()) {
105 *dst = MaybeUninit::new((v, Idx::from_usize(i).unwrap()));
106 }
107 debug_assert_eq!(n, scratch_slice.len());
108
109 let scratch_slice = unsafe { assume_init_mut(scratch_slice) };
110 scratch_slice.sort_by(|key1, key2| key1.0.tot_cmp(&key2.0));
111
112 unsafe {
115 let src = scratch_slice.as_ptr();
116
117 let (_, scratch_slice_aligned_to_idx, _) = scratch_slice.align_to_mut::<Idx>();
118
119 let dst = scratch_slice_aligned_to_idx.as_mut_ptr();
120
121 for i in 0..n {
122 dst.add(i).write((*src.add(i)).1);
123 }
124
125 &mut scratch_slice_aligned_to_idx[..n]
126 }
127}
128
129#[derive(PartialEq, Eq, Clone, Hash)]
130#[repr(transparent)]
131pub struct ReorderWithNulls<T, const DESCENDING: bool, const NULLS_LAST: bool>(pub Option<T>);
132
133impl<T: PartialOrd, const DESCENDING: bool, const NULLS_LAST: bool> PartialOrd
134 for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
135{
136 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
137 match (&self.0, &other.0) {
138 (None, None) => Some(Ordering::Equal),
139 (None, Some(_)) => {
140 if NULLS_LAST {
141 Some(Ordering::Greater)
142 } else {
143 Some(Ordering::Less)
144 }
145 },
146 (Some(_), None) => {
147 if NULLS_LAST {
148 Some(Ordering::Less)
149 } else {
150 Some(Ordering::Greater)
151 }
152 },
153 (Some(l), Some(r)) => {
154 if DESCENDING {
155 r.partial_cmp(l)
156 } else {
157 l.partial_cmp(r)
158 }
159 },
160 }
161 }
162}
163
164impl<T: Ord, const DESCENDING: bool, const NULLS_LAST: bool> Ord
165 for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
166{
167 fn cmp(&self, other: &Self) -> Ordering {
168 match (&self.0, &other.0) {
169 (None, None) => Ordering::Equal,
170 (None, Some(_)) => {
171 if NULLS_LAST {
172 Ordering::Greater
173 } else {
174 Ordering::Less
175 }
176 },
177 (Some(_), None) => {
178 if NULLS_LAST {
179 Ordering::Less
180 } else {
181 Ordering::Greater
182 }
183 },
184 (Some(l), Some(r)) => {
185 if DESCENDING {
186 r.cmp(l)
187 } else {
188 l.cmp(r)
189 }
190 },
191 }
192 }
193}