polars_utils/
sort.rs

1use std::cmp::Ordering;
2use std::mem::MaybeUninit;
3
4use num_traits::FromPrimitive;
5use rayon::ThreadPool;
6use rayon::prelude::*;
7
8use crate::IdxSize;
9use crate::total_ord::TotalOrd;
10
11/// This is a perfect sort particularly useful for an arg_sort of an arg_sort
12/// The second arg_sort sorts indices from `0` to `len` so can be just assigned to the
13/// new index location.
14///
15/// Besides that we know that all indices are unique and thus not alias so we can parallelize.
16///
17/// This sort does not sort in place and will allocate.
18///
19/// - The right indices are used for sorting
20/// - The left indices are placed at the location right points to.
21///
22/// # Safety
23/// The caller must ensure that the right indexes for `&[(_, IdxSize)]` are integers ranging from `0..idx.len`
24#[cfg(any(target_os = "emscripten", not(target_family = "wasm")))]
25pub unsafe fn perfect_sort(pool: &ThreadPool, idx: &[(IdxSize, IdxSize)], out: &mut Vec<IdxSize>) {
26    let chunk_size = std::cmp::max(
27        idx.len() / pool.current_num_threads(),
28        pool.current_num_threads(),
29    );
30
31    out.reserve(idx.len());
32    let ptr = out.as_mut_ptr() as *const IdxSize as usize;
33
34    pool.install(|| {
35        idx.par_chunks(chunk_size).for_each(|indices| {
36            let ptr = ptr as *mut IdxSize;
37            for (idx_val, idx_location) in indices {
38                // SAFETY:
39                // idx_location is in bounds by invariant of this function
40                // and we ensured we have at least `idx.len()` capacity
41                unsafe { *ptr.add(*idx_location as usize) = *idx_val };
42            }
43        });
44    });
45    // SAFETY:
46    // all elements are written
47    unsafe { out.set_len(idx.len()) };
48}
49
50// wasm alternative with different signature
51#[cfg(all(not(target_os = "emscripten"), target_family = "wasm"))]
52pub unsafe fn perfect_sort(
53    pool: &crate::wasm::Pool,
54    idx: &[(IdxSize, IdxSize)],
55    out: &mut Vec<IdxSize>,
56) {
57    let chunk_size = std::cmp::max(
58        idx.len() / pool.current_num_threads(),
59        pool.current_num_threads(),
60    );
61
62    out.reserve(idx.len());
63    let ptr = out.as_mut_ptr() as *const IdxSize as usize;
64
65    pool.install(|| {
66        idx.par_chunks(chunk_size).for_each(|indices| {
67            let ptr = ptr as *mut IdxSize;
68            for (idx_val, idx_location) in indices {
69                // SAFETY:
70                // idx_location is in bounds by invariant of this function
71                // and we ensured we have at least `idx.len()` capacity
72                *ptr.add(*idx_location as usize) = *idx_val;
73            }
74        });
75    });
76    // SAFETY:
77    // all elements are written
78    out.set_len(idx.len());
79}
80
81unsafe fn assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
82    unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
83}
84
85pub fn arg_sort_ascending<'a, T: TotalOrd + Copy + 'a, Idx, I: IntoIterator<Item = T>>(
86    v: I,
87    scratch: &'a mut Vec<u8>,
88    n: usize,
89) -> &'a mut [Idx]
90where
91    Idx: FromPrimitive + Copy,
92{
93    // Needed to be able to write back to back in the same buffer.
94    debug_assert_eq!(align_of::<T>(), align_of::<(T, Idx)>());
95    let size = size_of::<(T, Idx)>();
96    let upper_bound = size * n + size;
97    scratch.reserve(upper_bound);
98    let scratch_slice = unsafe {
99        let cap_slice = scratch.spare_capacity_mut();
100        let (_, scratch_slice, _) = cap_slice.align_to_mut::<MaybeUninit<(T, Idx)>>();
101        &mut scratch_slice[..n]
102    };
103
104    for ((i, v), dst) in v.into_iter().enumerate().zip(scratch_slice.iter_mut()) {
105        *dst = MaybeUninit::new((v, Idx::from_usize(i).unwrap()));
106    }
107    debug_assert_eq!(n, scratch_slice.len());
108
109    let scratch_slice = unsafe { assume_init_mut(scratch_slice) };
110    scratch_slice.sort_by(|key1, key2| key1.0.tot_cmp(&key2.0));
111
112    // now we write the indexes in the same array.
113    // So from <T, Idxsize> to <IdxSize>
114    unsafe {
115        let src = scratch_slice.as_ptr();
116
117        let (_, scratch_slice_aligned_to_idx, _) = scratch_slice.align_to_mut::<Idx>();
118
119        let dst = scratch_slice_aligned_to_idx.as_mut_ptr();
120
121        for i in 0..n {
122            dst.add(i).write((*src.add(i)).1);
123        }
124
125        &mut scratch_slice_aligned_to_idx[..n]
126    }
127}
128
129#[derive(PartialEq, Eq, Clone, Hash)]
130#[repr(transparent)]
131pub struct ReorderWithNulls<T, const DESCENDING: bool, const NULLS_LAST: bool>(pub Option<T>);
132
133impl<T: PartialOrd, const DESCENDING: bool, const NULLS_LAST: bool> PartialOrd
134    for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
135{
136    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
137        match (&self.0, &other.0) {
138            (None, None) => Some(Ordering::Equal),
139            (None, Some(_)) => {
140                if NULLS_LAST {
141                    Some(Ordering::Greater)
142                } else {
143                    Some(Ordering::Less)
144                }
145            },
146            (Some(_), None) => {
147                if NULLS_LAST {
148                    Some(Ordering::Less)
149                } else {
150                    Some(Ordering::Greater)
151                }
152            },
153            (Some(l), Some(r)) => {
154                if DESCENDING {
155                    r.partial_cmp(l)
156                } else {
157                    l.partial_cmp(r)
158                }
159            },
160        }
161    }
162}
163
164impl<T: Ord, const DESCENDING: bool, const NULLS_LAST: bool> Ord
165    for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
166{
167    fn cmp(&self, other: &Self) -> Ordering {
168        match (&self.0, &other.0) {
169            (None, None) => Ordering::Equal,
170            (None, Some(_)) => {
171                if NULLS_LAST {
172                    Ordering::Greater
173                } else {
174                    Ordering::Less
175                }
176            },
177            (Some(_), None) => {
178                if NULLS_LAST {
179                    Ordering::Less
180                } else {
181                    Ordering::Greater
182                }
183            },
184            (Some(l), Some(r)) => {
185                if DESCENDING {
186                    r.cmp(l)
187                } else {
188                    l.cmp(r)
189                }
190            },
191        }
192    }
193}