1use polars_utils::itertools::Itertools;
2
3use self::row_encode::_get_rows_encoded;
4use super::*;
5
6fn sort_impl<T>(vals: &mut [(IdxSize, T)], options: SortOptions)
8where
9 T: TotalOrd + Send + Sync,
10{
11 sort_by_branch(
12 vals,
13 options.descending,
14 |a, b| a.1.tot_cmp(&b.1),
15 options.multithreaded,
16 );
17}
18pub(super) fn reverse_stable_no_nulls<I, J, T>(iters: I, len: usize) -> Vec<IdxSize>
22where
23 I: IntoIterator<Item = J>,
24 J: IntoIterator<Item = T>,
25 T: TotalOrd + Send + Sync,
26{
27 let mut current_start: IdxSize = 0;
28 let mut current_end: IdxSize = 0;
29 let mut rev_idx: Vec<IdxSize> = Vec::with_capacity(len);
30 let mut i: IdxSize;
31 let mut previous_element: Option<T> = None;
45 for arr_iter in iters {
46 for current_element in arr_iter {
47 match &previous_element {
48 None => {
49 current_end = 1;
51 },
52 Some(prev) => {
53 if current_element.tot_cmp(prev) == Ordering::Equal {
54 current_end += 1;
55 } else {
56 i = current_end;
58 while i > current_start {
59 i -= 1;
60 unsafe { rev_idx.push_unchecked(i) };
62 }
63 current_start = current_end;
64 current_end += 1;
65 }
66 },
67 }
68 previous_element = Some(current_element);
69 }
70 }
71 i = current_end;
73 while i > current_start {
74 i -= 1;
75 unsafe { rev_idx.push_unchecked(i) };
76 }
77 rev_idx.reverse();
79 rev_idx
80}
81
82pub(super) fn arg_sort<I, J, T>(
83 name: PlSmallStr,
84 iters: I,
85 options: SortOptions,
86 null_count: usize,
87 mut len: usize,
88 is_sorted_flag: IsSorted,
89 first_element_null: bool,
90) -> IdxCa
91where
92 I: IntoIterator<Item = J>,
93 J: IntoIterator<Item = Option<T>>,
94 T: TotalOrd + Send + Sync,
95{
96 let nulls_last = options.nulls_last;
97 let null_cap = if nulls_last { null_count } else { len };
98
99 if ((options.descending && is_sorted_flag == IsSorted::Descending)
103 || (!options.descending && is_sorted_flag == IsSorted::Ascending))
104 && ((nulls_last && !first_element_null) || (!nulls_last && first_element_null))
105 {
106 len = options
107 .limit
108 .map_or(len, |limit| std::cmp::min(limit.try_into().unwrap(), len));
109 return ChunkedArray::with_chunk(
110 name,
111 IdxArr::from_data_default(
112 Buffer::from((0..(len as IdxSize)).collect::<Vec<IdxSize>>()),
113 None,
114 ),
115 );
116 }
117
118 let mut vals = Vec::with_capacity(len - null_count);
119 let mut nulls_idx = Vec::with_capacity(null_cap);
120 let mut count: IdxSize = 0;
121
122 for arr_iter in iters {
123 let iter = arr_iter.into_iter().filter_map(|v| {
124 let i = count;
125 count += 1;
126 match v {
127 Some(v) => Some((i, v)),
128 None => {
129 unsafe { nulls_idx.push_unchecked(i) };
131 None
132 },
133 }
134 });
135 vals.extend(iter);
136 }
137
138 let vals = if let Some(limit) = options.limit {
139 let limit = limit as usize;
140 len = limit;
142 let out = if limit >= vals.len() {
143 vals.as_mut_slice()
144 } else {
145 let (lower, _el, _upper) = if options.descending {
146 vals.as_mut_slice()
147 .select_nth_unstable_by(limit, |a, b| b.1.tot_cmp(&a.1))
148 } else {
149 vals.as_mut_slice()
150 .select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1))
151 };
152 lower
153 };
154
155 sort_impl(out, options);
156 out
157 } else {
158 sort_impl(vals.as_mut_slice(), options);
159 vals.as_slice()
160 };
161
162 let iter = vals.iter().map(|(idx, _v)| idx).copied();
163 let idx = if nulls_last {
164 let mut idx = Vec::with_capacity(len);
165 idx.extend(iter);
166
167 let nulls_idx = if options.limit.is_some() {
168 &nulls_idx[..len - idx.len()]
169 } else {
170 &nulls_idx
171 };
172 idx.extend_from_slice(nulls_idx);
173 idx
174 } else if options.limit.is_some() {
175 nulls_idx.extend(iter.take(len - nulls_idx.len()));
176 nulls_idx
177 } else {
178 let ptr = nulls_idx.as_ptr() as usize;
179 nulls_idx.extend(iter);
180 debug_assert_eq!(nulls_idx.as_ptr() as usize, ptr);
182 nulls_idx
183 };
184
185 ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None))
186}
187
188pub(super) fn arg_sort_no_nulls<I, J, T>(
189 name: PlSmallStr,
190 iters: I,
191 options: SortOptions,
192 len: usize,
193 is_sorted_flag: IsSorted,
194) -> IdxCa
195where
196 I: IntoIterator<Item = J>,
197 J: IntoIterator<Item = T>,
198 T: TotalOrd + Send + Sync,
199{
200 if is_sorted_flag != IsSorted::Not {
204 let len_final = options
205 .limit
206 .map_or(len, |limit| std::cmp::min(limit.try_into().unwrap(), len));
207 if (options.descending && is_sorted_flag == IsSorted::Descending)
208 || (!options.descending && is_sorted_flag == IsSorted::Ascending)
209 {
210 return ChunkedArray::with_chunk(
211 name,
212 IdxArr::from_data_default(
213 Buffer::from((0..(len_final as IdxSize)).collect::<Vec<IdxSize>>()),
214 None,
215 ),
216 );
217 } else if (options.descending && is_sorted_flag == IsSorted::Ascending)
218 || (!options.descending && is_sorted_flag == IsSorted::Descending)
219 {
220 let idx = reverse_stable_no_nulls(iters, len);
221 let idx = Buffer::from(idx).sliced(..len_final);
222 return ChunkedArray::with_chunk(name, IdxArr::from_data_default(idx, None));
223 }
224 }
225
226 let mut vals = Vec::with_capacity(len);
227
228 let mut count: IdxSize = 0;
229 for arr_iter in iters {
230 vals.extend(arr_iter.into_iter().map(|v| {
231 let idx = count;
232 count += 1;
233 (idx, v)
234 }));
235 }
236
237 let vals = if let Some(limit) = options.limit {
238 let limit = limit as usize;
239 let out = if limit >= vals.len() {
240 vals.as_mut_slice()
241 } else {
242 let (lower, _el, _upper) = if options.descending {
243 vals.as_mut_slice()
244 .select_nth_unstable_by(limit, |a, b| b.1.tot_cmp(&a.1))
245 } else {
246 vals.as_mut_slice()
247 .select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1))
248 };
249 lower
250 };
251 sort_impl(out, options);
252 out
253 } else {
254 sort_impl(vals.as_mut_slice(), options);
255 vals.as_slice()
256 };
257
258 let iter = vals.iter().map(|(idx, _v)| idx).copied();
259 let idx: Vec<_> = iter.collect_trusted();
260
261 ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None))
262}
263
264pub(crate) fn arg_sort_row_fmt(
265 by: &[Column],
266 descending: bool,
267 nulls_last: bool,
268 parallel: bool,
269) -> PolarsResult<IdxCa> {
270 let rows_encoded = _get_rows_encoded(by, &[descending], &[nulls_last])?;
271 let mut items: Vec<_> = rows_encoded.iter().enumerate_idx().collect();
272
273 if parallel {
274 RAYON.install(|| items.par_sort_by(|a, b| a.1.cmp(b.1)));
275 } else {
276 items.sort_by(|a, b| a.1.cmp(b.1));
277 }
278
279 let ca: NoNull<IdxCa> = items.into_iter().map(|tpl| tpl.0).collect();
280 Ok(ca.into_inner())
281}
282#[cfg(test)]
283mod test {
284 use sort::arg_sort::reverse_stable_no_nulls;
285
286 use crate::prelude::*;
287
288 #[test]
289 fn test_reverse_stable_no_nulls() {
290 let a = Int32Chunked::new(
291 PlSmallStr::from_static("a"),
292 &[
293 Some(1), Some(2), Some(2), Some(3), Some(3), Some(3), Some(4), ],
301 );
302 let idx = reverse_stable_no_nulls(a.iter(), 7);
303 let expected = [6, 3, 4, 5, 1, 2, 0];
304 assert_eq!(idx, expected);
305
306 let a = Int32Chunked::new(
307 PlSmallStr::from_static("a"),
308 &[
309 Some(1), Some(2), Some(3), Some(4), Some(5), Some(6), Some(7), ],
317 );
318 let idx = reverse_stable_no_nulls(a.iter(), 7);
319 let expected = [6, 5, 4, 3, 2, 1, 0];
320 assert_eq!(idx, expected);
321
322 let a = Int32Chunked::new(
323 PlSmallStr::from_static("a"),
324 &[
325 Some(1), ],
327 );
328 let idx = reverse_stable_no_nulls(a.iter(), 1);
329 let expected = [0];
330 assert_eq!(idx, expected);
331
332 let empty_array: [i32; 0] = [];
333 let a = Int32Chunked::new(PlSmallStr::from_static("a"), &empty_array);
334 let idx = reverse_stable_no_nulls(a.iter(), 0);
335 assert_eq!(idx.len(), 0);
336 }
337
338 #[test]
339 fn test_arg_sort_descending_with_limit() {
340 let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 2, 5, 1, 3]);
341 let o = SortOptions {
342 descending: true,
343 nulls_last: false,
344 multithreaded: false,
345 limit: Some(3),
346 ..Default::default()
347 };
348 let r = a.arg_sort(o);
349 let idx: Vec<IdxSize> = r.into_no_null_iter().collect();
350 assert_eq!(idx, vec![2, 0, 4]);
351 }
352
353 #[test]
354 fn test_arg_sort_asc_with_limit() {
355 let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 2, 5, 1, 3]);
356 let o = SortOptions {
357 descending: false,
358 nulls_last: false,
359 multithreaded: false,
360 limit: Some(3),
361 ..Default::default()
362 };
363 let r = a.arg_sort(o);
364 let idx: Vec<IdxSize> = r.into_no_null_iter().collect();
365 assert_eq!(idx, vec![3, 1, 4]);
366 }
367
368 #[test]
369 fn test_arg_sort_desc_limit_nulls() {
370 let a = Int32Chunked::new(
371 PlSmallStr::from_static("a"),
372 &[Some(4), None, Some(5), Some(1), None, Some(3)],
373 );
374 let o = SortOptions {
375 descending: true,
376 nulls_last: true,
377 multithreaded: false,
378 limit: Some(3),
379 ..Default::default()
380 };
381 let r = a.arg_sort(o);
382 let idx: Vec<IdxSize> = r.into_no_null_iter().collect();
383 assert_eq!(idx, vec![2, 0, 5]);
384 }
385}