Skip to main content

polars_utils/
sort.rs

1use std::cmp::Ordering;
2use std::mem::MaybeUninit;
3use std::ops::Deref;
4
5use num_traits::FromPrimitive;
6
7use crate::total_ord::TotalOrd;
8
9unsafe fn assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
10    unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
11}
12
13pub fn arg_sort_ascending<'a, T: TotalOrd + Copy + 'a, Idx, I: IntoIterator<Item = T>>(
14    v: I,
15    scratch: &'a mut Vec<u8>,
16    n: usize,
17) -> &'a mut [Idx]
18where
19    Idx: FromPrimitive + Copy,
20{
21    // Needed to be able to write back to back in the same buffer.
22    debug_assert_eq!(align_of::<T>(), align_of::<(T, Idx)>());
23    let size = size_of::<(T, Idx)>();
24    let upper_bound = size * n + size;
25    scratch.reserve(upper_bound);
26    let scratch_slice = unsafe {
27        let cap_slice = scratch.spare_capacity_mut();
28        let (_, scratch_slice, _) = cap_slice.align_to_mut::<MaybeUninit<(T, Idx)>>();
29        &mut scratch_slice[..n]
30    };
31
32    for ((i, v), dst) in v.into_iter().enumerate().zip(scratch_slice.iter_mut()) {
33        *dst = MaybeUninit::new((v, Idx::from_usize(i).unwrap()));
34    }
35    debug_assert_eq!(n, scratch_slice.len());
36
37    let scratch_slice = unsafe { assume_init_mut(scratch_slice) };
38    scratch_slice.sort_by(|key1, key2| key1.0.tot_cmp(&key2.0));
39
40    // now we write the indexes in the same array.
41    // So from <T, Idxsize> to <IdxSize>
42    unsafe {
43        let src = scratch_slice.as_ptr();
44
45        let (_, scratch_slice_aligned_to_idx, _) = scratch_slice.align_to_mut::<Idx>();
46
47        let dst = scratch_slice_aligned_to_idx.as_mut_ptr();
48
49        for i in 0..n {
50            dst.add(i).write((*src.add(i)).1);
51        }
52
53        &mut scratch_slice_aligned_to_idx[..n]
54    }
55}
56
57#[derive(PartialEq, Eq, Clone, Hash)]
58#[repr(transparent)]
59pub struct ReorderWithNulls<T, const DESCENDING: bool, const NULLS_LAST: bool>(pub Option<T>);
60
61impl<T, const DESCENDING: bool, const NULLS_LAST: bool>
62    ReorderWithNulls<T, DESCENDING, NULLS_LAST>
63{
64    pub fn as_deref(&self) -> ReorderWithNulls<&<T as Deref>::Target, DESCENDING, NULLS_LAST>
65    where
66        T: Deref,
67    {
68        let x = self.0.as_deref();
69        ReorderWithNulls(x)
70    }
71}
72
73impl<T: PartialOrd, const DESCENDING: bool, const NULLS_LAST: bool> PartialOrd
74    for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
75{
76    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
77        match (&self.0, &other.0) {
78            (None, None) => Some(Ordering::Equal),
79            (None, Some(_)) => {
80                if NULLS_LAST {
81                    Some(Ordering::Greater)
82                } else {
83                    Some(Ordering::Less)
84                }
85            },
86            (Some(_), None) => {
87                if NULLS_LAST {
88                    Some(Ordering::Less)
89                } else {
90                    Some(Ordering::Greater)
91                }
92            },
93            (Some(l), Some(r)) => {
94                if DESCENDING {
95                    r.partial_cmp(l)
96                } else {
97                    l.partial_cmp(r)
98                }
99            },
100        }
101    }
102}
103
104impl<T: Ord, const DESCENDING: bool, const NULLS_LAST: bool> Ord
105    for ReorderWithNulls<T, DESCENDING, NULLS_LAST>
106{
107    fn cmp(&self, other: &Self) -> Ordering {
108        match (&self.0, &other.0) {
109            (None, None) => Ordering::Equal,
110            (None, Some(_)) => {
111                if NULLS_LAST {
112                    Ordering::Greater
113                } else {
114                    Ordering::Less
115                }
116            },
117            (Some(_), None) => {
118                if NULLS_LAST {
119                    Ordering::Less
120                } else {
121                    Ordering::Greater
122                }
123            },
124            (Some(l), Some(r)) => {
125                if DESCENDING {
126                    r.cmp(l)
127                } else {
128                    l.cmp(r)
129                }
130            },
131        }
132    }
133}