Skip to main content

polars_core/chunked_array/ops/sort/
arg_sort_multiple.rs

1use polars_utils::itertools::Itertools;
2use polars_utils::total_ord::TotalOrdWrap;
3
4use super::*;
5use crate::chunked_array::ops::row_encode::_get_rows_encoded;
6
7pub(crate) fn args_validate<T: PolarsDataType>(
8    ca: &ChunkedArray<T>,
9    other: &[Column],
10    param_value: &[bool],
11    param_name: &str,
12) -> PolarsResult<()> {
13    for s in other {
14        assert_eq!(ca.len(), s.len());
15    }
16    polars_ensure!(other.len() == (param_value.len() - 1),
17        ComputeError:
18        "the length of `{}` ({}) does not match the number of series ({})",
19        param_name, param_value.len(), other.len() + 1,
20    );
21    Ok(())
22}
23
24pub(crate) fn arg_sort_multiple_impl<T: TotalOrd + IsNull + Send + Copy>(
25    mut vals: Vec<(IdxSize, T)>,
26    by: &[Column],
27    options: &SortMultipleOptions,
28) -> PolarsResult<IdxCa> {
29    let nulls_last = &options.nulls_last;
30    let descending = &options.descending;
31
32    debug_assert_eq!(descending.len() - 1, by.len());
33    debug_assert_eq!(nulls_last.len() - 1, by.len());
34
35    let compare_inner: Vec<_> = by
36        .iter()
37        .map(|c| c.into_total_ord_inner())
38        .collect_trusted();
39
40    let compare = move |tpl_a: &(_, T), tpl_b: &(_, T)| -> Ordering {
41        match reorder_cmp(
42            &TotalOrdWrap(tpl_a.1),
43            &TotalOrdWrap(tpl_b.1),
44            descending[0],
45            nulls_last[0],
46        ) {
47            // if ordering is equal, we check the other arrays until we find a non-equal ordering
48            // if we have exhausted all arrays, we keep the equal ordering.
49            Ordering::Equal => {
50                let idx_a = tpl_a.0 as usize;
51                let idx_b = tpl_b.0 as usize;
52                unsafe {
53                    ordering_other_columns(
54                        &compare_inner,
55                        descending.get_unchecked(1..),
56                        nulls_last.get_unchecked(1..),
57                        idx_a,
58                        idx_b,
59                    )
60                }
61            },
62            ord => ord,
63        }
64    };
65
66    match (options.multithreaded, options.maintain_order) {
67        (true, true) => RAYON.install(|| {
68            vals.par_sort_by(compare);
69        }),
70        (true, false) => RAYON.install(|| {
71            vals.par_sort_unstable_by(compare);
72        }),
73        (false, true) => vals.sort_by(compare),
74        (false, false) => vals.sort_unstable_by(compare),
75    }
76
77    let ca: NoNull<IdxCa> = vals.into_iter().map(|(idx, _v)| idx).collect_trusted();
78    // Don't set to sorted. Argsort indices are not sorted.
79    Ok(ca.into_inner())
80}
81
82pub(crate) fn argsort_multiple_row_fmt(
83    by: &[Column],
84    mut descending: Vec<bool>,
85    mut nulls_last: Vec<bool>,
86    parallel: bool,
87) -> PolarsResult<IdxCa> {
88    _broadcast_bools(by.len(), &mut descending);
89    _broadcast_bools(by.len(), &mut nulls_last);
90
91    let rows_encoded = _get_rows_encoded(by, &descending, &nulls_last)?;
92    let mut items: Vec<_> = rows_encoded.iter().enumerate_idx().collect();
93
94    if parallel {
95        RAYON.install(|| items.par_sort_by_key(|i| i.1));
96    } else {
97        items.sort_by_key(|i| i.1);
98    }
99
100    let ca: NoNull<IdxCa> = items.into_iter().map(|tpl| tpl.0).collect();
101    Ok(ca.into_inner())
102}