Skip to main content

polars_core/chunked_array/ops/sort/
arg_sort.rs

1use polars_utils::itertools::Itertools;
2
3use self::row_encode::_get_rows_encoded;
4use super::*;
5
6// Reduce monomorphisation.
7fn sort_impl<T>(vals: &mut [(IdxSize, T)], options: SortOptions)
8where
9    T: TotalOrd + Send + Sync,
10{
11    sort_by_branch(
12        vals,
13        options.descending,
14        |a, b| a.1.tot_cmp(&b.1),
15        options.multithreaded,
16    );
17}
18// Compute the indexes after reversing a sorted array, maintaining
19// the order of equal elements, in linear time. Faster than sort_impl
20//  as we avoid allocating extra memory.
21pub(super) fn reverse_stable_no_nulls<I, J, T>(iters: I, len: usize) -> Vec<IdxSize>
22where
23    I: IntoIterator<Item = J>,
24    J: IntoIterator<Item = T>,
25    T: TotalOrd + Send + Sync,
26{
27    let mut current_start: IdxSize = 0;
28    let mut current_end: IdxSize = 0;
29    let mut rev_idx: Vec<IdxSize> = Vec::with_capacity(len);
30    let mut i: IdxSize;
31    // We traverse the array, comparing consecutive elements.
32    // We maintain the start and end indice of elements with same value.
33    // When we see a new element we push the previous indices in reverse order.
34    // We do a final reverse to get stable reverse index.
35    // Example -
36    // 1 2 2 3 3 3 4
37    // 0 1 2 3 4 5 6
38    // We get start and end position of equal values -
39    // 0 1-2 3-5 6
40    // We insert the indexes of equal elements in reverse
41    // 0 2 1 5 4 3 6
42    // Then do a final reverse
43    // 6 3 4 5 1 2 0
44    let mut previous_element: Option<T> = None;
45    for arr_iter in iters {
46        for current_element in arr_iter {
47            match &previous_element {
48                None => {
49                    //There is atleast one element
50                    current_end = 1;
51                },
52                Some(prev) => {
53                    if current_element.tot_cmp(prev) == Ordering::Equal {
54                        current_end += 1;
55                    } else {
56                        // Insert in reverse order
57                        i = current_end;
58                        while i > current_start {
59                            i -= 1;
60                            //SAFETY - we allocated enough
61                            unsafe { rev_idx.push_unchecked(i) };
62                        }
63                        current_start = current_end;
64                        current_end += 1;
65                    }
66                },
67            }
68            previous_element = Some(current_element);
69        }
70    }
71    // If there are no elements this does nothing
72    i = current_end;
73    while i > current_start {
74        i -= 1;
75        unsafe { rev_idx.push_unchecked(i) };
76    }
77    // Final reverse
78    rev_idx.reverse();
79    rev_idx
80}
81
82pub(super) fn arg_sort<I, J, T>(
83    name: PlSmallStr,
84    iters: I,
85    options: SortOptions,
86    null_count: usize,
87    mut len: usize,
88    is_sorted_flag: IsSorted,
89    first_element_null: bool,
90) -> IdxCa
91where
92    I: IntoIterator<Item = J>,
93    J: IntoIterator<Item = Option<T>>,
94    T: TotalOrd + Send + Sync,
95{
96    let nulls_last = options.nulls_last;
97    let null_cap = if nulls_last { null_count } else { len };
98
99    // Fast path
100    // Only if array is already sorted in the required ordered and
101    // nulls are also in the correct position
102    if ((options.descending && is_sorted_flag == IsSorted::Descending)
103        || (!options.descending && is_sorted_flag == IsSorted::Ascending))
104        && ((nulls_last && !first_element_null) || (!nulls_last && first_element_null))
105    {
106        len = options
107            .limit
108            .map_or(len, |limit| std::cmp::min(limit.try_into().unwrap(), len));
109        return ChunkedArray::with_chunk(
110            name,
111            IdxArr::from_data_default(
112                Buffer::from((0..(len as IdxSize)).collect::<Vec<IdxSize>>()),
113                None,
114            ),
115        );
116    }
117
118    let mut vals = Vec::with_capacity(len - null_count);
119    let mut nulls_idx = Vec::with_capacity(null_cap);
120    let mut count: IdxSize = 0;
121
122    for arr_iter in iters {
123        let iter = arr_iter.into_iter().filter_map(|v| {
124            let i = count;
125            count += 1;
126            match v {
127                Some(v) => Some((i, v)),
128                None => {
129                    // SAFETY: we allocated enough.
130                    unsafe { nulls_idx.push_unchecked(i) };
131                    None
132                },
133            }
134        });
135        vals.extend(iter);
136    }
137
138    let vals = if let Some(limit) = options.limit {
139        let limit = limit as usize;
140        // Overwrite output len.
141        len = limit;
142        let out = if limit >= vals.len() {
143            vals.as_mut_slice()
144        } else {
145            let (lower, _el, _upper) = if options.descending {
146                vals.as_mut_slice()
147                    .select_nth_unstable_by(limit, |a, b| b.1.tot_cmp(&a.1))
148            } else {
149                vals.as_mut_slice()
150                    .select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1))
151            };
152            lower
153        };
154
155        sort_impl(out, options);
156        out
157    } else {
158        sort_impl(vals.as_mut_slice(), options);
159        vals.as_slice()
160    };
161
162    let iter = vals.iter().map(|(idx, _v)| idx).copied();
163    let idx = if nulls_last {
164        let mut idx = Vec::with_capacity(len);
165        idx.extend(iter);
166
167        let nulls_idx = if options.limit.is_some() {
168            &nulls_idx[..len - idx.len()]
169        } else {
170            &nulls_idx
171        };
172        idx.extend_from_slice(nulls_idx);
173        idx
174    } else if options.limit.is_some() {
175        nulls_idx.extend(iter.take(len - nulls_idx.len()));
176        nulls_idx
177    } else {
178        let ptr = nulls_idx.as_ptr() as usize;
179        nulls_idx.extend(iter);
180        // We had a realloc.
181        debug_assert_eq!(nulls_idx.as_ptr() as usize, ptr);
182        nulls_idx
183    };
184
185    ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None))
186}
187
188pub(super) fn arg_sort_no_nulls<I, J, T>(
189    name: PlSmallStr,
190    iters: I,
191    options: SortOptions,
192    len: usize,
193    is_sorted_flag: IsSorted,
194) -> IdxCa
195where
196    I: IntoIterator<Item = J>,
197    J: IntoIterator<Item = T>,
198    T: TotalOrd + Send + Sync,
199{
200    // Fast path
201    // 1) If array is already sorted in the required ordered .
202    // 2) If array is reverse sorted -> we do a stable reverse.
203    if is_sorted_flag != IsSorted::Not {
204        let len_final = options
205            .limit
206            .map_or(len, |limit| std::cmp::min(limit.try_into().unwrap(), len));
207        if (options.descending && is_sorted_flag == IsSorted::Descending)
208            || (!options.descending && is_sorted_flag == IsSorted::Ascending)
209        {
210            return ChunkedArray::with_chunk(
211                name,
212                IdxArr::from_data_default(
213                    Buffer::from((0..(len_final as IdxSize)).collect::<Vec<IdxSize>>()),
214                    None,
215                ),
216            );
217        } else if (options.descending && is_sorted_flag == IsSorted::Ascending)
218            || (!options.descending && is_sorted_flag == IsSorted::Descending)
219        {
220            let idx = reverse_stable_no_nulls(iters, len);
221            let idx = Buffer::from(idx).sliced(..len_final);
222            return ChunkedArray::with_chunk(name, IdxArr::from_data_default(idx, None));
223        }
224    }
225
226    let mut vals = Vec::with_capacity(len);
227
228    let mut count: IdxSize = 0;
229    for arr_iter in iters {
230        vals.extend(arr_iter.into_iter().map(|v| {
231            let idx = count;
232            count += 1;
233            (idx, v)
234        }));
235    }
236
237    let vals = if let Some(limit) = options.limit {
238        let limit = limit as usize;
239        let out = if limit >= vals.len() {
240            vals.as_mut_slice()
241        } else {
242            let (lower, _el, _upper) = if options.descending {
243                vals.as_mut_slice()
244                    .select_nth_unstable_by(limit, |a, b| b.1.tot_cmp(&a.1))
245            } else {
246                vals.as_mut_slice()
247                    .select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1))
248            };
249            lower
250        };
251        sort_impl(out, options);
252        out
253    } else {
254        sort_impl(vals.as_mut_slice(), options);
255        vals.as_slice()
256    };
257
258    let iter = vals.iter().map(|(idx, _v)| idx).copied();
259    let idx: Vec<_> = iter.collect_trusted();
260
261    ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None))
262}
263
264pub(crate) fn arg_sort_row_fmt(
265    by: &[Column],
266    descending: bool,
267    nulls_last: bool,
268    parallel: bool,
269) -> PolarsResult<IdxCa> {
270    let rows_encoded = _get_rows_encoded(by, &[descending], &[nulls_last])?;
271    let mut items: Vec<_> = rows_encoded.iter().enumerate_idx().collect();
272
273    if parallel {
274        RAYON.install(|| items.par_sort_by(|a, b| a.1.cmp(b.1)));
275    } else {
276        items.sort_by(|a, b| a.1.cmp(b.1));
277    }
278
279    let ca: NoNull<IdxCa> = items.into_iter().map(|tpl| tpl.0).collect();
280    Ok(ca.into_inner())
281}
282#[cfg(test)]
283mod test {
284    use sort::arg_sort::reverse_stable_no_nulls;
285
286    use crate::prelude::*;
287
288    #[test]
289    fn test_reverse_stable_no_nulls() {
290        let a = Int32Chunked::new(
291            PlSmallStr::from_static("a"),
292            &[
293                Some(1), // 0
294                Some(2), // 1
295                Some(2), // 2
296                Some(3), // 3
297                Some(3), // 4
298                Some(3), // 5
299                Some(4), // 6
300            ],
301        );
302        let idx = reverse_stable_no_nulls(a.iter(), 7);
303        let expected = [6, 3, 4, 5, 1, 2, 0];
304        assert_eq!(idx, expected);
305
306        let a = Int32Chunked::new(
307            PlSmallStr::from_static("a"),
308            &[
309                Some(1), // 0
310                Some(2), // 1
311                Some(3), // 2
312                Some(4), // 3
313                Some(5), // 4
314                Some(6), // 5
315                Some(7), // 6
316            ],
317        );
318        let idx = reverse_stable_no_nulls(a.iter(), 7);
319        let expected = [6, 5, 4, 3, 2, 1, 0];
320        assert_eq!(idx, expected);
321
322        let a = Int32Chunked::new(
323            PlSmallStr::from_static("a"),
324            &[
325                Some(1), // 0
326            ],
327        );
328        let idx = reverse_stable_no_nulls(a.iter(), 1);
329        let expected = [0];
330        assert_eq!(idx, expected);
331
332        let empty_array: [i32; 0] = [];
333        let a = Int32Chunked::new(PlSmallStr::from_static("a"), &empty_array);
334        let idx = reverse_stable_no_nulls(a.iter(), 0);
335        assert_eq!(idx.len(), 0);
336    }
337
338    #[test]
339    fn test_arg_sort_descending_with_limit() {
340        let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 2, 5, 1, 3]);
341        let o = SortOptions {
342            descending: true,
343            nulls_last: false,
344            multithreaded: false,
345            limit: Some(3),
346            ..Default::default()
347        };
348        let r = a.arg_sort(o);
349        let idx: Vec<IdxSize> = r.into_no_null_iter().collect();
350        assert_eq!(idx, vec![2, 0, 4]);
351    }
352
353    #[test]
354    fn test_arg_sort_asc_with_limit() {
355        let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 2, 5, 1, 3]);
356        let o = SortOptions {
357            descending: false,
358            nulls_last: false,
359            multithreaded: false,
360            limit: Some(3),
361            ..Default::default()
362        };
363        let r = a.arg_sort(o);
364        let idx: Vec<IdxSize> = r.into_no_null_iter().collect();
365        assert_eq!(idx, vec![3, 1, 4]);
366    }
367
368    #[test]
369    fn test_arg_sort_desc_limit_nulls() {
370        let a = Int32Chunked::new(
371            PlSmallStr::from_static("a"),
372            &[Some(4), None, Some(5), Some(1), None, Some(3)],
373        );
374        let o = SortOptions {
375            descending: true,
376            nulls_last: true,
377            multithreaded: false,
378            limit: Some(3),
379            ..Default::default()
380        };
381        let r = a.arg_sort(o);
382        let idx: Vec<IdxSize> = r.into_no_null_iter().collect();
383        assert_eq!(idx, vec![2, 0, 5]);
384    }
385}