polars_core/chunked_array/logical/categorical/
mod.rs

1mod builder;
2mod from;
3mod merge;
4mod ops;
5pub mod revmap;
6pub mod string_cache;
7
8use bitflags::bitflags;
9pub use builder::*;
10pub use merge::*;
11use polars_utils::itertools::Itertools;
12use polars_utils::sync::SyncPtr;
13pub use revmap::*;
14
15use super::*;
16use crate::chunked_array::cast::CastOptions;
17use crate::chunked_array::flags::StatisticsFlags;
18use crate::prelude::*;
19use crate::series::IsSorted;
20use crate::using_string_cache;
21
22bitflags! {
23    #[derive(Default, Clone)]
24    struct BitSettings: u8 {
25        const ORIGINAL = 0x01;
26    }
27}
28
29#[derive(Default, Clone)]
30pub struct CategoricalChunked {
31    physical: Logical<CategoricalType, UInt32Type>,
32    /// 1st bit: original local categorical
33    ///             meaning that n_unique is the same as the cat map length
34    bit_settings: BitSettings,
35}
36
37impl CategoricalChunked {
38    pub(crate) fn field(&self) -> Field {
39        let name = self.physical().name();
40        Field::new(name.clone(), self.dtype().clone())
41    }
42
43    pub fn is_empty(&self) -> bool {
44        self.len() == 0
45    }
46
47    #[inline]
48    pub fn len(&self) -> usize {
49        self.physical.len()
50    }
51
52    #[inline]
53    pub fn null_count(&self) -> usize {
54        self.physical.null_count()
55    }
56
57    pub fn name(&self) -> &PlSmallStr {
58        self.physical.name()
59    }
60
61    // TODO: Rename this
62    /// Get a reference to the physical array (the categories).
63    pub fn physical(&self) -> &UInt32Chunked {
64        &self.physical
65    }
66
67    /// Get a mutable reference to the physical array (the categories).
68    pub(crate) fn physical_mut(&mut self) -> &mut UInt32Chunked {
69        &mut self.physical
70    }
71
72    pub fn is_enum(&self) -> bool {
73        matches!(self.dtype(), DataType::Enum(_, _))
74    }
75
76    /// Convert a categorical column to its local representation.
77    pub fn to_local(&self) -> Self {
78        let rev_map = self.get_rev_map();
79        let (physical_map, categories) = match rev_map.as_ref() {
80            RevMapping::Global(m, c, _) => (m, c),
81            RevMapping::Local(_, _) if !self.is_enum() => return self.clone(),
82            RevMapping::Local(_, _) => {
83                // Change dtype from Enum to Categorical
84                let mut local = self.clone();
85                local.physical.2 = Some(DataType::Categorical(
86                    Some(rev_map.clone()),
87                    self.get_ordering(),
88                ));
89                return local;
90            },
91        };
92
93        let local_rev_map = RevMapping::build_local(categories.clone());
94        // TODO: A fast path can possibly be implemented here:
95        // if all physical map keys are equal to their values,
96        // we can skip the apply and only update the rev_map
97        let local_ca = self
98            .physical()
99            .apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap()));
100
101        let mut out = unsafe {
102            Self::from_cats_and_rev_map_unchecked(
103                local_ca,
104                local_rev_map.into(),
105                false,
106                self.get_ordering(),
107            )
108        };
109        out.set_fast_unique(self._can_fast_unique());
110
111        out
112    }
113
114    pub fn to_global(&self) -> PolarsResult<Self> {
115        polars_ensure!(using_string_cache(), string_cache_mismatch);
116        // Fast path
117        let categories = match &**self.get_rev_map() {
118            RevMapping::Global(_, _, _) => return Ok(self.clone()),
119            RevMapping::Local(categories, _) => categories,
120        };
121
122        // SAFETY: keys and values are in bounds
123        unsafe {
124            Ok(CategoricalChunked::from_keys_and_values_global(
125                self.name().clone(),
126                self.physical(),
127                self.len(),
128                categories,
129                self.get_ordering(),
130            ))
131        }
132    }
133
134    // Convert to fixed enum. Values not in categories are mapped to None.
135    pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> Self {
136        // Fast paths
137        match self.get_rev_map().as_ref() {
138            RevMapping::Local(_, cur_hash) if hash == *cur_hash => {
139                return unsafe {
140                    CategoricalChunked::from_cats_and_rev_map_unchecked(
141                        self.physical().clone(),
142                        self.get_rev_map().clone(),
143                        true,
144                        self.get_ordering(),
145                    )
146                };
147            },
148            _ => (),
149        };
150        // Make a mapping from old idx to new idx
151        let old_rev_map = self.get_rev_map();
152        #[allow(clippy::unnecessary_cast)]
153        let idx_map: PlHashMap<u32, u32> = categories
154            .values_iter()
155            .enumerate_idx()
156            .filter_map(|(new_idx, s)| old_rev_map.find(s).map(|old_idx| (old_idx, new_idx as u32)))
157            .collect();
158
159        // Loop over the physicals and try get new idx
160        let new_phys: UInt32Chunked = self
161            .physical()
162            .into_iter()
163            .map(|opt_v: Option<u32>| opt_v.and_then(|v| idx_map.get(&v).copied()))
164            .collect();
165
166        // SAFETY: we created the physical from the enum categories
167        unsafe {
168            CategoricalChunked::from_cats_and_rev_map_unchecked(
169                new_phys,
170                Arc::new(RevMapping::Local(categories.clone(), hash)),
171                true,
172                self.get_ordering(),
173            )
174        }
175    }
176
177    pub(crate) fn get_flags(&self) -> StatisticsFlags {
178        self.physical().get_flags()
179    }
180
181    /// Set flags for the Chunked Array
182    pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) {
183        // We should not set the sorted flag if we are sorting in lexical order
184        if self.uses_lexical_ordering() {
185            flags.set_sorted(IsSorted::Not)
186        }
187        self.physical_mut().set_flags(flags)
188    }
189
190    /// Return whether or not the [`CategoricalChunked`] uses the lexical order
191    /// of the string values when sorting.
192    pub fn uses_lexical_ordering(&self) -> bool {
193        self.get_ordering() == CategoricalOrdering::Lexical
194    }
195
196    pub(crate) fn get_ordering(&self) -> CategoricalOrdering {
197        if let DataType::Categorical(_, ordering) | DataType::Enum(_, ordering) =
198            &self.physical.2.as_ref().unwrap()
199        {
200            *ordering
201        } else {
202            panic!("implementation error")
203        }
204    }
205
206    /// Create a [`CategoricalChunked`] from a physical array and dtype.
207    ///
208    /// # Safety
209    /// It's not checked that the indices are in-bounds or that the dtype is
210    /// correct.
211    pub unsafe fn from_cats_and_dtype_unchecked(idx: UInt32Chunked, dtype: DataType) -> Self {
212        debug_assert!(matches!(
213            dtype,
214            DataType::Enum { .. } | DataType::Categorical { .. }
215        ));
216        let mut logical = Logical::<UInt32Type, _>::new_logical::<CategoricalType>(idx);
217        logical.2 = Some(dtype);
218        Self {
219            physical: logical,
220            bit_settings: Default::default(),
221        }
222    }
223
224    /// Create a [`CategoricalChunked`] from an array of `idx` and an existing [`RevMapping`]:  `rev_map`.
225    ///
226    /// # Safety
227    /// Invariant in `v < rev_map.len() for v in idx` must hold.
228    pub unsafe fn from_cats_and_rev_map_unchecked(
229        idx: UInt32Chunked,
230        rev_map: Arc<RevMapping>,
231        is_enum: bool,
232        ordering: CategoricalOrdering,
233    ) -> Self {
234        let mut logical = Logical::<UInt32Type, _>::new_logical::<CategoricalType>(idx);
235        if is_enum {
236            logical.2 = Some(DataType::Enum(Some(rev_map), ordering));
237        } else {
238            logical.2 = Some(DataType::Categorical(Some(rev_map), ordering));
239        }
240        Self {
241            physical: logical,
242            bit_settings: Default::default(),
243        }
244    }
245
246    pub(crate) fn set_ordering(
247        mut self,
248        ordering: CategoricalOrdering,
249        keep_fast_unique: bool,
250    ) -> Self {
251        self.physical.2 = match self.dtype() {
252            DataType::Enum(_, _) => {
253                Some(DataType::Enum(Some(self.get_rev_map().clone()), ordering))
254            },
255            DataType::Categorical(_, _) => Some(DataType::Categorical(
256                Some(self.get_rev_map().clone()),
257                ordering,
258            )),
259            _ => panic!("implementation error"),
260        };
261
262        if !keep_fast_unique {
263            self.set_fast_unique(false)
264        }
265        self
266    }
267
268    /// # Safety
269    /// The existing index values must be in bounds of the new [`RevMapping`].
270    pub(crate) unsafe fn set_rev_map(&mut self, rev_map: Arc<RevMapping>, keep_fast_unique: bool) {
271        self.physical.2 = match self.dtype() {
272            DataType::Enum(_, _) => Some(DataType::Enum(Some(rev_map), self.get_ordering())),
273            DataType::Categorical(_, _) => {
274                Some(DataType::Categorical(Some(rev_map), self.get_ordering()))
275            },
276            _ => panic!("implementation error"),
277        };
278
279        if !keep_fast_unique {
280            self.set_fast_unique(false)
281        }
282    }
283
284    /// True if all categories are represented in this array. When this is the case, the unique
285    /// values of the array are the categories.
286    pub fn _can_fast_unique(&self) -> bool {
287        self.bit_settings.contains(BitSettings::ORIGINAL)
288            && self.physical.chunks.len() == 1
289            && self.null_count() == 0
290    }
291
292    pub(crate) fn set_fast_unique(&mut self, toggle: bool) {
293        if toggle {
294            self.bit_settings.insert(BitSettings::ORIGINAL);
295        } else {
296            self.bit_settings.remove(BitSettings::ORIGINAL);
297        }
298    }
299
300    /// Set `FAST_UNIQUE` metadata
301    /// # Safety
302    /// This invariant must hold `unique(categories) == unique(self)`
303    pub(crate) unsafe fn with_fast_unique(mut self, toggle: bool) -> Self {
304        self.set_fast_unique(toggle);
305        self
306    }
307
308    /// Set `FAST_UNIQUE` metadata
309    /// # Safety
310    /// This invariant must hold `unique(categories) == unique(self)`
311    pub unsafe fn _with_fast_unique(self, toggle: bool) -> Self {
312        self.with_fast_unique(toggle)
313    }
314
315    /// Get a reference to the mapping of categorical types to the string values.
316    pub fn get_rev_map(&self) -> &Arc<RevMapping> {
317        if let DataType::Categorical(Some(rev_map), _) | DataType::Enum(Some(rev_map), _) =
318            &self.physical.2.as_ref().unwrap()
319        {
320            rev_map
321        } else {
322            panic!("implementation error")
323        }
324    }
325
326    /// Create an [`Iterator`] that iterates over the `&str` values of the [`CategoricalChunked`].
327    pub fn iter_str(&self) -> CatIter<'_> {
328        let iter = self.physical().into_iter();
329        CatIter {
330            rev: self.get_rev_map(),
331            iter,
332        }
333    }
334}
335
336impl LogicalType for CategoricalChunked {
337    fn dtype(&self) -> &DataType {
338        self.physical.2.as_ref().unwrap()
339    }
340
341    fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
342        polars_ensure!(i < self.len(), oob = i, self.len());
343        Ok(unsafe { self.get_any_value_unchecked(i) })
344    }
345
346    unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
347        match self.physical.0.get_unchecked(i) {
348            Some(i) => match self.dtype() {
349                DataType::Enum(_, _) => AnyValue::Enum(i, self.get_rev_map(), SyncPtr::new_null()),
350                DataType::Categorical(_, _) => {
351                    AnyValue::Categorical(i, self.get_rev_map(), SyncPtr::new_null())
352                },
353                _ => unimplemented!(),
354            },
355            None => AnyValue::Null,
356        }
357    }
358
359    fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
360        match dtype {
361            DataType::String => {
362                let mapping = &**self.get_rev_map();
363
364                let mut builder =
365                    StringChunkedBuilder::new(self.physical.name().clone(), self.len());
366
367                let f = |idx: u32| mapping.get(idx);
368
369                if !self.physical.has_nulls() {
370                    self.physical
371                        .into_no_null_iter()
372                        .for_each(|idx| builder.append_value(f(idx)));
373                } else {
374                    self.physical.into_iter().for_each(|opt_idx| {
375                        builder.append_option(opt_idx.map(f));
376                    });
377                }
378
379                let ca = builder.finish();
380                Ok(ca.into_series())
381            },
382            DataType::UInt32 => {
383                let ca = unsafe {
384                    UInt32Chunked::from_chunks(
385                        self.physical.name().clone(),
386                        self.physical.chunks.clone(),
387                    )
388                };
389                Ok(ca.into_series())
390            },
391            #[cfg(feature = "dtype-categorical")]
392            DataType::Enum(Some(rev_map), ordering) => {
393                let RevMapping::Local(categories, hash) = &**rev_map else {
394                    polars_bail!(ComputeError: "can not cast to enum with global mapping")
395                };
396                Ok(self
397                    .to_enum(categories, *hash)
398                    .set_ordering(*ordering, true)
399                    .into_series()
400                    .with_name(self.name().clone()))
401            },
402            DataType::Enum(None, _) => {
403                polars_bail!(ComputeError: "can not cast to enum without categories present")
404            },
405            #[cfg(feature = "dtype-categorical")]
406            DataType::Categorical(rev_map, ordering) => {
407                // Casting from an Enum to a local or global
408                if matches!(self.dtype(), DataType::Enum(_, _)) && rev_map.is_none() {
409                    if using_string_cache() {
410                        return Ok(self
411                            .to_global()?
412                            .set_ordering(*ordering, true)
413                            .into_series());
414                    } else {
415                        return Ok(self.to_local().set_ordering(*ordering, true).into_series());
416                    }
417                }
418                // If casting to lexical categorical, set sorted flag as not set
419
420                let mut ca = self.clone().set_ordering(*ordering, true);
421                if ca.uses_lexical_ordering() {
422                    ca.physical.set_sorted_flag(IsSorted::Not);
423                }
424                Ok(ca.into_series())
425            },
426            dt if dt.is_primitive_numeric() => {
427                // Apply the cast to the categories and then index into the casted series.
428                // This has to be local for the gather.
429                let slf = self.to_local();
430                let categories = StringChunked::with_chunk(
431                    slf.physical.name().clone(),
432                    slf.get_rev_map().get_categories().clone(),
433                );
434                let casted_series = categories.cast_with_options(dtype, options)?;
435
436                #[cfg(feature = "bigidx")]
437                {
438                    let s = slf.physical.cast_with_options(&DataType::UInt64, options)?;
439                    Ok(unsafe { casted_series.take_unchecked(s.u64()?) })
440                }
441                #[cfg(not(feature = "bigidx"))]
442                {
443                    // SAFETY: Invariant of categorical means indices are in bound
444                    Ok(unsafe { casted_series.take_unchecked(&slf.physical) })
445                }
446            },
447            _ => self.physical.cast_with_options(dtype, options),
448        }
449    }
450}
451
452pub struct CatIter<'a> {
453    rev: &'a RevMapping,
454    iter: Box<dyn PolarsIterator<Item = Option<u32>> + 'a>,
455}
456
457unsafe impl TrustedLen for CatIter<'_> {}
458
459impl<'a> Iterator for CatIter<'a> {
460    type Item = Option<&'a str>;
461
462    fn next(&mut self) -> Option<Self::Item> {
463        self.iter.next().map(|item| {
464            item.map(|idx| {
465                // SAFETY:
466                // all categories are in bound
467                unsafe { self.rev.get_unchecked(idx) }
468            })
469        })
470    }
471
472    fn size_hint(&self) -> (usize, Option<usize>) {
473        self.iter.size_hint()
474    }
475}
476
477impl DoubleEndedIterator for CatIter<'_> {
478    fn next_back(&mut self) -> Option<Self::Item> {
479        self.iter.next_back().map(|item| {
480            item.map(|idx| {
481                // SAFETY:
482                // all categories are in bound
483                unsafe { self.rev.get_unchecked(idx) }
484            })
485        })
486    }
487}
488
489impl ExactSizeIterator for CatIter<'_> {}
490
491#[cfg(test)]
492mod test {
493    use super::*;
494    use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache};
495
496    #[test]
497    fn test_categorical_round_trip() -> PolarsResult<()> {
498        let _lock = SINGLE_LOCK.lock();
499        disable_string_cache();
500        let slice = &[
501            Some("foo"),
502            None,
503            Some("bar"),
504            Some("foo"),
505            Some("foo"),
506            Some("bar"),
507        ];
508        let ca = StringChunked::new(PlSmallStr::from_static("a"), slice);
509        let ca = ca.cast(&DataType::Categorical(None, Default::default()))?;
510        let ca = ca.categorical().unwrap();
511
512        let arr = ca.to_arrow(CompatLevel::newest(), false);
513        let s = Series::try_from((PlSmallStr::from_static("foo"), arr))?;
514        assert!(matches!(s.dtype(), &DataType::Categorical(_, _)));
515        assert_eq!(s.null_count(), 1);
516        assert_eq!(s.len(), 6);
517
518        Ok(())
519    }
520
521    #[test]
522    fn test_append_categorical() {
523        let _lock = SINGLE_LOCK.lock();
524        disable_string_cache();
525        enable_string_cache();
526
527        let mut s1 = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"])
528            .cast(&DataType::Categorical(None, Default::default()))
529            .unwrap();
530        let s2 = Series::new(PlSmallStr::from_static("2"), vec!["a", "x", "y"])
531            .cast(&DataType::Categorical(None, Default::default()))
532            .unwrap();
533        let appended = s1.append(&s2).unwrap();
534        assert_eq!(appended.str_value(0).unwrap(), "a");
535        assert_eq!(appended.str_value(1).unwrap(), "b");
536        assert_eq!(appended.str_value(4).unwrap(), "x");
537        assert_eq!(appended.str_value(5).unwrap(), "y");
538    }
539
540    #[test]
541    fn test_fast_unique() {
542        let _lock = SINGLE_LOCK.lock();
543        let s = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"])
544            .cast(&DataType::Categorical(None, Default::default()))
545            .unwrap();
546
547        assert_eq!(s.n_unique().unwrap(), 3);
548        // Make sure that it does not take the fast path after take/slice.
549        let out = s.take(&IdxCa::new(PlSmallStr::EMPTY, [1, 2])).unwrap();
550        assert_eq!(out.n_unique().unwrap(), 2);
551        let out = s.slice(1, 2);
552        assert_eq!(out.n_unique().unwrap(), 2);
553    }
554
555    #[test]
556    fn test_categorical_flow() -> PolarsResult<()> {
557        let _lock = SINGLE_LOCK.lock();
558        disable_string_cache();
559
560        // tests several things that may lose the dtype information
561        let s = Series::new(PlSmallStr::from_static("a"), vec!["a", "b", "c"])
562            .cast(&DataType::Categorical(None, Default::default()))?;
563
564        assert_eq!(
565            s.field().into_owned(),
566            Field::new(
567                PlSmallStr::from_static("a"),
568                DataType::Categorical(None, Default::default())
569            )
570        );
571        assert!(matches!(
572            s.get(0)?,
573            AnyValue::Categorical(0, RevMapping::Local(_, _), _)
574        ));
575
576        let groups = s.group_tuples(false, true);
577        let aggregated = unsafe { s.agg_list(&groups?) };
578        match aggregated.get(0)? {
579            AnyValue::List(s) => {
580                assert!(matches!(s.dtype(), DataType::Categorical(_, _)));
581                let str_s = s.cast(&DataType::String).unwrap();
582                assert_eq!(str_s.get(0)?, AnyValue::String("a"));
583                assert_eq!(s.len(), 1);
584            },
585            _ => panic!(),
586        }
587        let flat = aggregated.explode()?;
588        let ca = flat.categorical().unwrap();
589        let vals = ca.iter_str().map(|v| v.unwrap()).collect::<Vec<_>>();
590        assert_eq!(vals, &["a", "b", "c"]);
591        Ok(())
592    }
593}