polars_core/chunked_array/logical/
categorical.rs

1use std::marker::PhantomData;
2
3use arrow::bitmap::BitmapBuilder;
4use num_traits::Zero;
5
6use crate::chunked_array::cast::CastOptions;
7use crate::chunked_array::flags::StatisticsFlags;
8use crate::chunked_array::ops::ChunkFullNull;
9use crate::prelude::*;
10use crate::series::IsSorted;
11use crate::utils::handle_casting_failures;
12
13pub type CategoricalChunked<T> = Logical<T, <T as PolarsCategoricalType>::PolarsPhysical>;
14pub type Categorical8Chunked = CategoricalChunked<Categorical8Type>;
15pub type Categorical16Chunked = CategoricalChunked<Categorical16Type>;
16pub type Categorical32Chunked = CategoricalChunked<Categorical32Type>;
17
18pub trait CategoricalPhysicalDtypeExt {
19    fn dtype(&self) -> DataType;
20}
21
22impl CategoricalPhysicalDtypeExt for CategoricalPhysical {
23    fn dtype(&self) -> DataType {
24        match self {
25            Self::U8 => DataType::UInt8,
26            Self::U16 => DataType::UInt16,
27            Self::U32 => DataType::UInt32,
28        }
29    }
30}
31
32impl<T: PolarsCategoricalType> CategoricalChunked<T> {
33    pub fn is_enum(&self) -> bool {
34        matches!(self.dtype(), DataType::Enum(_, _))
35    }
36
37    pub(crate) fn get_flags(&self) -> StatisticsFlags {
38        // If we use lexical ordering then physical sortedness does not imply
39        // our sortedness.
40        let mut flags = self.phys.get_flags();
41        if self.uses_lexical_ordering() {
42            flags.set_sorted(IsSorted::Not);
43        }
44        flags
45    }
46
47    /// Set flags for the ChunkedArray.
48    pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) {
49        // We should not set the sorted flag if we are sorting in lexical order.
50        if self.uses_lexical_ordering() {
51            flags.set_sorted(IsSorted::Not)
52        }
53        self.physical_mut().set_flags(flags)
54    }
55
56    /// Return whether or not the [`CategoricalChunked`] uses the lexical order
57    /// of the string values when sorting.
58    pub fn uses_lexical_ordering(&self) -> bool {
59        !self.is_enum()
60    }
61
62    pub fn full_null_with_dtype(name: PlSmallStr, length: usize, dtype: DataType) -> Self {
63        let phys =
64            ChunkedArray::<<T as PolarsCategoricalType>::PolarsPhysical>::full_null(name, length);
65        unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) }
66    }
67
68    /// Create a [`CategoricalChunked`] from a physical array and dtype.
69    ///
70    /// Checks that all the category ids are valid, mapping invalid ones to nulls.
71    pub fn from_cats_and_dtype(
72        mut cat_ids: ChunkedArray<T::PolarsPhysical>,
73        dtype: DataType,
74    ) -> Self {
75        let (DataType::Enum(_, mapping) | DataType::Categorical(_, mapping)) = &dtype else {
76            panic!("from_cats_and_dtype called on non-categorical type")
77        };
78        assert!(dtype.cat_physical().ok() == Some(T::physical()));
79
80        unsafe {
81            let mut validity = BitmapBuilder::new();
82            for arr in cat_ids.downcast_iter_mut() {
83                validity.reserve(arr.len());
84                if arr.has_nulls() {
85                    for opt_cat_id in arr.iter() {
86                        if let Some(cat_id) = opt_cat_id {
87                            validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
88                        } else {
89                            validity.push_unchecked(false);
90                        }
91                    }
92                } else {
93                    for cat_id in arr.values_iter() {
94                        validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
95                    }
96                }
97
98                if arr.null_count() != validity.unset_bits() {
99                    arr.set_validity(core::mem::take(&mut validity).into_opt_validity());
100                } else {
101                    validity.clear();
102                }
103            }
104        }
105
106        Self {
107            phys: cat_ids,
108            dtype,
109            _phantom: PhantomData,
110        }
111    }
112
113    /// Create a [`CategoricalChunked`] from a physical array and dtype.
114    ///
115    /// # Safety
116    /// It's not checked that the indices are in-bounds or that the dtype is correct.
117    pub unsafe fn from_cats_and_dtype_unchecked(
118        cat_ids: ChunkedArray<T::PolarsPhysical>,
119        dtype: DataType,
120    ) -> Self {
121        debug_assert!(dtype.cat_physical().ok() == Some(T::physical()));
122
123        Self {
124            phys: cat_ids,
125            dtype,
126            _phantom: PhantomData,
127        }
128    }
129
130    /// Get a reference to the mapping of categorical types to the string values.
131    pub fn get_mapping(&self) -> &Arc<CategoricalMapping> {
132        let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = self.dtype() else {
133            unreachable!()
134        };
135        mapping
136    }
137
138    /// Create an [`Iterator`] that iterates over the `&str` values of the [`CategoricalChunked`].
139    pub fn iter_str(&self) -> impl PolarsIterator<Item = Option<&str>> {
140        let mapping = self.get_mapping();
141        self.phys
142            .iter()
143            .map(|cat| unsafe { Some(mapping.cat_to_str_unchecked(cat?.as_cat())) })
144    }
145
146    /// Converts from strings to this CategoricalChunked.
147    ///
148    /// If this dtype is an Enum any non-existing strings get mapped to null.
149    pub fn from_str_iter<'a, I: IntoIterator<Item = Option<&'a str>>>(
150        name: PlSmallStr,
151        dtype: DataType,
152        strings: I,
153    ) -> PolarsResult<Self> {
154        let strings = strings.into_iter();
155
156        let hint = strings.size_hint().0;
157        let mut cat_ids = Vec::with_capacity(hint);
158        let mut validity = BitmapBuilder::with_capacity(hint);
159
160        match &dtype {
161            DataType::Categorical(cats, mapping) => {
162                assert!(cats.physical() == T::physical());
163                for opt_s in strings {
164                    cat_ids.push(if let Some(s) = opt_s {
165                        T::Native::from_cat(mapping.insert_cat(s)?)
166                    } else {
167                        T::Native::zero()
168                    });
169                    validity.push(opt_s.is_some());
170                }
171            },
172            DataType::Enum(fcats, mapping) => {
173                assert!(fcats.physical() == T::physical());
174                for opt_s in strings {
175                    cat_ids.push(if let Some(cat) = opt_s.and_then(|s| mapping.get_cat(s)) {
176                        validity.push(true);
177                        T::Native::from_cat(cat)
178                    } else {
179                        validity.push(false);
180                        T::Native::zero()
181                    });
182                }
183            },
184            _ => panic!("from_strings_and_dtype_strict called on non-categorical type"),
185        }
186
187        let arr = <T::PolarsPhysical as PolarsDataType>::Array::from_vec(cat_ids)
188            .with_validity(validity.into_opt_validity());
189        let phys = ChunkedArray::<T::PolarsPhysical>::with_chunk(name, arr);
190        Ok(unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) })
191    }
192
193    pub fn to_arrow(&self, compat_level: CompatLevel) -> DictionaryArray<T::Native> {
194        let keys = self.physical().rechunk();
195        let keys = keys.downcast_as_array();
196        let values = self
197            .get_mapping()
198            .to_arrow(compat_level != CompatLevel::oldest());
199        let values_dtype = Box::new(values.dtype().clone());
200        let dtype =
201            ArrowDataType::Dictionary(<T::Native as DictionaryKey>::KEY_TYPE, values_dtype, false);
202        unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() }
203    }
204}
205
206impl<T: PolarsCategoricalType> LogicalType for CategoricalChunked<T> {
207    fn dtype(&self) -> &DataType {
208        &self.dtype
209    }
210
211    fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
212        polars_ensure!(i < self.len(), oob = i, self.len());
213        Ok(unsafe { self.get_any_value_unchecked(i) })
214    }
215
216    unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
217        match self.phys.get_unchecked(i) {
218            Some(i) => match &self.dtype {
219                DataType::Enum(_, mapping) => AnyValue::Enum(i.as_cat(), mapping),
220                DataType::Categorical(_, mapping) => AnyValue::Categorical(i.as_cat(), mapping),
221                _ => unreachable!(),
222            },
223            None => AnyValue::Null,
224        }
225    }
226
227    fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
228        if &self.dtype == dtype {
229            return Ok(self.clone().into_series());
230        }
231
232        match dtype {
233            DataType::String => {
234                let mapping = self.get_mapping();
235
236                // TODO @ cat-rework:, if len >= mapping.upper_bound(), cast categories to ViewArray, then construct array of Views.
237
238                let mut builder = StringChunkedBuilder::new(self.phys.name().clone(), self.len());
239                let to_str = |cat_id: CatSize| unsafe { mapping.cat_to_str_unchecked(cat_id) };
240                if !self.phys.has_nulls() {
241                    for cat_id in self.phys.into_no_null_iter() {
242                        builder.append_value(to_str(cat_id.as_cat()));
243                    }
244                } else {
245                    for opt_cat_id in self.phys.into_iter() {
246                        let opt_cat_id: Option<_> = opt_cat_id;
247                        builder.append_option(opt_cat_id.map(|c| to_str(c.as_cat())));
248                    }
249                }
250
251                let ca = builder.finish();
252                Ok(ca.into_series())
253            },
254
255            DataType::Enum(fcats, _mapping) => {
256                // TODO @ cat-rework: if len >= self.mapping().upper_bound(), remap categories then index into array.
257                let ret = with_match_categorical_physical_type!(fcats.physical(), |$C| {
258                    CategoricalChunked::<$C>::from_str_iter(
259                        self.name().clone(),
260                        dtype.clone(),
261                        self.iter_str()
262                    )?.into_series()
263                });
264
265                if options.is_strict() && self.null_count() != ret.null_count() {
266                    handle_casting_failures(&self.clone().into_series(), &ret)?;
267                }
268
269                Ok(ret)
270            },
271
272            DataType::Categorical(cats, _mapping) => {
273                // TODO @ cat-rework: if len >= self.mapping().upper_bound(), remap categories then index into array.
274                Ok(
275                    with_match_categorical_physical_type!(cats.physical(), |$C| {
276                        CategoricalChunked::<$C>::from_str_iter(
277                            self.name().clone(),
278                            dtype.clone(),
279                            self.iter_str()
280                        )?.into_series()
281                    }),
282                )
283            },
284
285            // LEGACY
286            // TODO @ cat-rework: remove after exposing to/from physical functions.
287            dt if dt.is_integer() => self.phys.clone().cast_with_options(dtype, options),
288
289            _ => polars_bail!(ComputeError: "cannot cast categorical types to {dtype:?}"),
290        }
291    }
292}