Skip to main content

polars_core/chunked_array/ops/sort/
arg_bottom_k.rs

1use polars_utils::itertools::Itertools;
2
3use super::*;
4use crate::chunked_array::ops::row_encode::_get_rows_encoded;
5
6#[derive(Eq)]
7struct CompareRow<'a> {
8    idx: IdxSize,
9    bytes: &'a [u8],
10}
11
12impl PartialEq for CompareRow<'_> {
13    fn eq(&self, other: &Self) -> bool {
14        self.bytes == other.bytes
15    }
16}
17
18impl Ord for CompareRow<'_> {
19    fn cmp(&self, other: &Self) -> Ordering {
20        self.bytes.cmp(other.bytes)
21    }
22}
23
24impl PartialOrd for CompareRow<'_> {
25    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
26        Some(self.cmp(other))
27    }
28}
29
30/// Return the indices of the bottom k elements.
31///
32/// Similar to .argsort() then .slice(0, k) but with a more efficient implementation.
33pub fn _arg_bottom_k(
34    k: usize,
35    by_column: &[Column],
36    sort_options: &mut SortMultipleOptions,
37) -> PolarsResult<NoNull<IdxCa>> {
38    let from_n_rows = by_column[0].len();
39    _broadcast_bools(by_column.len(), &mut sort_options.descending);
40    _broadcast_bools(by_column.len(), &mut sort_options.nulls_last);
41
42    // Don't go into row encoding.
43    if by_column.len() == 1 && sort_options.limit.is_some() && !sort_options.maintain_order {
44        return Ok(NoNull::new(by_column[0].arg_sort((&*sort_options).into())));
45    }
46
47    let encoded = _get_rows_encoded(
48        by_column,
49        &sort_options.descending,
50        &sort_options.nulls_last,
51    )?;
52    let arr = encoded.into_array();
53    let mut rows = arr
54        .values_iter()
55        .enumerate_idx()
56        .map(|(idx, bytes)| CompareRow { idx, bytes })
57        .collect::<Vec<_>>();
58
59    let sorted = if k >= from_n_rows {
60        match (sort_options.multithreaded, sort_options.maintain_order) {
61            (true, true) => RAYON.install(|| {
62                rows.par_sort();
63            }),
64            (true, false) => RAYON.install(|| {
65                rows.par_sort_unstable();
66            }),
67            (false, true) => rows.sort(),
68            (false, false) => rows.sort_unstable(),
69        }
70        &rows
71    } else if sort_options.maintain_order {
72        // todo: maybe there is some more efficient method, comparable to select_nth_unstable
73        if sort_options.multithreaded {
74            RAYON.install(|| {
75                rows.par_sort();
76            })
77        } else {
78            rows.sort();
79        }
80        &rows[..k]
81    } else {
82        // todo: possible multi threaded `select_nth_unstable`?
83        let (lower, _el, _upper) = rows.select_nth_unstable(k);
84        if sort_options.multithreaded {
85            RAYON.install(|| {
86                lower.par_sort_unstable();
87            })
88        } else {
89            lower.sort_unstable();
90        }
91        &*lower
92    };
93
94    let idx: NoNull<IdxCa> = sorted.iter().map(|cmp_row| cmp_row.idx).collect();
95    Ok(idx)
96}