polars_core/chunked_array/logical/
categorical.rs

1use std::marker::PhantomData;
2
3use arrow::bitmap::BitmapBuilder;
4use num_traits::Zero;
5
6use crate::chunked_array::cast::CastOptions;
7use crate::chunked_array::flags::StatisticsFlags;
8use crate::chunked_array::ops::ChunkFullNull;
9use crate::prelude::*;
10use crate::series::IsSorted;
11use crate::utils::handle_casting_failures;
12
13pub type CategoricalChunked<T> = Logical<T, <T as PolarsCategoricalType>::PolarsPhysical>;
14pub type Categorical8Chunked = CategoricalChunked<Categorical8Type>;
15pub type Categorical16Chunked = CategoricalChunked<Categorical16Type>;
16pub type Categorical32Chunked = CategoricalChunked<Categorical32Type>;
17
18pub trait CategoricalPhysicalDtypeExt {
19    fn dtype(&self) -> DataType;
20}
21
22impl CategoricalPhysicalDtypeExt for CategoricalPhysical {
23    fn dtype(&self) -> DataType {
24        match self {
25            Self::U8 => DataType::UInt8,
26            Self::U16 => DataType::UInt16,
27            Self::U32 => DataType::UInt32,
28        }
29    }
30}
31
32impl<T: PolarsCategoricalType> CategoricalChunked<T> {
33    pub fn is_enum(&self) -> bool {
34        matches!(self.dtype(), DataType::Enum(_, _))
35    }
36
37    pub(crate) fn get_flags(&self) -> StatisticsFlags {
38        // If we use lexical ordering then physical sortedness does not imply
39        // our sortedness.
40        let mut flags = self.phys.get_flags();
41        if self.uses_lexical_ordering() {
42            flags.set_sorted(IsSorted::Not);
43        }
44        flags
45    }
46
47    /// Set flags for the ChunkedArray.
48    pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) {
49        // We should not set the sorted flag if we are sorting in lexical order.
50        if self.uses_lexical_ordering() {
51            flags.set_sorted(IsSorted::Not)
52        }
53        self.physical_mut().set_flags(flags)
54    }
55
56    /// Return whether or not the [`CategoricalChunked`] uses the lexical order
57    /// of the string values when sorting.
58    pub fn uses_lexical_ordering(&self) -> bool {
59        !self.is_enum()
60    }
61
62    pub fn full_null_with_dtype(name: PlSmallStr, length: usize, dtype: DataType) -> Self {
63        let phys =
64            ChunkedArray::<<T as PolarsCategoricalType>::PolarsPhysical>::full_null(name, length);
65        unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) }
66    }
67
68    /// Create a [`CategoricalChunked`] from a physical array and dtype.
69    ///
70    /// Checks that all the category ids are valid, mapping invalid ones to nulls.
71    pub fn from_cats_and_dtype(
72        mut cat_ids: ChunkedArray<T::PolarsPhysical>,
73        dtype: DataType,
74    ) -> Self {
75        let (DataType::Enum(_, mapping) | DataType::Categorical(_, mapping)) = &dtype else {
76            panic!("from_cats_and_dtype called on non-categorical type")
77        };
78        assert!(dtype.cat_physical().ok() == Some(T::physical()));
79
80        unsafe {
81            let mut invariants_violated = false;
82            let mut validity = BitmapBuilder::new();
83            for arr in cat_ids.downcast_iter_mut() {
84                validity.reserve(arr.len());
85                if arr.has_nulls() {
86                    for opt_cat_id in arr.iter() {
87                        if let Some(cat_id) = opt_cat_id {
88                            validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
89                        } else {
90                            validity.push_unchecked(false);
91                        }
92                    }
93                } else {
94                    for cat_id in arr.values_iter() {
95                        validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
96                    }
97                }
98
99                if arr.null_count() != validity.unset_bits() {
100                    invariants_violated = true;
101                    arr.set_validity(core::mem::take(&mut validity).into_opt_validity());
102                } else {
103                    validity.clear();
104                }
105            }
106
107            if invariants_violated {
108                cat_ids.set_flags(StatisticsFlags::empty());
109                cat_ids.compute_len();
110            }
111        }
112
113        Self {
114            phys: cat_ids,
115            dtype,
116            _phantom: PhantomData,
117        }
118    }
119
120    /// Create a [`CategoricalChunked`] from a physical array and dtype.
121    ///
122    /// # Safety
123    /// It's not checked that the indices are in-bounds or that the dtype is correct.
124    pub unsafe fn from_cats_and_dtype_unchecked(
125        cat_ids: ChunkedArray<T::PolarsPhysical>,
126        dtype: DataType,
127    ) -> Self {
128        debug_assert!(dtype.cat_physical().ok() == Some(T::physical()));
129
130        Self {
131            phys: cat_ids,
132            dtype,
133            _phantom: PhantomData,
134        }
135    }
136
137    /// Get a reference to the mapping of categorical types to the string values.
138    pub fn get_mapping(&self) -> &Arc<CategoricalMapping> {
139        let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = self.dtype() else {
140            unreachable!()
141        };
142        mapping
143    }
144
145    /// Create an [`Iterator`] that iterates over the `&str` values of the [`CategoricalChunked`].
146    pub fn iter_str(&self) -> impl PolarsIterator<Item = Option<&str>> {
147        let mapping = self.get_mapping();
148        self.phys
149            .iter()
150            .map(|cat| unsafe { Some(mapping.cat_to_str_unchecked(cat?.as_cat())) })
151    }
152
153    /// Converts from strings to this CategoricalChunked.
154    ///
155    /// If this dtype is an Enum any non-existing strings get mapped to null.
156    pub fn from_str_iter<'a, I: IntoIterator<Item = Option<&'a str>>>(
157        name: PlSmallStr,
158        dtype: DataType,
159        strings: I,
160    ) -> PolarsResult<Self> {
161        let strings = strings.into_iter();
162
163        let hint = strings.size_hint().0;
164        let mut cat_ids = Vec::with_capacity(hint);
165        let mut validity = BitmapBuilder::with_capacity(hint);
166
167        match &dtype {
168            DataType::Categorical(cats, mapping) => {
169                assert!(cats.physical() == T::physical());
170                for opt_s in strings {
171                    cat_ids.push(if let Some(s) = opt_s {
172                        T::Native::from_cat(mapping.insert_cat(s)?)
173                    } else {
174                        T::Native::zero()
175                    });
176                    validity.push(opt_s.is_some());
177                }
178            },
179            DataType::Enum(fcats, mapping) => {
180                assert!(fcats.physical() == T::physical());
181                for opt_s in strings {
182                    cat_ids.push(if let Some(cat) = opt_s.and_then(|s| mapping.get_cat(s)) {
183                        validity.push(true);
184                        T::Native::from_cat(cat)
185                    } else {
186                        validity.push(false);
187                        T::Native::zero()
188                    });
189                }
190            },
191            _ => panic!("from_strings_and_dtype_strict called on non-categorical type"),
192        }
193
194        let arr = <T::PolarsPhysical as PolarsDataType>::Array::from_vec(cat_ids)
195            .with_validity(validity.into_opt_validity());
196        let phys = ChunkedArray::<T::PolarsPhysical>::with_chunk(name, arr);
197        Ok(unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) })
198    }
199
200    pub fn to_arrow(&self, compat_level: CompatLevel) -> DictionaryArray<T::Native> {
201        let keys = self.physical().rechunk();
202        let keys = keys.downcast_as_array();
203        let values = self
204            .get_mapping()
205            .to_arrow(compat_level != CompatLevel::oldest());
206        let values_dtype = Box::new(values.dtype().clone());
207        let dtype =
208            ArrowDataType::Dictionary(<T::Native as DictionaryKey>::KEY_TYPE, values_dtype, false);
209        unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() }
210    }
211}
212
213impl<T: PolarsCategoricalType> LogicalType for CategoricalChunked<T> {
214    fn dtype(&self) -> &DataType {
215        &self.dtype
216    }
217
218    fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
219        polars_ensure!(i < self.len(), oob = i, self.len());
220        Ok(unsafe { self.get_any_value_unchecked(i) })
221    }
222
223    unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
224        match self.phys.get_unchecked(i) {
225            Some(i) => match &self.dtype {
226                DataType::Enum(_, mapping) => AnyValue::Enum(i.as_cat(), mapping),
227                DataType::Categorical(_, mapping) => AnyValue::Categorical(i.as_cat(), mapping),
228                _ => unreachable!(),
229            },
230            None => AnyValue::Null,
231        }
232    }
233
234    fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
235        if &self.dtype == dtype {
236            return Ok(self.clone().into_series());
237        }
238
239        match dtype {
240            DataType::String => {
241                let mapping = self.get_mapping();
242
243                // TODO @ cat-rework:, if len >= mapping.upper_bound(), cast categories to ViewArray, then construct array of Views.
244
245                let mut builder = StringChunkedBuilder::new(self.phys.name().clone(), self.len());
246                let to_str = |cat_id: CatSize| unsafe { mapping.cat_to_str_unchecked(cat_id) };
247                if !self.phys.has_nulls() {
248                    for cat_id in self.phys.into_no_null_iter() {
249                        builder.append_value(to_str(cat_id.as_cat()));
250                    }
251                } else {
252                    for opt_cat_id in self.phys.into_iter() {
253                        let opt_cat_id: Option<_> = opt_cat_id;
254                        builder.append_option(opt_cat_id.map(|c| to_str(c.as_cat())));
255                    }
256                }
257
258                let ca = builder.finish();
259                Ok(ca.into_series())
260            },
261
262            DataType::Enum(fcats, _mapping) => {
263                // TODO @ cat-rework: if len >= self.mapping().upper_bound(), remap categories then index into array.
264                let ret = with_match_categorical_physical_type!(fcats.physical(), |$C| {
265                    CategoricalChunked::<$C>::from_str_iter(
266                        self.name().clone(),
267                        dtype.clone(),
268                        self.iter_str()
269                    )?.into_series()
270                });
271
272                if options.is_strict() && self.null_count() != ret.null_count() {
273                    handle_casting_failures(&self.clone().into_series(), &ret)?;
274                }
275
276                Ok(ret)
277            },
278
279            DataType::Categorical(cats, _mapping) => {
280                // TODO @ cat-rework: if len >= self.mapping().upper_bound(), remap categories then index into array.
281                Ok(
282                    with_match_categorical_physical_type!(cats.physical(), |$C| {
283                        CategoricalChunked::<$C>::from_str_iter(
284                            self.name().clone(),
285                            dtype.clone(),
286                            self.iter_str()
287                        )?.into_series()
288                    }),
289                )
290            },
291
292            // LEGACY
293            // TODO @ cat-rework: remove after exposing to/from physical functions.
294            dt if dt.is_integer() => self.phys.clone().cast_with_options(dtype, options),
295
296            _ => polars_bail!(ComputeError: "cannot cast categorical types to {dtype:?}"),
297        }
298    }
299}