polars_core/chunked_array/ops/sort/
arg_bottom_k.rs1use polars_utils::itertools::Itertools;
2
3use super::*;
4use crate::chunked_array::ops::row_encode::_get_rows_encoded;
5
6#[derive(Eq)]
7struct CompareRow<'a> {
8 idx: IdxSize,
9 bytes: &'a [u8],
10}
11
12impl PartialEq for CompareRow<'_> {
13 fn eq(&self, other: &Self) -> bool {
14 self.bytes == other.bytes
15 }
16}
17
18impl Ord for CompareRow<'_> {
19 fn cmp(&self, other: &Self) -> Ordering {
20 self.bytes.cmp(other.bytes)
21 }
22}
23
24impl PartialOrd for CompareRow<'_> {
25 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
26 Some(self.cmp(other))
27 }
28}
29
30pub fn _arg_bottom_k(
34 k: usize,
35 by_column: &[Column],
36 sort_options: &mut SortMultipleOptions,
37) -> PolarsResult<NoNull<IdxCa>> {
38 let from_n_rows = by_column[0].len();
39 _broadcast_bools(by_column.len(), &mut sort_options.descending);
40 _broadcast_bools(by_column.len(), &mut sort_options.nulls_last);
41
42 if by_column.len() == 1 && sort_options.limit.is_some() && !sort_options.maintain_order {
44 return Ok(NoNull::new(by_column[0].arg_sort((&*sort_options).into())));
45 }
46
47 let encoded = _get_rows_encoded(
48 by_column,
49 &sort_options.descending,
50 &sort_options.nulls_last,
51 )?;
52 let arr = encoded.into_array();
53 let mut rows = arr
54 .values_iter()
55 .enumerate_idx()
56 .map(|(idx, bytes)| CompareRow { idx, bytes })
57 .collect::<Vec<_>>();
58
59 let sorted = if k >= from_n_rows {
60 match (sort_options.multithreaded, sort_options.maintain_order) {
61 (true, true) => RAYON.install(|| {
62 rows.par_sort();
63 }),
64 (true, false) => RAYON.install(|| {
65 rows.par_sort_unstable();
66 }),
67 (false, true) => rows.sort(),
68 (false, false) => rows.sort_unstable(),
69 }
70 &rows
71 } else if sort_options.maintain_order {
72 if sort_options.multithreaded {
74 RAYON.install(|| {
75 rows.par_sort();
76 })
77 } else {
78 rows.sort();
79 }
80 &rows[..k]
81 } else {
82 let (lower, _el, _upper) = rows.select_nth_unstable(k);
84 if sort_options.multithreaded {
85 RAYON.install(|| {
86 lower.par_sort_unstable();
87 })
88 } else {
89 lower.sort_unstable();
90 }
91 &*lower
92 };
93
94 let idx: NoNull<IdxCa> = sorted.iter().map(|cmp_row| cmp_row.idx).collect();
95 Ok(idx)
96}