Skip to main content

polars_core/chunked_array/logical/
categorical.rs

1use std::hash::BuildHasher;
2use std::marker::PhantomData;
3
4use arrow::bitmap::BitmapBuilder;
5use num_traits::Zero;
6use polars_utils::hashing::{_boost_hash_combine, folded_multiply};
7
8use crate::chunked_array::cast::CastOptions;
9use crate::chunked_array::flags::StatisticsFlags;
10use crate::chunked_array::ops::ChunkFullNull;
11use crate::hashing::get_null_hash_value;
12use crate::prelude::*;
13use crate::series::IsSorted;
14use crate::utils::handle_casting_failures;
15
16pub type CategoricalChunked<T> = Logical<T, <T as PolarsCategoricalType>::PolarsPhysical>;
17pub type Categorical8Chunked = CategoricalChunked<Categorical8Type>;
18pub type Categorical16Chunked = CategoricalChunked<Categorical16Type>;
19pub type Categorical32Chunked = CategoricalChunked<Categorical32Type>;
20
21pub trait CategoricalPhysicalDtypeExt {
22    fn dtype(&self) -> DataType;
23}
24
25impl CategoricalPhysicalDtypeExt for CategoricalPhysical {
26    fn dtype(&self) -> DataType {
27        match self {
28            Self::U8 => DataType::UInt8,
29            Self::U16 => DataType::UInt16,
30            Self::U32 => DataType::UInt32,
31        }
32    }
33}
34
35impl<T: PolarsCategoricalType> CategoricalChunked<T> {
36    pub fn is_enum(&self) -> bool {
37        matches!(self.dtype(), DataType::Enum(_, _))
38    }
39
40    pub(crate) fn get_flags(&self) -> StatisticsFlags {
41        // If we use lexical ordering then physical sortedness does not imply
42        // our sortedness.
43        let mut flags = self.phys.get_flags();
44        if self.uses_lexical_ordering() {
45            flags.set_sorted(IsSorted::Not);
46        }
47        flags
48    }
49
50    /// Set flags for the ChunkedArray.
51    pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) {
52        // We should not set the sorted flag if we are sorting in lexical order.
53        if self.uses_lexical_ordering() {
54            flags.set_sorted(IsSorted::Not)
55        }
56        self.physical_mut().set_flags(flags)
57    }
58
59    /// Return whether or not the [`CategoricalChunked`] uses the lexical order
60    /// of the string values when sorting.
61    pub fn uses_lexical_ordering(&self) -> bool {
62        !self.is_enum()
63    }
64
65    pub fn full_null_with_dtype(name: PlSmallStr, length: usize, dtype: DataType) -> Self {
66        let phys =
67            ChunkedArray::<<T as PolarsCategoricalType>::PolarsPhysical>::full_null(name, length);
68        unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) }
69    }
70
71    /// Create a [`CategoricalChunked`] from a physical array and dtype.
72    ///
73    /// Checks that all the category ids are valid, mapping invalid ones to nulls.
74    pub fn from_cats_and_dtype(
75        mut cat_ids: ChunkedArray<T::PolarsPhysical>,
76        dtype: DataType,
77    ) -> Self {
78        let (DataType::Enum(_, mapping) | DataType::Categorical(_, mapping)) = &dtype else {
79            panic!("from_cats_and_dtype called on non-categorical type")
80        };
81        assert!(dtype.cat_physical().ok() == Some(T::physical()));
82
83        unsafe {
84            let mut invariants_violated = false;
85            let mut validity = BitmapBuilder::new();
86            for arr in cat_ids.downcast_iter_mut() {
87                validity.reserve(arr.len());
88                if arr.has_nulls() {
89                    for opt_cat_id in arr.iter() {
90                        if let Some(cat_id) = opt_cat_id {
91                            validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
92                        } else {
93                            validity.push_unchecked(false);
94                        }
95                    }
96                } else {
97                    for cat_id in arr.values_iter() {
98                        validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
99                    }
100                }
101
102                if arr.null_count() != validity.unset_bits() {
103                    invariants_violated = true;
104                    arr.set_validity(core::mem::take(&mut validity).into_opt_validity());
105                } else {
106                    validity.clear();
107                }
108            }
109
110            if invariants_violated {
111                cat_ids.set_flags(StatisticsFlags::empty());
112                cat_ids.compute_len();
113            }
114        }
115
116        Self {
117            phys: cat_ids,
118            dtype,
119            _phantom: PhantomData,
120        }
121    }
122
123    /// Create a [`CategoricalChunked`] from a physical array and dtype.
124    ///
125    /// # Safety
126    /// It's not checked that the indices are in-bounds or that the dtype is correct.
127    pub unsafe fn from_cats_and_dtype_unchecked(
128        cat_ids: ChunkedArray<T::PolarsPhysical>,
129        dtype: DataType,
130    ) -> Self {
131        debug_assert!(dtype.cat_physical().ok() == Some(T::physical()));
132
133        Self {
134            phys: cat_ids,
135            dtype,
136            _phantom: PhantomData,
137        }
138    }
139
140    /// Get a reference to the mapping of categorical types to the string values.
141    pub fn get_mapping(&self) -> &Arc<CategoricalMapping> {
142        let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = self.dtype() else {
143            unreachable!()
144        };
145        mapping
146    }
147
148    /// Create an [`Iterator`] that iterates over the `&str` values of the [`CategoricalChunked`].
149    pub fn iter_str(&self) -> impl PolarsIterator<Item = Option<&str>> {
150        let mapping = self.get_mapping();
151        self.phys
152            .iter()
153            .map(|cat| unsafe { Some(mapping.cat_to_str_unchecked(cat?.as_cat())) })
154    }
155
156    /// Converts from strings to this CategoricalChunked.
157    ///
158    /// If this dtype is an Enum any non-existing strings get mapped to null.
159    pub fn from_str_iter<'a, I: IntoIterator<Item = Option<&'a str>>>(
160        name: PlSmallStr,
161        dtype: DataType,
162        strings: I,
163    ) -> PolarsResult<Self> {
164        let strings = strings.into_iter();
165
166        let hint = strings.size_hint().0;
167        let mut cat_ids = Vec::with_capacity(hint);
168        let mut validity = BitmapBuilder::with_capacity(hint);
169
170        match &dtype {
171            DataType::Categorical(cats, mapping) => {
172                assert!(cats.physical() == T::physical());
173                for opt_s in strings {
174                    cat_ids.push(if let Some(s) = opt_s {
175                        T::Native::from_cat(mapping.insert_cat(s)?)
176                    } else {
177                        T::Native::zero()
178                    });
179                    validity.push(opt_s.is_some());
180                }
181            },
182            DataType::Enum(fcats, mapping) => {
183                assert!(fcats.physical() == T::physical());
184                for opt_s in strings {
185                    cat_ids.push(if let Some(cat) = opt_s.and_then(|s| mapping.get_cat(s)) {
186                        validity.push(true);
187                        T::Native::from_cat(cat)
188                    } else {
189                        validity.push(false);
190                        T::Native::zero()
191                    });
192                }
193            },
194            _ => panic!("from_strings_and_dtype_strict called on non-categorical type"),
195        }
196
197        let arr = <T::PolarsPhysical as PolarsDataType>::Array::from_vec(cat_ids)
198            .with_validity(validity.into_opt_validity());
199        let phys = ChunkedArray::<T::PolarsPhysical>::with_chunk(name, arr);
200        Ok(unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) })
201    }
202
203    pub fn to_arrow(&self, compat_level: CompatLevel) -> DictionaryArray<T::Native> {
204        let keys = self.physical().rechunk();
205        let keys = keys.downcast_as_array();
206        let values = self
207            .get_mapping()
208            .to_arrow(compat_level.uses_binview_types());
209        let values_dtype = Box::new(values.dtype().clone());
210        let dtype = ArrowDataType::Dictionary(
211            <T::Native as DictionaryKey>::KEY_TYPE,
212            values_dtype,
213            self.is_enum(),
214        );
215        unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() }
216    }
217}
218
219impl<T: PolarsCategoricalType> LogicalType for CategoricalChunked<T> {
220    fn dtype(&self) -> &DataType {
221        &self.dtype
222    }
223
224    fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
225        polars_ensure!(i < self.len(), oob = i, self.len());
226        Ok(unsafe { self.get_any_value_unchecked(i) })
227    }
228
229    unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
230        match self.phys.get_unchecked(i) {
231            Some(i) => match &self.dtype {
232                DataType::Enum(_, mapping) => AnyValue::Enum(i.as_cat(), mapping),
233                DataType::Categorical(_, mapping) => AnyValue::Categorical(i.as_cat(), mapping),
234                _ => unreachable!(),
235            },
236            None => AnyValue::Null,
237        }
238    }
239
240    fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
241        if &self.dtype == dtype {
242            return Ok(self.clone().into_series());
243        }
244
245        match dtype {
246            DataType::String => {
247                let mapping = self.get_mapping();
248
249                // TODO @ cat-rework:, if len >= mapping.upper_bound(), cast categories to ViewArray, then construct array of Views.
250
251                let mut builder = StringChunkedBuilder::new(self.phys.name().clone(), self.len());
252                let to_str = |cat_id: CatSize| unsafe { mapping.cat_to_str_unchecked(cat_id) };
253                if !self.phys.has_nulls() {
254                    for cat_id in self.phys.into_no_null_iter() {
255                        builder.append_value(to_str(cat_id.as_cat()));
256                    }
257                } else {
258                    for opt_cat_id in self.phys.into_iter() {
259                        let opt_cat_id: Option<_> = opt_cat_id;
260                        builder.append_option(opt_cat_id.map(|c| to_str(c.as_cat())));
261                    }
262                }
263
264                let ca = builder.finish();
265                Ok(ca.into_series())
266            },
267
268            DataType::Enum(fcats, _mapping) => {
269                // TODO @ cat-rework: if len >= self.mapping().upper_bound(), remap categories then index into array.
270                let ret = with_match_categorical_physical_type!(fcats.physical(), |$C| {
271                    CategoricalChunked::<$C>::from_str_iter(
272                        self.name().clone(),
273                        dtype.clone(),
274                        self.iter_str()
275                    )?.into_series()
276                });
277
278                if options.is_strict() && self.null_count() != ret.null_count() {
279                    handle_casting_failures(&self.clone().into_series(), &ret)?;
280                }
281
282                Ok(ret)
283            },
284
285            DataType::Categorical(cats, _mapping) => {
286                // TODO @ cat-rework: if len >= self.mapping().upper_bound(), remap categories then index into array.
287                Ok(
288                    with_match_categorical_physical_type!(cats.physical(), |$C| {
289                        CategoricalChunked::<$C>::from_str_iter(
290                            self.name().clone(),
291                            dtype.clone(),
292                            self.iter_str()
293                        )?.into_series()
294                    }),
295                )
296            },
297
298            // LEGACY
299            // TODO @ cat-rework: remove after exposing to/from physical functions.
300            dt if dt.is_integer() => self.phys.clone().cast_with_options(dtype, options),
301
302            _ => polars_bail!(ComputeError: "cannot cast categorical types to {dtype:?}"),
303        }
304    }
305}
306
307impl<T: PolarsCategoricalType> VecHash for CategoricalChunked<T>
308where
309    ChunkedArray<<T as PolarsCategoricalType>::PolarsPhysical>: VecHash,
310{
311    fn vec_hash(
312        &self,
313        random_state: PlSeedableRandomStateQuality,
314        buf: &mut Vec<u64>,
315    ) -> PolarsResult<()> {
316        if self.is_enum() {
317            self.phys.vec_hash(random_state, buf)
318        } else {
319            buf.clear();
320            buf.reserve(self.phys.len());
321            let mult = random_state.hash_one(0);
322            let null = get_null_hash_value(&random_state);
323
324            let mapping = self.get_mapping();
325            for opt_cat in self.phys.iter() {
326                if let Some(cat) = opt_cat {
327                    let base_h = unsafe { mapping.cat_to_hash_unchecked(cat.as_cat()) };
328                    buf.push(folded_multiply(base_h, mult));
329                } else {
330                    buf.push(null);
331                }
332            }
333            Ok(())
334        }
335    }
336
337    fn vec_hash_combine(
338        &self,
339        random_state: PlSeedableRandomStateQuality,
340        hashes: &mut [u64],
341    ) -> PolarsResult<()> {
342        if self.is_enum() {
343            self.phys.vec_hash_combine(random_state, hashes)
344        } else {
345            let mult = random_state.hash_one(0);
346            let null = get_null_hash_value(&random_state);
347
348            let mapping = self.get_mapping();
349            assert!(self.phys.len() == hashes.len());
350            for (opt_cat, h) in self.phys.iter().zip(hashes.iter_mut()) {
351                let our_h = if let Some(cat) = opt_cat {
352                    let base_h = unsafe { mapping.cat_to_hash_unchecked(cat.as_cat()) };
353                    folded_multiply(base_h, mult)
354                } else {
355                    null
356                };
357                *h = _boost_hash_combine(our_h, *h);
358            }
359            Ok(())
360        }
361    }
362}