polars_core/chunked_array/ops/sort/
categorical.rs

1use super::*;
2
3impl CategoricalChunked {
4    #[must_use]
5    pub fn sort_with(&self, options: SortOptions) -> CategoricalChunked {
6        assert!(
7            !options.nulls_last,
8            "null last not yet supported for categorical dtype"
9        );
10
11        if self.uses_lexical_ordering() {
12            let mut vals = self
13                .physical()
14                .into_iter()
15                .zip(self.iter_str())
16                .collect_trusted::<Vec<_>>();
17
18            sort_unstable_by_branch(vals.as_mut_slice(), options, |a, b| a.1.cmp(&b.1));
19            let cats: UInt32Chunked = vals
20                .into_iter()
21                .map(|(idx, _v)| idx)
22                .collect_ca_trusted(self.name().clone());
23
24            // SAFETY:
25            // we only reordered the indexes so we are still in bounds
26            return unsafe {
27                CategoricalChunked::from_cats_and_rev_map_unchecked(
28                    cats,
29                    self.get_rev_map().clone(),
30                    self.is_enum(),
31                    self.get_ordering(),
32                )
33            };
34        }
35        let cats = self.physical().sort_with(options);
36        // SAFETY:
37        // we only reordered the indexes so we are still in bounds
38        unsafe {
39            CategoricalChunked::from_cats_and_rev_map_unchecked(
40                cats,
41                self.get_rev_map().clone(),
42                self.is_enum(),
43                self.get_ordering(),
44            )
45        }
46    }
47
48    /// Returned a sorted `ChunkedArray`.
49    #[must_use]
50    pub fn sort(&self, descending: bool) -> CategoricalChunked {
51        self.sort_with(SortOptions {
52            nulls_last: false,
53            descending,
54            multithreaded: true,
55            maintain_order: false,
56            limit: None,
57        })
58    }
59
60    /// Retrieve the indexes needed to sort this array.
61    pub fn arg_sort(&self, options: SortOptions) -> IdxCa {
62        if self.uses_lexical_ordering() {
63            let iters = [self.iter_str()];
64            arg_sort::arg_sort(
65                self.name().clone(),
66                iters,
67                options,
68                self.physical().null_count(),
69                self.len(),
70                IsSorted::Not,
71                false,
72            )
73        } else {
74            self.physical().arg_sort(options)
75        }
76    }
77
78    /// Retrieve the indices needed to sort this and the other arrays.
79    pub(crate) fn arg_sort_multiple(
80        &self,
81        by: &[Column],
82        options: &SortMultipleOptions,
83    ) -> PolarsResult<IdxCa> {
84        if self.uses_lexical_ordering() {
85            args_validate(self.physical(), by, &options.descending, "descending")?;
86            args_validate(self.physical(), by, &options.nulls_last, "nulls_last")?;
87            let mut count: IdxSize = 0;
88
89            // we use bytes to save a monomorphisized str impl
90            // as bytes already is used for binary and string sorting
91            let vals: Vec<_> = self
92                .iter_str()
93                .map(|v| {
94                    let i = count;
95                    count += 1;
96                    (i, v.map(|v| v.as_bytes()))
97                })
98                .collect_trusted();
99
100            arg_sort_multiple_impl(vals, by, options)
101        } else {
102            self.physical().arg_sort_multiple(by, options)
103        }
104    }
105}
106
107#[cfg(test)]
108mod test {
109    use crate::prelude::*;
110    use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache};
111
112    fn assert_order(ca: &CategoricalChunked, cmp: &[&str]) {
113        let s = ca.cast(&DataType::String).unwrap();
114        let ca = s.str().unwrap();
115        assert_eq!(ca.into_no_null_iter().collect::<Vec<_>>(), cmp);
116    }
117
118    #[test]
119    fn test_cat_lexical_sort() -> PolarsResult<()> {
120        let init = &["c", "b", "a", "d"];
121
122        let _lock = SINGLE_LOCK.lock();
123        for use_string_cache in [true, false] {
124            disable_string_cache();
125            if use_string_cache {
126                enable_string_cache();
127            }
128
129            let s = Series::new(PlSmallStr::EMPTY, init)
130                .cast(&DataType::Categorical(None, CategoricalOrdering::Lexical))?;
131            let ca = s.categorical()?;
132            let ca_lexical = ca.clone();
133
134            let out = ca_lexical.sort(false);
135            assert_order(&out, &["a", "b", "c", "d"]);
136
137            let s = Series::new(PlSmallStr::EMPTY, init)
138                .cast(&DataType::Categorical(None, Default::default()))?;
139            let ca = s.categorical()?;
140
141            let out = ca.sort(false);
142            assert_order(&out, init);
143
144            let out = ca_lexical.arg_sort(SortOptions {
145                descending: false,
146                ..Default::default()
147            });
148            assert_eq!(out.into_no_null_iter().collect::<Vec<_>>(), &[2, 1, 0, 3]);
149        }
150
151        Ok(())
152    }
153
154    #[test]
155    fn test_cat_lexical_sort_multiple() -> PolarsResult<()> {
156        let init = &["c", "b", "a", "a"];
157
158        let _lock = SINGLE_LOCK.lock();
159        for use_string_cache in [true, false] {
160            disable_string_cache();
161            if use_string_cache {
162                enable_string_cache();
163            }
164
165            let s = Series::new(PlSmallStr::EMPTY, init)
166                .cast(&DataType::Categorical(None, CategoricalOrdering::Lexical))?;
167            let ca = s.categorical()?;
168            let ca_lexical: CategoricalChunked = ca.clone();
169
170            let series = ca_lexical.into_series();
171
172            let df = df![
173                "cat" => &series,
174                "vals" => [1, 1, 2, 2]
175            ]?;
176
177            let out = df.sort(
178                ["cat", "vals"],
179                SortMultipleOptions::default().with_order_descending_multi([false, false]),
180            )?;
181            let out = out.column("cat")?;
182            let cat = out.as_materialized_series().categorical()?;
183            assert_order(cat, &["a", "a", "b", "c"]);
184
185            let out = df.sort(
186                ["vals", "cat"],
187                SortMultipleOptions::default().with_order_descending_multi([false, false]),
188            )?;
189            let out = out.column("cat")?;
190            let cat = out.as_materialized_series().categorical()?;
191            assert_order(cat, &["b", "c", "a", "a"]);
192        }
193        Ok(())
194    }
195}