polars_core/chunked_array/logical/categorical/
mod.rs

1mod builder;
2mod from;
3mod merge;
4mod ops;
5pub mod revmap;
6pub mod string_cache;
7
8use bitflags::bitflags;
9pub use builder::*;
10pub use merge::*;
11use polars_utils::itertools::Itertools;
12use polars_utils::sync::SyncPtr;
13pub use revmap::*;
14
15use super::*;
16use crate::chunked_array::cast::CastOptions;
17use crate::chunked_array::flags::StatisticsFlags;
18use crate::prelude::*;
19use crate::series::IsSorted;
20use crate::using_string_cache;
21
22bitflags! {
23    #[derive(Default, Clone)]
24    struct BitSettings: u8 {
25        const ORIGINAL = 0x01;
26    }
27}
28
29#[derive(Clone)]
30pub struct CategoricalChunked {
31    physical: Logical<CategoricalType, UInt32Type>,
32    /// 1st bit: original local categorical
33    ///             meaning that n_unique is the same as the cat map length
34    bit_settings: BitSettings,
35}
36
37impl CategoricalChunked {
38    pub(crate) fn field(&self) -> Field {
39        let name = self.physical().name();
40        Field::new(name.clone(), self.dtype().clone())
41    }
42
43    pub fn is_empty(&self) -> bool {
44        self.len() == 0
45    }
46
47    #[inline]
48    pub fn len(&self) -> usize {
49        self.physical.len()
50    }
51
52    #[inline]
53    pub fn null_count(&self) -> usize {
54        self.physical.null_count()
55    }
56
57    pub fn name(&self) -> &PlSmallStr {
58        self.physical.name()
59    }
60
61    /// Get the physical array (the category indexes).
62    pub fn into_physical(self) -> UInt32Chunked {
63        self.physical.phys
64    }
65
66    // TODO: Rename this
67    /// Get a reference to the physical array (the categories).
68    pub fn physical(&self) -> &UInt32Chunked {
69        &self.physical
70    }
71
72    /// Get a mutable reference to the physical array (the categories).
73    pub(crate) fn physical_mut(&mut self) -> &mut UInt32Chunked {
74        &mut self.physical
75    }
76
77    pub fn is_enum(&self) -> bool {
78        matches!(self.dtype(), DataType::Enum(_, _))
79    }
80
81    /// Convert a categorical column to its local representation.
82    pub fn to_local(&self) -> Self {
83        let rev_map = self.get_rev_map();
84        let (physical_map, categories) = match rev_map.as_ref() {
85            RevMapping::Global(m, c, _) => (m, c),
86            RevMapping::Local(_, _) if !self.is_enum() => return self.clone(),
87            RevMapping::Local(_, _) => {
88                // Change dtype from Enum to Categorical
89                let mut local = self.clone();
90                local.physical.dtype =
91                    DataType::Categorical(Some(rev_map.clone()), self.get_ordering());
92                return local;
93            },
94        };
95
96        let local_rev_map = RevMapping::build_local(categories.clone());
97        // TODO: A fast path can possibly be implemented here:
98        // if all physical map keys are equal to their values,
99        // we can skip the apply and only update the rev_map
100        let local_ca = self
101            .physical()
102            .apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap()));
103
104        let mut out = unsafe {
105            Self::from_cats_and_rev_map_unchecked(
106                local_ca,
107                local_rev_map.into(),
108                false,
109                self.get_ordering(),
110            )
111        };
112        out.set_fast_unique(self._can_fast_unique());
113
114        out
115    }
116
117    pub fn to_global(&self) -> PolarsResult<Self> {
118        polars_ensure!(using_string_cache(), string_cache_mismatch);
119        // Fast path
120        let categories = match &**self.get_rev_map() {
121            RevMapping::Global(_, _, _) => return Ok(self.clone()),
122            RevMapping::Local(categories, _) => categories,
123        };
124
125        // SAFETY: keys and values are in bounds
126        unsafe {
127            Ok(CategoricalChunked::from_keys_and_values_global(
128                self.name().clone(),
129                self.physical(),
130                self.len(),
131                categories,
132                self.get_ordering(),
133            ))
134        }
135    }
136
137    // Convert to fixed enum. Values not in categories are mapped to None.
138    pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> Self {
139        // Fast paths
140        match self.get_rev_map().as_ref() {
141            RevMapping::Local(_, cur_hash) if hash == *cur_hash => {
142                return unsafe {
143                    CategoricalChunked::from_cats_and_rev_map_unchecked(
144                        self.physical().clone(),
145                        self.get_rev_map().clone(),
146                        true,
147                        self.get_ordering(),
148                    )
149                };
150            },
151            _ => (),
152        };
153        // Make a mapping from old idx to new idx
154        let old_rev_map = self.get_rev_map();
155
156        // Create map of old category -> idx for fast lookup.
157        let old_categories = old_rev_map.get_categories();
158        let old_idx_map: PlHashMap<&str, u32> = old_categories
159            .values_iter()
160            .zip(0..old_categories.len() as u32)
161            .collect();
162
163        #[allow(clippy::unnecessary_cast)]
164        let idx_map: PlHashMap<u32, u32> = categories
165            .values_iter()
166            .enumerate_idx()
167            .filter_map(|(new_idx, s)| old_idx_map.get(s).map(|old_idx| (*old_idx, new_idx as u32)))
168            .collect();
169
170        // Loop over the physicals and try get new idx
171        let new_phys: UInt32Chunked = self
172            .physical()
173            .into_iter()
174            .map(|opt_v: Option<u32>| opt_v.and_then(|v| idx_map.get(&v).copied()))
175            .collect();
176
177        // SAFETY: we created the physical from the enum categories
178        unsafe {
179            CategoricalChunked::from_cats_and_rev_map_unchecked(
180                new_phys,
181                Arc::new(RevMapping::Local(categories.clone(), hash)),
182                true,
183                self.get_ordering(),
184            )
185        }
186    }
187
188    pub(crate) fn get_flags(&self) -> StatisticsFlags {
189        self.physical().get_flags()
190    }
191
192    /// Set flags for the Chunked Array
193    pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) {
194        // We should not set the sorted flag if we are sorting in lexical order
195        if self.uses_lexical_ordering() {
196            flags.set_sorted(IsSorted::Not)
197        }
198        self.physical_mut().set_flags(flags)
199    }
200
201    /// Return whether or not the [`CategoricalChunked`] uses the lexical order
202    /// of the string values when sorting.
203    pub fn uses_lexical_ordering(&self) -> bool {
204        self.get_ordering() == CategoricalOrdering::Lexical
205    }
206
207    pub fn get_ordering(&self) -> CategoricalOrdering {
208        if let DataType::Categorical(_, ordering) | DataType::Enum(_, ordering) =
209            &self.physical.dtype
210        {
211            *ordering
212        } else {
213            panic!("implementation error")
214        }
215    }
216
217    /// Create a [`CategoricalChunked`] from a physical array and dtype.
218    ///
219    /// # Safety
220    /// It's not checked that the indices are in-bounds or that the dtype is
221    /// correct.
222    pub unsafe fn from_cats_and_dtype_unchecked(idx: UInt32Chunked, dtype: DataType) -> Self {
223        debug_assert!(matches!(
224            dtype,
225            DataType::Enum { .. } | DataType::Categorical { .. }
226        ));
227        Self {
228            physical: Logical::new_logical(idx, dtype),
229            bit_settings: Default::default(),
230        }
231    }
232
233    /// Create a [`CategoricalChunked`] from an array of `idx` and an existing [`RevMapping`]:  `rev_map`.
234    ///
235    /// # Safety
236    /// Invariant in `v < rev_map.len() for v in idx` must hold.
237    pub unsafe fn from_cats_and_rev_map_unchecked(
238        idx: UInt32Chunked,
239        rev_map: Arc<RevMapping>,
240        is_enum: bool,
241        ordering: CategoricalOrdering,
242    ) -> Self {
243        let dtype = if is_enum {
244            DataType::Enum(Some(rev_map), ordering)
245        } else {
246            DataType::Categorical(Some(rev_map), ordering)
247        };
248        Self {
249            physical: Logical::new_logical(idx, dtype),
250            bit_settings: Default::default(),
251        }
252    }
253
254    pub(crate) fn set_ordering(
255        mut self,
256        ordering: CategoricalOrdering,
257        keep_fast_unique: bool,
258    ) -> Self {
259        self.physical.dtype = match self.dtype() {
260            DataType::Enum(_, _) => DataType::Enum(Some(self.get_rev_map().clone()), ordering),
261            DataType::Categorical(_, _) => {
262                DataType::Categorical(Some(self.get_rev_map().clone()), ordering)
263            },
264            _ => panic!("implementation error"),
265        };
266
267        if !keep_fast_unique {
268            self.set_fast_unique(false)
269        }
270        self
271    }
272
273    /// # Safety
274    /// The existing index values must be in bounds of the new [`RevMapping`].
275    pub(crate) unsafe fn set_rev_map(&mut self, rev_map: Arc<RevMapping>, keep_fast_unique: bool) {
276        self.physical.dtype = match self.dtype() {
277            DataType::Enum(_, _) => DataType::Enum(Some(rev_map), self.get_ordering()),
278            DataType::Categorical(_, _) => {
279                DataType::Categorical(Some(rev_map), self.get_ordering())
280            },
281            _ => panic!("implementation error"),
282        };
283
284        if !keep_fast_unique {
285            self.set_fast_unique(false)
286        }
287    }
288
289    /// True if all categories are represented in this array. When this is the case, the unique
290    /// values of the array are the categories.
291    pub fn _can_fast_unique(&self) -> bool {
292        self.bit_settings.contains(BitSettings::ORIGINAL)
293            && self.physical.chunks.len() == 1
294            && self.null_count() == 0
295    }
296
297    pub(crate) fn set_fast_unique(&mut self, toggle: bool) {
298        if toggle {
299            self.bit_settings.insert(BitSettings::ORIGINAL);
300        } else {
301            self.bit_settings.remove(BitSettings::ORIGINAL);
302        }
303    }
304
305    /// Set `FAST_UNIQUE` metadata
306    /// # Safety
307    /// This invariant must hold `unique(categories) == unique(self)`
308    pub(crate) unsafe fn with_fast_unique(mut self, toggle: bool) -> Self {
309        self.set_fast_unique(toggle);
310        self
311    }
312
313    /// Set `FAST_UNIQUE` metadata
314    /// # Safety
315    /// This invariant must hold `unique(categories) == unique(self)`
316    pub unsafe fn _with_fast_unique(self, toggle: bool) -> Self {
317        self.with_fast_unique(toggle)
318    }
319
320    /// Get a reference to the mapping of categorical types to the string values.
321    pub fn get_rev_map(&self) -> &Arc<RevMapping> {
322        if let DataType::Categorical(Some(rev_map), _) | DataType::Enum(Some(rev_map), _) =
323            &self.physical.dtype
324        {
325            rev_map
326        } else {
327            panic!("implementation error")
328        }
329    }
330
331    /// Create an [`Iterator`] that iterates over the `&str` values of the [`CategoricalChunked`].
332    pub fn iter_str(&self) -> CatIter<'_> {
333        let iter = self.physical().into_iter();
334        CatIter {
335            rev: self.get_rev_map(),
336            iter,
337        }
338    }
339}
340
341impl LogicalType for CategoricalChunked {
342    fn dtype(&self) -> &DataType {
343        &self.physical.dtype
344    }
345
346    fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
347        polars_ensure!(i < self.len(), oob = i, self.len());
348        Ok(unsafe { self.get_any_value_unchecked(i) })
349    }
350
351    unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
352        match self.physical.phys.get_unchecked(i) {
353            Some(i) => match self.dtype() {
354                DataType::Enum(_, _) => AnyValue::Enum(i, self.get_rev_map(), SyncPtr::new_null()),
355                DataType::Categorical(_, _) => {
356                    AnyValue::Categorical(i, self.get_rev_map(), SyncPtr::new_null())
357                },
358                _ => unimplemented!(),
359            },
360            None => AnyValue::Null,
361        }
362    }
363
364    fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
365        match dtype {
366            DataType::String => {
367                let mapping = &**self.get_rev_map();
368
369                let mut builder =
370                    StringChunkedBuilder::new(self.physical.name().clone(), self.len());
371
372                let f = |idx: u32| mapping.get(idx);
373
374                if !self.physical.has_nulls() {
375                    self.physical
376                        .into_no_null_iter()
377                        .for_each(|idx| builder.append_value(f(idx)));
378                } else {
379                    self.physical.into_iter().for_each(|opt_idx| {
380                        builder.append_option(opt_idx.map(f));
381                    });
382                }
383
384                let ca = builder.finish();
385                Ok(ca.into_series())
386            },
387            DataType::UInt32 => {
388                let ca = unsafe {
389                    UInt32Chunked::from_chunks(
390                        self.physical.name().clone(),
391                        self.physical.chunks.clone(),
392                    )
393                };
394                Ok(ca.into_series())
395            },
396            #[cfg(feature = "dtype-categorical")]
397            DataType::Enum(Some(rev_map), ordering) => {
398                let RevMapping::Local(categories, hash) = &**rev_map else {
399                    polars_bail!(ComputeError: "can not cast to enum with global mapping")
400                };
401                Ok(self
402                    .to_enum(categories, *hash)
403                    .set_ordering(*ordering, true)
404                    .into_series()
405                    .with_name(self.name().clone()))
406            },
407            DataType::Enum(None, _) => {
408                polars_bail!(ComputeError: "can not cast to enum without categories present")
409            },
410            #[cfg(feature = "dtype-categorical")]
411            DataType::Categorical(rev_map, ordering) => {
412                // Casting from an Enum to a local or global
413                if matches!(self.dtype(), DataType::Enum(_, _)) && rev_map.is_none() {
414                    if using_string_cache() {
415                        return Ok(self
416                            .to_global()?
417                            .set_ordering(*ordering, true)
418                            .into_series());
419                    } else {
420                        return Ok(self.to_local().set_ordering(*ordering, true).into_series());
421                    }
422                }
423                // If casting to lexical categorical, set sorted flag as not set
424
425                let mut ca = self.clone().set_ordering(*ordering, true);
426                if ca.uses_lexical_ordering() {
427                    ca.physical.set_sorted_flag(IsSorted::Not);
428                }
429                Ok(ca.into_series())
430            },
431            dt if dt.is_primitive_numeric() => {
432                // Apply the cast to the categories and then index into the casted series.
433                // This has to be local for the gather.
434                let slf = self.to_local();
435                let categories = StringChunked::with_chunk(
436                    slf.physical.name().clone(),
437                    slf.get_rev_map().get_categories().clone(),
438                );
439                let casted_series = categories.cast_with_options(dtype, options)?;
440
441                #[cfg(feature = "bigidx")]
442                {
443                    let s = slf.physical.cast_with_options(&DataType::UInt64, options)?;
444                    Ok(unsafe { casted_series.take_unchecked(s.u64()?) })
445                }
446                #[cfg(not(feature = "bigidx"))]
447                {
448                    // SAFETY: Invariant of categorical means indices are in bound
449                    Ok(unsafe { casted_series.take_unchecked(&slf.physical) })
450                }
451            },
452            _ => self.physical.cast_with_options(dtype, options),
453        }
454    }
455}
456
457pub struct CatIter<'a> {
458    rev: &'a RevMapping,
459    iter: Box<dyn PolarsIterator<Item = Option<u32>> + 'a>,
460}
461
462unsafe impl TrustedLen for CatIter<'_> {}
463
464impl<'a> Iterator for CatIter<'a> {
465    type Item = Option<&'a str>;
466
467    fn next(&mut self) -> Option<Self::Item> {
468        self.iter.next().map(|item| {
469            item.map(|idx| {
470                // SAFETY:
471                // all categories are in bound
472                unsafe { self.rev.get_unchecked(idx) }
473            })
474        })
475    }
476
477    fn size_hint(&self) -> (usize, Option<usize>) {
478        self.iter.size_hint()
479    }
480}
481
482impl DoubleEndedIterator for CatIter<'_> {
483    fn next_back(&mut self) -> Option<Self::Item> {
484        self.iter.next_back().map(|item| {
485            item.map(|idx| {
486                // SAFETY:
487                // all categories are in bound
488                unsafe { self.rev.get_unchecked(idx) }
489            })
490        })
491    }
492}
493
494impl ExactSizeIterator for CatIter<'_> {}
495
496#[cfg(test)]
497mod test {
498    use super::*;
499    use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache};
500
501    #[test]
502    fn test_categorical_round_trip() -> PolarsResult<()> {
503        let _lock = SINGLE_LOCK.lock();
504        disable_string_cache();
505        let slice = &[
506            Some("foo"),
507            None,
508            Some("bar"),
509            Some("foo"),
510            Some("foo"),
511            Some("bar"),
512        ];
513        let ca = StringChunked::new(PlSmallStr::from_static("a"), slice);
514        let ca = ca.cast(&DataType::Categorical(None, Default::default()))?;
515        let ca = ca.categorical().unwrap();
516
517        let arr = ca.to_arrow(CompatLevel::newest(), false);
518        let s = Series::try_from((PlSmallStr::from_static("foo"), arr))?;
519        assert!(matches!(s.dtype(), &DataType::Categorical(_, _)));
520        assert_eq!(s.null_count(), 1);
521        assert_eq!(s.len(), 6);
522
523        Ok(())
524    }
525
526    #[test]
527    fn test_append_categorical() {
528        let _lock = SINGLE_LOCK.lock();
529        disable_string_cache();
530        enable_string_cache();
531
532        let mut s1 = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"])
533            .cast(&DataType::Categorical(None, Default::default()))
534            .unwrap();
535        let s2 = Series::new(PlSmallStr::from_static("2"), vec!["a", "x", "y"])
536            .cast(&DataType::Categorical(None, Default::default()))
537            .unwrap();
538        let appended = s1.append(&s2).unwrap();
539        assert_eq!(appended.str_value(0).unwrap(), "a");
540        assert_eq!(appended.str_value(1).unwrap(), "b");
541        assert_eq!(appended.str_value(4).unwrap(), "x");
542        assert_eq!(appended.str_value(5).unwrap(), "y");
543    }
544
545    #[test]
546    fn test_fast_unique() {
547        let _lock = SINGLE_LOCK.lock();
548        let s = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"])
549            .cast(&DataType::Categorical(None, Default::default()))
550            .unwrap();
551
552        assert_eq!(s.n_unique().unwrap(), 3);
553        // Make sure that it does not take the fast path after take/slice.
554        let out = s.take(&IdxCa::new(PlSmallStr::EMPTY, [1, 2])).unwrap();
555        assert_eq!(out.n_unique().unwrap(), 2);
556        let out = s.slice(1, 2);
557        assert_eq!(out.n_unique().unwrap(), 2);
558    }
559
560    #[test]
561    fn test_categorical_flow() -> PolarsResult<()> {
562        let _lock = SINGLE_LOCK.lock();
563        disable_string_cache();
564
565        // tests several things that may lose the dtype information
566        let s = Series::new(PlSmallStr::from_static("a"), vec!["a", "b", "c"])
567            .cast(&DataType::Categorical(None, Default::default()))?;
568
569        assert_eq!(
570            s.field().into_owned(),
571            Field::new(
572                PlSmallStr::from_static("a"),
573                DataType::Categorical(None, Default::default())
574            )
575        );
576        assert!(matches!(
577            s.get(0)?,
578            AnyValue::Categorical(0, RevMapping::Local(_, _), _)
579        ));
580
581        let groups = s.group_tuples(false, true);
582        let aggregated = unsafe { s.agg_list(&groups?) };
583        match aggregated.get(0)? {
584            AnyValue::List(s) => {
585                assert!(matches!(s.dtype(), DataType::Categorical(_, _)));
586                let str_s = s.cast(&DataType::String).unwrap();
587                assert_eq!(str_s.get(0)?, AnyValue::String("a"));
588                assert_eq!(s.len(), 1);
589            },
590            _ => panic!(),
591        }
592        let flat = aggregated.explode(false)?;
593        let ca = flat.categorical().unwrap();
594        let vals = ca.iter_str().map(|v| v.unwrap()).collect::<Vec<_>>();
595        assert_eq!(vals, &["a", "b", "c"]);
596        Ok(())
597    }
598}