polars_core/chunked_array/ops/sort/
mod.rs

1mod arg_sort;
2
3pub mod arg_sort_multiple;
4
5pub mod arg_bottom_k;
6pub mod options;
7
8#[cfg(feature = "dtype-categorical")]
9mod categorical;
10
11use std::cmp::Ordering;
12
13pub(crate) use arg_sort::arg_sort_row_fmt;
14pub(crate) use arg_sort_multiple::argsort_multiple_row_fmt;
15use arrow::bitmap::{Bitmap, BitmapBuilder};
16use arrow::buffer::Buffer;
17use arrow::legacy::trusted_len::TrustedLenPush;
18use compare_inner::NonNull;
19use rayon::prelude::*;
20pub use slice::*;
21
22use super::row_encode::_get_rows_encoded_ca;
23use crate::POOL;
24use crate::prelude::compare_inner::TotalOrdInner;
25use crate::prelude::sort::arg_sort_multiple::*;
26use crate::prelude::*;
27use crate::series::IsSorted;
28use crate::utils::NoNull;
29
30fn partition_nulls<T: Copy>(
31    values: &mut [T],
32    mut validity: Option<Bitmap>,
33    options: SortOptions,
34) -> (&mut [T], Option<Bitmap>) {
35    let partitioned = if let Some(bitmap) = &validity {
36        // Partition null last first
37        let mut out_len = 0;
38        for idx in bitmap.true_idx_iter() {
39            unsafe { *values.get_unchecked_mut(out_len) = *values.get_unchecked(idx) };
40            out_len += 1;
41        }
42        let valid_count = out_len;
43        let null_count = values.len() - valid_count;
44        validity = Some(create_validity(
45            bitmap.len(),
46            bitmap.unset_bits(),
47            options.nulls_last,
48        ));
49
50        // Views are correctly partitioned.
51        if options.nulls_last {
52            &mut values[..valid_count]
53        }
54        // We need to swap the ends.
55        else {
56            // swap nulls with end
57            let mut end = values.len() - 1;
58
59            for i in 0..null_count {
60                unsafe { *values.get_unchecked_mut(end) = *values.get_unchecked(i) };
61                end = end.saturating_sub(1);
62            }
63            &mut values[null_count..]
64        }
65    } else {
66        values
67    };
68    (partitioned, validity)
69}
70
71pub(crate) fn sort_by_branch<T, C>(slice: &mut [T], descending: bool, cmp: C, parallel: bool)
72where
73    T: Send,
74    C: Send + Sync + Fn(&T, &T) -> Ordering,
75{
76    if parallel {
77        POOL.install(|| match descending {
78            true => slice.par_sort_by(|a, b| cmp(b, a)),
79            false => slice.par_sort_by(cmp),
80        })
81    } else {
82        match descending {
83            true => slice.sort_by(|a, b| cmp(b, a)),
84            false => slice.sort_by(cmp),
85        }
86    }
87}
88
89fn sort_unstable_by_branch<T, C>(slice: &mut [T], options: SortOptions, cmp: C)
90where
91    T: Send,
92    C: Send + Sync + Fn(&T, &T) -> Ordering,
93{
94    if options.multithreaded {
95        POOL.install(|| match options.descending {
96            true => slice.par_sort_unstable_by(|a, b| cmp(b, a)),
97            false => slice.par_sort_unstable_by(cmp),
98        })
99    } else {
100        match options.descending {
101            true => slice.sort_unstable_by(|a, b| cmp(b, a)),
102            false => slice.sort_unstable_by(cmp),
103        }
104    }
105}
106
107// Reduce monomorphisation.
108fn sort_impl_unstable<T>(vals: &mut [T], options: SortOptions)
109where
110    T: TotalOrd + Send + Sync,
111{
112    sort_unstable_by_branch(vals, options, TotalOrd::tot_cmp);
113}
114
115fn create_validity(len: usize, null_count: usize, nulls_last: bool) -> Bitmap {
116    let mut validity = BitmapBuilder::with_capacity(len);
117    if nulls_last {
118        validity.extend_constant(len - null_count, true);
119        validity.extend_constant(null_count, false);
120    } else {
121        validity.extend_constant(null_count, false);
122        validity.extend_constant(len - null_count, true);
123    }
124    validity.freeze()
125}
126
127macro_rules! sort_with_fast_path {
128    ($ca:ident, $options:expr) => {{
129        if $ca.is_empty() {
130            return $ca.clone();
131        }
132
133        // we can clone if we sort in same order
134        if $options.descending && $ca.is_sorted_descending_flag() || ($ca.is_sorted_ascending_flag() && !$options.descending) {
135            // there are nulls
136            if $ca.null_count() > 0 {
137                // if the nulls are already last we can clone
138                if $options.nulls_last && $ca.get($ca.len() - 1).is_none()  ||
139                // if the nulls are already first we can clone
140                (!$options.nulls_last && $ca.get(0).is_none())
141                {
142                    return $ca.clone();
143                }
144                // nulls are not at the right place
145                // continue w/ sorting
146                // TODO: we can optimize here and just put the null at the correct place
147            } else {
148                return $ca.clone();
149            }
150        }
151        // we can reverse if we sort in other order
152        else if ($options.descending && $ca.is_sorted_ascending_flag() || $ca.is_sorted_descending_flag()) && $ca.null_count() == 0 {
153            return $ca.reverse()
154        };
155
156
157    }}
158}
159
160macro_rules! arg_sort_fast_path {
161    ($ca:ident,  $options:expr) => {{
162        // if already sorted in required order we can just return 0..len
163        if $options.limit.is_none() &&
164        ($options.descending && $ca.is_sorted_descending_flag() || ($ca.is_sorted_ascending_flag() && !$options.descending)) {
165            // there are nulls
166            if $ca.null_count() > 0 {
167                // if the nulls are already last we can return 0..len
168                if ($options.nulls_last && $ca.get($ca.len() - 1).is_none() ) ||
169                // if the nulls are already first we can return 0..len
170                (! $options.nulls_last && $ca.get(0).is_none())
171                {
172                   return ChunkedArray::with_chunk($ca.name().clone(),
173                    IdxArr::from_data_default(Buffer::from((0..($ca.len() as IdxSize)).collect::<Vec<IdxSize>>()), None));
174                }
175                // nulls are not at the right place
176                // continue w/ sorting
177                // TODO: we can optimize here and just put the null at the correct place
178            } else {
179                // no nulls
180                return ChunkedArray::with_chunk($ca.name().clone(),
181                IdxArr::from_data_default(Buffer::from((0..($ca.len() as IdxSize )).collect::<Vec<IdxSize>>()), None));
182            }
183        }
184    }}
185}
186
187fn sort_with_numeric<T>(ca: &ChunkedArray<T>, options: SortOptions) -> ChunkedArray<T>
188where
189    T: PolarsNumericType,
190{
191    sort_with_fast_path!(ca, options);
192    if ca.null_count() == 0 {
193        let mut vals = ca.to_vec_null_aware().left().unwrap();
194
195        sort_impl_unstable(vals.as_mut_slice(), options);
196
197        let mut ca = ChunkedArray::from_vec(ca.name().clone(), vals);
198        let s = if options.descending {
199            IsSorted::Descending
200        } else {
201            IsSorted::Ascending
202        };
203        ca.set_sorted_flag(s);
204        ca
205    } else {
206        let null_count = ca.null_count();
207        let len = ca.len();
208
209        let mut vals = Vec::with_capacity(ca.len());
210
211        if !options.nulls_last {
212            let iter = std::iter::repeat_n(T::Native::default(), null_count);
213            vals.extend(iter);
214        }
215
216        ca.downcast_iter().for_each(|arr| {
217            let iter = arr.iter().filter_map(|v| v.copied());
218            vals.extend(iter);
219        });
220        let mut_slice = if options.nulls_last {
221            &mut vals[..len - null_count]
222        } else {
223            &mut vals[null_count..]
224        };
225
226        sort_impl_unstable(mut_slice, options);
227
228        if options.nulls_last {
229            vals.extend(std::iter::repeat_n(T::Native::default(), ca.null_count()));
230        }
231
232        let arr = PrimitiveArray::new(
233            T::get_static_dtype().to_arrow(CompatLevel::newest()),
234            vals.into(),
235            Some(create_validity(len, null_count, options.nulls_last)),
236        );
237        let mut new_ca = ChunkedArray::with_chunk(ca.name().clone(), arr);
238        let s = if options.descending {
239            IsSorted::Descending
240        } else {
241            IsSorted::Ascending
242        };
243        new_ca.set_sorted_flag(s);
244        new_ca
245    }
246}
247
248fn arg_sort_numeric<T>(ca: &ChunkedArray<T>, mut options: SortOptions) -> IdxCa
249where
250    T: PolarsNumericType,
251{
252    options.multithreaded &= POOL.current_num_threads() > 1;
253    arg_sort_fast_path!(ca, options);
254    if ca.null_count() == 0 {
255        let iter = ca
256            .downcast_iter()
257            .map(|arr| arr.values().as_slice().iter().copied());
258        arg_sort::arg_sort_no_nulls(
259            ca.name().clone(),
260            iter,
261            options,
262            ca.len(),
263            ca.is_sorted_flag(),
264        )
265    } else {
266        let iter = ca
267            .downcast_iter()
268            .map(|arr| arr.iter().map(|opt| opt.copied()));
269        arg_sort::arg_sort(
270            ca.name().clone(),
271            iter,
272            options,
273            ca.null_count(),
274            ca.len(),
275            ca.is_sorted_flag(),
276            ca.get(0).is_none(),
277        )
278    }
279}
280
281fn arg_sort_multiple_numeric<T: PolarsNumericType>(
282    ca: &ChunkedArray<T>,
283    by: &[Column],
284    options: &SortMultipleOptions,
285) -> PolarsResult<IdxCa> {
286    args_validate(ca, by, &options.descending, "descending")?;
287    args_validate(ca, by, &options.nulls_last, "nulls_last")?;
288    let mut count: IdxSize = 0;
289
290    let no_nulls = ca.null_count() == 0;
291
292    if no_nulls {
293        let mut vals = Vec::with_capacity(ca.len());
294        for arr in ca.downcast_iter() {
295            vals.extend_trusted_len(arr.values().as_slice().iter().map(|v| {
296                let i = count;
297                count += 1;
298                (i, NonNull(*v))
299            }))
300        }
301        arg_sort_multiple_impl(vals, by, options)
302    } else {
303        let mut vals = Vec::with_capacity(ca.len());
304        for arr in ca.downcast_iter() {
305            vals.extend_trusted_len(arr.into_iter().map(|v| {
306                let i = count;
307                count += 1;
308                (i, v.copied())
309            }));
310        }
311        arg_sort_multiple_impl(vals, by, options)
312    }
313}
314
315impl<T> ChunkSort<T> for ChunkedArray<T>
316where
317    T: PolarsNumericType,
318{
319    fn sort_with(&self, mut options: SortOptions) -> ChunkedArray<T> {
320        options.multithreaded &= POOL.current_num_threads() > 1;
321        sort_with_numeric(self, options)
322    }
323
324    fn sort(&self, descending: bool) -> ChunkedArray<T> {
325        self.sort_with(SortOptions {
326            descending,
327            ..Default::default()
328        })
329    }
330
331    fn arg_sort(&self, options: SortOptions) -> IdxCa {
332        arg_sort_numeric(self, options)
333    }
334
335    /// # Panics
336    ///
337    /// This function is very opinionated.
338    /// We assume that all numeric `Series` are of the same type, if not it will panic
339    fn arg_sort_multiple(
340        &self,
341        by: &[Column],
342        options: &SortMultipleOptions,
343    ) -> PolarsResult<IdxCa> {
344        arg_sort_multiple_numeric(self, by, options)
345    }
346}
347
348fn ordering_other_columns<'a>(
349    compare_inner: &'a [Box<dyn TotalOrdInner + 'a>],
350    descending: &[bool],
351    nulls_last: &[bool],
352    idx_a: usize,
353    idx_b: usize,
354) -> Ordering {
355    for ((cmp, descending), null_last) in compare_inner.iter().zip(descending).zip(nulls_last) {
356        // SAFETY: indices are in bounds
357        let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b, null_last ^ descending) };
358        match (ordering, descending) {
359            (Ordering::Equal, _) => continue,
360            (_, true) => return ordering.reverse(),
361            _ => return ordering,
362        }
363    }
364    // all arrays/columns exhausted, ordering equal it is.
365    Ordering::Equal
366}
367
368impl ChunkSort<StringType> for StringChunked {
369    fn sort_with(&self, options: SortOptions) -> ChunkedArray<StringType> {
370        unsafe { self.as_binary().sort_with(options).to_string_unchecked() }
371    }
372
373    fn sort(&self, descending: bool) -> StringChunked {
374        self.sort_with(SortOptions {
375            descending,
376            nulls_last: false,
377            multithreaded: true,
378            maintain_order: false,
379            limit: None,
380        })
381    }
382
383    fn arg_sort(&self, options: SortOptions) -> IdxCa {
384        self.as_binary().arg_sort(options)
385    }
386
387    /// # Panics
388    ///
389    /// This function is very opinionated. On the implementation of `ChunkedArray<T>` for numeric types,
390    /// we assume that all numeric `Series` are of the same type.
391    ///
392    /// In this case we assume that all numeric `Series` are `f64` types. The caller needs to
393    /// uphold this contract. If not, it will panic.
394    ///
395    fn arg_sort_multiple(
396        &self,
397        by: &[Column],
398        options: &SortMultipleOptions,
399    ) -> PolarsResult<IdxCa> {
400        self.as_binary().arg_sort_multiple(by, options)
401    }
402}
403
404impl ChunkSort<BinaryType> for BinaryChunked {
405    fn sort_with(&self, mut options: SortOptions) -> ChunkedArray<BinaryType> {
406        options.multithreaded &= POOL.current_num_threads() > 1;
407        sort_with_fast_path!(self, options);
408        // We will sort by the views and reconstruct with sorted views. We leave the buffers as is.
409        // We must rechunk to ensure that all views point into the proper buffers.
410        let ca = self.rechunk();
411        let arr = ca.downcast_as_array().clone();
412
413        let (views, buffers, validity, total_bytes_len, total_buffer_len) = arr.into_inner();
414        let mut views = views.make_mut();
415
416        let (partitioned_part, validity) = partition_nulls(&mut views, validity, options);
417
418        sort_unstable_by_branch(partitioned_part, options, |a, b| unsafe {
419            a.get_slice_unchecked(&buffers)
420                .tot_cmp(&b.get_slice_unchecked(&buffers))
421        });
422
423        let array = unsafe {
424            BinaryViewArray::new_unchecked(
425                ArrowDataType::BinaryView,
426                views.into(),
427                buffers,
428                validity,
429                total_bytes_len,
430                total_buffer_len,
431            )
432        };
433
434        let mut out = Self::with_chunk_like(self, array);
435
436        let s = if options.descending {
437            IsSorted::Descending
438        } else {
439            IsSorted::Ascending
440        };
441        out.set_sorted_flag(s);
442        out
443    }
444
445    fn sort(&self, descending: bool) -> ChunkedArray<BinaryType> {
446        self.sort_with(SortOptions {
447            descending,
448            nulls_last: false,
449            multithreaded: true,
450            maintain_order: false,
451            limit: None,
452        })
453    }
454
455    fn arg_sort(&self, options: SortOptions) -> IdxCa {
456        arg_sort_fast_path!(self, options);
457        if self.null_count() == 0 {
458            arg_sort::arg_sort_no_nulls(
459                self.name().clone(),
460                self.downcast_iter().map(|arr| arr.values_iter()),
461                options,
462                self.len(),
463                self.is_sorted_flag(),
464            )
465        } else {
466            arg_sort::arg_sort(
467                self.name().clone(),
468                self.downcast_iter().map(|arr| arr.iter()),
469                options,
470                self.null_count(),
471                self.len(),
472                self.is_sorted_flag(),
473                self.get(0).is_none(),
474            )
475        }
476    }
477
478    fn arg_sort_multiple(
479        &self,
480        by: &[Column],
481        options: &SortMultipleOptions,
482    ) -> PolarsResult<IdxCa> {
483        args_validate(self, by, &options.descending, "descending")?;
484        args_validate(self, by, &options.nulls_last, "nulls_last")?;
485        let mut count: IdxSize = 0;
486
487        let mut vals = Vec::with_capacity(self.len());
488        for arr in self.downcast_iter() {
489            for v in arr {
490                let i = count;
491                count += 1;
492                vals.push((i, v))
493            }
494        }
495
496        arg_sort_multiple_impl(vals, by, options)
497    }
498}
499
500impl ChunkSort<BinaryOffsetType> for BinaryOffsetChunked {
501    fn sort_with(&self, mut options: SortOptions) -> BinaryOffsetChunked {
502        options.multithreaded &= POOL.current_num_threads() > 1;
503        sort_with_fast_path!(self, options);
504
505        let mut v: Vec<&[u8]> = Vec::with_capacity(self.len());
506        for arr in self.downcast_iter() {
507            v.extend(arr.non_null_values_iter());
508        }
509
510        sort_impl_unstable(v.as_mut_slice(), options);
511
512        let mut values = Vec::<u8>::with_capacity(self.get_values_size());
513        let mut offsets = Vec::<i64>::with_capacity(self.len() + 1);
514        let mut length_so_far = 0i64;
515        offsets.push(length_so_far);
516
517        let len = self.len();
518        let null_count = self.null_count();
519        let mut ca: Self = match (null_count, options.nulls_last) {
520            (0, _) => {
521                for val in v {
522                    values.extend_from_slice(val);
523                    length_so_far = values.len() as i64;
524                    offsets.push(length_so_far);
525                }
526                // SAFETY: offsets are correctly created.
527                let arr = unsafe {
528                    BinaryArray::from_data_unchecked_default(offsets.into(), values.into(), None)
529                };
530                ChunkedArray::with_chunk(self.name().clone(), arr)
531            },
532            (_, true) => {
533                for val in v {
534                    values.extend_from_slice(val);
535                    length_so_far = values.len() as i64;
536                    offsets.push(length_so_far);
537                }
538                offsets.extend(std::iter::repeat_n(length_so_far, null_count));
539
540                // SAFETY: offsets are correctly created.
541                let arr = unsafe {
542                    BinaryArray::from_data_unchecked_default(
543                        offsets.into(),
544                        values.into(),
545                        Some(create_validity(len, null_count, true)),
546                    )
547                };
548                ChunkedArray::with_chunk(self.name().clone(), arr)
549            },
550            (_, false) => {
551                offsets.extend(std::iter::repeat_n(length_so_far, null_count));
552
553                for val in v {
554                    values.extend_from_slice(val);
555                    length_so_far = values.len() as i64;
556                    offsets.push(length_so_far);
557                }
558
559                // SAFETY: we pass valid UTF-8.
560                let arr = unsafe {
561                    BinaryArray::from_data_unchecked_default(
562                        offsets.into(),
563                        values.into(),
564                        Some(create_validity(len, null_count, false)),
565                    )
566                };
567                ChunkedArray::with_chunk(self.name().clone(), arr)
568            },
569        };
570
571        let s = if options.descending {
572            IsSorted::Descending
573        } else {
574            IsSorted::Ascending
575        };
576        ca.set_sorted_flag(s);
577        ca
578    }
579
580    fn sort(&self, descending: bool) -> BinaryOffsetChunked {
581        self.sort_with(SortOptions {
582            descending,
583            nulls_last: false,
584            multithreaded: true,
585            maintain_order: false,
586            limit: None,
587        })
588    }
589
590    fn arg_sort(&self, mut options: SortOptions) -> IdxCa {
591        options.multithreaded &= POOL.current_num_threads() > 1;
592        let ca = self.rechunk();
593        let arr = ca.downcast_as_array();
594        let mut idx = (0..(arr.len() as IdxSize)).collect::<Vec<_>>();
595
596        let argsort = |args| {
597            if options.maintain_order {
598                sort_by_branch(
599                    args,
600                    options.descending,
601                    |a, b| unsafe {
602                        let a = arr.value_unchecked(*a as usize);
603                        let b = arr.value_unchecked(*b as usize);
604                        a.tot_cmp(&b)
605                    },
606                    options.multithreaded,
607                );
608            } else {
609                sort_unstable_by_branch(args, options, |a, b| unsafe {
610                    let a = arr.value_unchecked(*a as usize);
611                    let b = arr.value_unchecked(*b as usize);
612                    a.tot_cmp(&b)
613                });
614            }
615        };
616
617        if self.null_count() == 0 {
618            argsort(&mut idx);
619            IdxCa::from_vec(self.name().clone(), idx)
620        } else {
621            // This branch (almost?) never gets called as the row-encoding also encodes nulls.
622            let (partitioned_part, validity) =
623                partition_nulls(&mut idx, arr.validity().cloned(), options);
624            argsort(partitioned_part);
625            IdxCa::with_chunk(
626                self.name().clone(),
627                IdxArr::from_data_default(idx.into(), validity),
628            )
629        }
630    }
631
632    /// # Panics
633    ///
634    /// This function is very opinionated. On the implementation of `ChunkedArray<T>` for numeric types,
635    /// we assume that all numeric `Series` are of the same type.
636    ///
637    /// In this case we assume that all numeric `Series` are `f64` types. The caller needs to
638    /// uphold this contract. If not, it will panic.
639    fn arg_sort_multiple(
640        &self,
641        by: &[Column],
642        options: &SortMultipleOptions,
643    ) -> PolarsResult<IdxCa> {
644        args_validate(self, by, &options.descending, "descending")?;
645        args_validate(self, by, &options.nulls_last, "nulls_last")?;
646        let mut count: IdxSize = 0;
647
648        let mut vals = Vec::with_capacity(self.len());
649        for arr in self.downcast_iter() {
650            for v in arr {
651                let i = count;
652                count += 1;
653                vals.push((i, v))
654            }
655        }
656
657        arg_sort_multiple_impl(vals, by, options)
658    }
659}
660
661#[cfg(feature = "dtype-struct")]
662impl ChunkSort<StructType> for StructChunked {
663    fn sort_with(&self, mut options: SortOptions) -> ChunkedArray<StructType> {
664        options.multithreaded &= POOL.current_num_threads() > 1;
665        let idx = self.arg_sort(options);
666        let mut out = unsafe { self.take_unchecked(&idx) };
667
668        let s = if options.descending {
669            IsSorted::Descending
670        } else {
671            IsSorted::Ascending
672        };
673        out.set_sorted_flag(s);
674        out
675    }
676
677    fn sort(&self, descending: bool) -> ChunkedArray<StructType> {
678        self.sort_with(SortOptions::new().with_order_descending(descending))
679    }
680
681    fn arg_sort(&self, options: SortOptions) -> IdxCa {
682        let bin = self.get_row_encoded(options).unwrap();
683        bin.arg_sort(Default::default())
684    }
685}
686
687impl ChunkSort<ListType> for ListChunked {
688    fn sort_with(&self, mut options: SortOptions) -> ListChunked {
689        options.multithreaded &= POOL.current_num_threads() > 1;
690        let idx = self.arg_sort(options);
691        let mut out = unsafe { self.take_unchecked(&idx) };
692
693        let s = if options.descending {
694            IsSorted::Descending
695        } else {
696            IsSorted::Ascending
697        };
698        out.set_sorted_flag(s);
699        out
700    }
701
702    fn sort(&self, descending: bool) -> ListChunked {
703        self.sort_with(SortOptions::new().with_order_descending(descending))
704    }
705
706    fn arg_sort(&self, options: SortOptions) -> IdxCa {
707        let bin = _get_rows_encoded_ca(
708            self.name().clone(),
709            &[self.clone().into_column()],
710            &[options.descending],
711            &[options.nulls_last],
712        )
713        .unwrap();
714        bin.arg_sort(Default::default())
715    }
716}
717
718impl ChunkSort<BooleanType> for BooleanChunked {
719    fn sort_with(&self, mut options: SortOptions) -> ChunkedArray<BooleanType> {
720        options.multithreaded &= POOL.current_num_threads() > 1;
721        sort_with_fast_path!(self, options);
722        let mut bitmap = BitmapBuilder::with_capacity(self.len());
723        let mut validity =
724            (self.null_count() > 0).then(|| BitmapBuilder::with_capacity(self.len()));
725
726        if self.null_count() > 0 && !options.nulls_last {
727            bitmap.extend_constant(self.null_count(), false);
728            if let Some(validity) = &mut validity {
729                validity.extend_constant(self.null_count(), false);
730            }
731        }
732
733        let n_valid = self.len() - self.null_count();
734        let n_set = self.sum().unwrap() as usize;
735        if options.descending {
736            bitmap.extend_constant(n_set, true);
737            bitmap.extend_constant(n_valid - n_set, false);
738        } else {
739            bitmap.extend_constant(n_valid - n_set, false);
740            bitmap.extend_constant(n_set, true);
741        }
742        if let Some(validity) = &mut validity {
743            validity.extend_constant(n_valid, true);
744        }
745
746        if self.null_count() > 0 && options.nulls_last {
747            bitmap.extend_constant(self.null_count(), false);
748            if let Some(validity) = &mut validity {
749                validity.extend_constant(self.null_count(), false);
750            }
751        }
752
753        Self::from_chunk_iter(
754            self.name().clone(),
755            Some(BooleanArray::from_data_default(
756                bitmap.freeze(),
757                validity.map(|v| v.freeze()),
758            )),
759        )
760    }
761
762    fn sort(&self, descending: bool) -> BooleanChunked {
763        self.sort_with(SortOptions {
764            descending,
765            nulls_last: false,
766            multithreaded: true,
767            maintain_order: false,
768            limit: None,
769        })
770    }
771
772    fn arg_sort(&self, options: SortOptions) -> IdxCa {
773        arg_sort_fast_path!(self, options);
774        if self.null_count() == 0 {
775            arg_sort::arg_sort_no_nulls(
776                self.name().clone(),
777                self.downcast_iter().map(|arr| arr.values_iter()),
778                options,
779                self.len(),
780                self.is_sorted_flag(),
781            )
782        } else {
783            arg_sort::arg_sort(
784                self.name().clone(),
785                self.downcast_iter().map(|arr| arr.iter()),
786                options,
787                self.null_count(),
788                self.len(),
789                self.is_sorted_flag(),
790                self.get(0).is_none(),
791            )
792        }
793    }
794    fn arg_sort_multiple(
795        &self,
796        by: &[Column],
797        options: &SortMultipleOptions,
798    ) -> PolarsResult<IdxCa> {
799        let mut vals = Vec::with_capacity(self.len());
800        let mut count: IdxSize = 0;
801        for arr in self.downcast_iter() {
802            vals.extend_trusted_len(arr.into_iter().map(|v| {
803                let i = count;
804                count += 1;
805                (i, v.map(|v| v as u8))
806            }));
807        }
808        arg_sort_multiple_impl(vals, by, options)
809    }
810}
811
812pub fn _broadcast_bools(n_cols: usize, values: &mut Vec<bool>) {
813    if n_cols > values.len() && values.len() == 1 {
814        while n_cols != values.len() {
815            values.push(values[0]);
816        }
817    }
818}
819
820pub(crate) fn prepare_arg_sort(
821    columns: Vec<Column>,
822    sort_options: &mut SortMultipleOptions,
823) -> PolarsResult<(Column, Vec<Column>)> {
824    let n_cols = columns.len();
825
826    let mut columns = columns;
827
828    _broadcast_bools(n_cols, &mut sort_options.descending);
829    _broadcast_bools(n_cols, &mut sort_options.nulls_last);
830
831    let first = columns.remove(0);
832    Ok((first, columns))
833}
834
835#[cfg(test)]
836mod test {
837    use crate::prelude::*;
838    #[test]
839    fn test_arg_sort() {
840        let a = Int32Chunked::new(
841            PlSmallStr::from_static("a"),
842            &[
843                Some(1), // 0
844                Some(5), // 1
845                None,    // 2
846                Some(1), // 3
847                None,    // 4
848                Some(4), // 5
849                Some(3), // 6
850                Some(1), // 7
851            ],
852        );
853        let idx = a.arg_sort(SortOptions {
854            descending: false,
855            ..Default::default()
856        });
857        let idx = idx.cont_slice().unwrap();
858
859        let expected = [2, 4, 0, 3, 7, 6, 5, 1];
860        assert_eq!(idx, expected);
861
862        let idx = a.arg_sort(SortOptions {
863            descending: true,
864            ..Default::default()
865        });
866        let idx = idx.cont_slice().unwrap();
867        // the duplicates are in reverse order of appearance, so we cannot reverse expected
868        let expected = [2, 4, 1, 5, 6, 0, 3, 7];
869        assert_eq!(idx, expected);
870    }
871
872    #[test]
873    fn test_sort() {
874        let a = Int32Chunked::new(
875            PlSmallStr::from_static("a"),
876            &[
877                Some(1),
878                Some(5),
879                None,
880                Some(1),
881                None,
882                Some(4),
883                Some(3),
884                Some(1),
885            ],
886        );
887        let out = a.sort_with(SortOptions {
888            descending: false,
889            nulls_last: false,
890            multithreaded: true,
891            maintain_order: false,
892            limit: None,
893        });
894        assert_eq!(
895            Vec::from(&out),
896            &[
897                None,
898                None,
899                Some(1),
900                Some(1),
901                Some(1),
902                Some(3),
903                Some(4),
904                Some(5)
905            ]
906        );
907        let out = a.sort_with(SortOptions {
908            descending: false,
909            nulls_last: true,
910            multithreaded: true,
911            maintain_order: false,
912            limit: None,
913        });
914        assert_eq!(
915            Vec::from(&out),
916            &[
917                Some(1),
918                Some(1),
919                Some(1),
920                Some(3),
921                Some(4),
922                Some(5),
923                None,
924                None
925            ]
926        );
927        let b = BooleanChunked::new(
928            PlSmallStr::from_static("b"),
929            &[Some(false), Some(true), Some(false)],
930        );
931        let out = b.sort_with(SortOptions::default().with_order_descending(true));
932        assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(false)]);
933        let out = b.sort_with(SortOptions::default().with_order_descending(false));
934        assert_eq!(Vec::from(&out), &[Some(false), Some(false), Some(true)]);
935    }
936
937    #[test]
938    #[cfg_attr(miri, ignore)]
939    fn test_arg_sort_multiple() -> PolarsResult<()> {
940        let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 1, 1, 3, 4, 3, 3]);
941        let b = Int64Chunked::new(PlSmallStr::from_static("b"), &[0, 1, 2, 3, 4, 5, 6, 1]);
942        let c = StringChunked::new(
943            PlSmallStr::from_static("c"),
944            &["a", "b", "c", "d", "e", "f", "g", "h"],
945        );
946        let df = DataFrame::new(vec![
947            a.into_series().into(),
948            b.into_series().into(),
949            c.into_series().into(),
950        ])?;
951
952        let out = df.sort(["a", "b", "c"], SortMultipleOptions::default())?;
953        assert_eq!(
954            Vec::from(out.column("b")?.as_series().unwrap().i64()?),
955            &[
956                Some(0),
957                Some(2),
958                Some(3),
959                Some(1),
960                Some(1),
961                Some(4),
962                Some(6),
963                Some(5)
964            ]
965        );
966
967        // now let the first sort be a string
968        let a = StringChunked::new(
969            PlSmallStr::from_static("a"),
970            &["a", "b", "c", "a", "b", "c"],
971        )
972        .into_series();
973        let b = Int32Chunked::new(PlSmallStr::from_static("b"), &[5, 4, 2, 3, 4, 5]).into_series();
974        let df = DataFrame::new(vec![a.into(), b.into()])?;
975
976        let out = df.sort(["a", "b"], SortMultipleOptions::default())?;
977        let expected = df!(
978            "a" => ["a", "a", "b", "b", "c", "c"],
979            "b" => [3, 5, 4, 4, 2, 5]
980        )?;
981        assert!(out.equals(&expected));
982
983        let df = df!(
984            "groups" => [1, 2, 3],
985            "values" => ["a", "a", "b"]
986        )?;
987
988        let out = df.sort(
989            ["groups", "values"],
990            SortMultipleOptions::default().with_order_descending_multi([true, false]),
991        )?;
992        let expected = df!(
993            "groups" => [3, 2, 1],
994            "values" => ["b", "a", "a"]
995        )?;
996        assert!(out.equals(&expected));
997
998        let out = df.sort(
999            ["values", "groups"],
1000            SortMultipleOptions::default().with_order_descending_multi([false, true]),
1001        )?;
1002        let expected = df!(
1003            "groups" => [2, 1, 3],
1004            "values" => ["a", "a", "b"]
1005        )?;
1006        assert!(out.equals(&expected));
1007
1008        Ok(())
1009    }
1010
1011    #[test]
1012    fn test_sort_string() {
1013        let ca = StringChunked::new(
1014            PlSmallStr::from_static("a"),
1015            &[Some("a"), None, Some("c"), None, Some("b")],
1016        );
1017        let out = ca.sort_with(SortOptions {
1018            descending: false,
1019            nulls_last: false,
1020            multithreaded: true,
1021            maintain_order: false,
1022            limit: None,
1023        });
1024        let expected = &[None, None, Some("a"), Some("b"), Some("c")];
1025        assert_eq!(Vec::from(&out), expected);
1026
1027        let out = ca.sort_with(SortOptions {
1028            descending: true,
1029            nulls_last: false,
1030            multithreaded: true,
1031            maintain_order: false,
1032            limit: None,
1033        });
1034
1035        let expected = &[None, None, Some("c"), Some("b"), Some("a")];
1036        assert_eq!(Vec::from(&out), expected);
1037
1038        let out = ca.sort_with(SortOptions {
1039            descending: false,
1040            nulls_last: true,
1041            multithreaded: true,
1042            maintain_order: false,
1043            limit: None,
1044        });
1045        let expected = &[Some("a"), Some("b"), Some("c"), None, None];
1046        assert_eq!(Vec::from(&out), expected);
1047
1048        let out = ca.sort_with(SortOptions {
1049            descending: true,
1050            nulls_last: true,
1051            multithreaded: true,
1052            maintain_order: false,
1053            limit: None,
1054        });
1055        let expected = &[Some("c"), Some("b"), Some("a"), None, None];
1056        assert_eq!(Vec::from(&out), expected);
1057
1058        // no nulls
1059        let ca = StringChunked::new(
1060            PlSmallStr::from_static("a"),
1061            &[Some("a"), Some("c"), Some("b")],
1062        );
1063        let out = ca.sort(false);
1064        let expected = &[Some("a"), Some("b"), Some("c")];
1065        assert_eq!(Vec::from(&out), expected);
1066
1067        let out = ca.sort(true);
1068        let expected = &[Some("c"), Some("b"), Some("a")];
1069        assert_eq!(Vec::from(&out), expected);
1070    }
1071}