Skip to main content

polars_core/chunked_array/logical/
categorical.rs

1use std::hash::BuildHasher;
2use std::marker::PhantomData;
3
4use arrow::bitmap::BitmapBuilder;
5use num_traits::Zero;
6use polars_utils::hashing::{_boost_hash_combine, folded_multiply};
7
8use crate::chunked_array::cast::CastOptions;
9use crate::chunked_array::flags::StatisticsFlags;
10use crate::chunked_array::iterator::PolarsIterator;
11use crate::chunked_array::ops::ChunkFullNull;
12use crate::hashing::get_null_hash_value;
13use crate::prelude::*;
14use crate::series::IsSorted;
15use crate::utils::handle_casting_failures;
16
17pub type CategoricalChunked<T> = Logical<T, <T as PolarsCategoricalType>::PolarsPhysical>;
18pub type Categorical8Chunked = CategoricalChunked<Categorical8Type>;
19pub type Categorical16Chunked = CategoricalChunked<Categorical16Type>;
20pub type Categorical32Chunked = CategoricalChunked<Categorical32Type>;
21
22pub trait CategoricalPhysicalDtypeExt {
23    fn dtype(&self) -> DataType;
24}
25
26impl CategoricalPhysicalDtypeExt for CategoricalPhysical {
27    fn dtype(&self) -> DataType {
28        match self {
29            Self::U8 => DataType::UInt8,
30            Self::U16 => DataType::UInt16,
31            Self::U32 => DataType::UInt32,
32        }
33    }
34}
35
36impl<T: PolarsCategoricalType> CategoricalChunked<T> {
37    pub fn is_enum(&self) -> bool {
38        matches!(self.dtype(), DataType::Enum(_, _))
39    }
40
41    pub(crate) fn get_flags(&self) -> StatisticsFlags {
42        // If we use lexical ordering then physical sortedness does not imply
43        // our sortedness.
44        let mut flags = self.phys.get_flags();
45        if self.uses_lexical_ordering() {
46            flags.set_sorted(IsSorted::Not);
47        }
48        flags
49    }
50
51    /// Set flags for the ChunkedArray.
52    pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) {
53        // We should not set the sorted flag if we are sorting in lexical order.
54        if self.uses_lexical_ordering() {
55            flags.set_sorted(IsSorted::Not)
56        }
57        self.physical_mut().set_flags(flags)
58    }
59
60    /// Return whether or not the [`CategoricalChunked`] uses the lexical order
61    /// of the string values when sorting.
62    pub fn uses_lexical_ordering(&self) -> bool {
63        !self.is_enum()
64    }
65
66    pub fn full_null_with_dtype(name: PlSmallStr, length: usize, dtype: DataType) -> Self {
67        let phys =
68            ChunkedArray::<<T as PolarsCategoricalType>::PolarsPhysical>::full_null(name, length);
69        unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) }
70    }
71
72    /// Create a [`CategoricalChunked`] from a physical array and dtype.
73    ///
74    /// Checks that all the category ids are valid, mapping invalid ones to nulls.
75    pub fn from_cats_and_dtype(
76        mut cat_ids: ChunkedArray<T::PolarsPhysical>,
77        dtype: DataType,
78    ) -> Self {
79        let (DataType::Enum(_, mapping) | DataType::Categorical(_, mapping)) = &dtype else {
80            panic!("from_cats_and_dtype called on non-categorical type")
81        };
82        assert!(dtype.cat_physical().ok() == Some(T::physical()));
83
84        unsafe {
85            let mut invariants_violated = false;
86            let mut validity = BitmapBuilder::new();
87            for arr in cat_ids.downcast_iter_mut() {
88                validity.reserve(arr.len());
89                if arr.has_nulls() {
90                    for opt_cat_id in arr.iter() {
91                        if let Some(cat_id) = opt_cat_id {
92                            validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
93                        } else {
94                            validity.push_unchecked(false);
95                        }
96                    }
97                } else {
98                    for cat_id in arr.values_iter() {
99                        validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
100                    }
101                }
102
103                if arr.null_count() != validity.unset_bits() {
104                    invariants_violated = true;
105                    arr.set_validity(core::mem::take(&mut validity).into_opt_validity());
106                } else {
107                    validity.clear();
108                }
109            }
110
111            if invariants_violated {
112                cat_ids.set_flags(StatisticsFlags::empty());
113                cat_ids.compute_len();
114            }
115        }
116
117        Self {
118            phys: cat_ids,
119            dtype,
120            _phantom: PhantomData,
121        }
122    }
123
124    /// Create a [`CategoricalChunked`] from a physical array and dtype.
125    ///
126    /// # Safety
127    /// It's not checked that the indices are in-bounds or that the dtype is correct.
128    pub unsafe fn from_cats_and_dtype_unchecked(
129        cat_ids: ChunkedArray<T::PolarsPhysical>,
130        dtype: DataType,
131    ) -> Self {
132        debug_assert!(dtype.cat_physical().ok() == Some(T::physical()));
133
134        Self {
135            phys: cat_ids,
136            dtype,
137            _phantom: PhantomData,
138        }
139    }
140
141    /// Get a reference to the mapping of categorical types to the string values.
142    pub fn get_mapping(&self) -> &Arc<CategoricalMapping> {
143        let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = self.dtype() else {
144            unreachable!()
145        };
146        mapping
147    }
148
149    /// Create an [`Iterator`] that iterates over the `&str` values of the [`CategoricalChunked`].
150    pub fn iter_str(&self) -> impl PolarsIterator<Item = Option<&str>> {
151        let mapping = self.get_mapping();
152        self.phys
153            .iter()
154            .map(|cat| unsafe { Some(mapping.cat_to_str_unchecked(cat?.as_cat())) })
155    }
156
157    /// Converts from strings to this CategoricalChunked.
158    ///
159    /// If this dtype is an Enum any non-existing strings get mapped to null.
160    pub fn from_str_iter<'a, I: IntoIterator<Item = Option<&'a str>>>(
161        name: PlSmallStr,
162        dtype: DataType,
163        strings: I,
164    ) -> PolarsResult<Self> {
165        let strings = strings.into_iter();
166
167        let hint = strings.size_hint().0;
168        let mut cat_ids = Vec::with_capacity(hint);
169        let mut validity = BitmapBuilder::with_capacity(hint);
170
171        match &dtype {
172            DataType::Categorical(cats, mapping) => {
173                assert!(cats.physical() == T::physical());
174                for opt_s in strings {
175                    cat_ids.push(if let Some(s) = opt_s {
176                        T::Native::from_cat(mapping.insert_cat(s)?)
177                    } else {
178                        T::Native::zero()
179                    });
180                    validity.push(opt_s.is_some());
181                }
182            },
183            DataType::Enum(fcats, mapping) => {
184                assert!(fcats.physical() == T::physical());
185                for opt_s in strings {
186                    cat_ids.push(if let Some(cat) = opt_s.and_then(|s| mapping.get_cat(s)) {
187                        validity.push(true);
188                        T::Native::from_cat(cat)
189                    } else {
190                        validity.push(false);
191                        T::Native::zero()
192                    });
193                }
194            },
195            _ => panic!("from_strings_and_dtype_strict called on non-categorical type"),
196        }
197
198        let arr = <T::PolarsPhysical as PolarsDataType>::Array::from_vec(cat_ids)
199            .with_validity(validity.into_opt_validity());
200        let phys = ChunkedArray::<T::PolarsPhysical>::with_chunk(name, arr);
201        Ok(unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) })
202    }
203
204    pub fn to_arrow(&self, compat_level: CompatLevel) -> DictionaryArray<T::Native> {
205        let keys = self.physical().rechunk();
206        let keys = keys.downcast_as_array();
207        let values = self
208            .get_mapping()
209            .to_arrow(compat_level.uses_binview_types());
210        let values_dtype = Box::new(values.dtype().clone());
211        let dtype = ArrowDataType::Dictionary(
212            <T::Native as DictionaryKey>::KEY_TYPE,
213            values_dtype,
214            self.is_enum(),
215        );
216        unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() }
217    }
218}
219
220impl<T: PolarsCategoricalType> LogicalType for CategoricalChunked<T> {
221    fn dtype(&self) -> &DataType {
222        &self.dtype
223    }
224
225    fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
226        polars_ensure!(i < self.len(), oob = i, self.len());
227        Ok(unsafe { self.get_any_value_unchecked(i) })
228    }
229
230    unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
231        match self.phys.get_unchecked(i) {
232            Some(i) => match &self.dtype {
233                DataType::Enum(_, mapping) => AnyValue::Enum(i.as_cat(), mapping),
234                DataType::Categorical(_, mapping) => AnyValue::Categorical(i.as_cat(), mapping),
235                _ => unreachable!(),
236            },
237            None => AnyValue::Null,
238        }
239    }
240
241    fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
242        if &self.dtype == dtype {
243            return Ok(self.clone().into_series());
244        }
245
246        match dtype {
247            DataType::String => {
248                let mapping = self.get_mapping();
249
250                // TODO @ cat-rework:, if len >= mapping.upper_bound(), cast categories to ViewArray, then construct array of Views.
251
252                let mut builder = StringChunkedBuilder::new(self.phys.name().clone(), self.len());
253                let to_str = |cat_id: CatSize| unsafe { mapping.cat_to_str_unchecked(cat_id) };
254                if !self.phys.has_nulls() {
255                    for cat_id in self.phys.into_no_null_iter() {
256                        builder.append_value(to_str(cat_id.as_cat()));
257                    }
258                } else {
259                    for opt_cat_id in self.phys.iter() {
260                        let opt_cat_id: Option<_> = opt_cat_id;
261                        builder.append_option(opt_cat_id.map(|c| to_str(c.as_cat())));
262                    }
263                }
264
265                let ca = builder.finish();
266                Ok(ca.into_series())
267            },
268
269            DataType::Enum(fcats, _mapping) => {
270                // TODO @ cat-rework: if len >= self.mapping().upper_bound(), remap categories then index into array.
271                let ret = with_match_categorical_physical_type!(fcats.physical(), |$C| {
272                    CategoricalChunked::<$C>::from_str_iter(
273                        self.name().clone(),
274                        dtype.clone(),
275                        self.iter_str()
276                    )?.into_series()
277                });
278
279                if options.is_strict() && self.null_count() != ret.null_count() {
280                    handle_casting_failures(&self.clone().into_series(), &ret)?;
281                }
282
283                Ok(ret)
284            },
285
286            DataType::Categorical(cats, _mapping) => {
287                // TODO @ cat-rework: if len >= self.mapping().upper_bound(), remap categories then index into array.
288                Ok(
289                    with_match_categorical_physical_type!(cats.physical(), |$C| {
290                        CategoricalChunked::<$C>::from_str_iter(
291                            self.name().clone(),
292                            dtype.clone(),
293                            self.iter_str()
294                        )?.into_series()
295                    }),
296                )
297            },
298
299            // LEGACY
300            // TODO @ cat-rework: remove after exposing to/from physical functions.
301            dt if dt.is_integer() => self.phys.clone().cast_with_options(dtype, options),
302
303            _ => polars_bail!(ComputeError: "cannot cast categorical types to {dtype:?}"),
304        }
305    }
306}
307
308impl<T: PolarsCategoricalType> VecHash for CategoricalChunked<T>
309where
310    ChunkedArray<<T as PolarsCategoricalType>::PolarsPhysical>: VecHash,
311{
312    fn vec_hash(
313        &self,
314        random_state: PlSeedableRandomStateQuality,
315        buf: &mut Vec<u64>,
316    ) -> PolarsResult<()> {
317        if self.is_enum() {
318            self.phys.vec_hash(random_state, buf)
319        } else {
320            buf.clear();
321            buf.reserve(self.phys.len());
322            let mult = random_state.hash_one(0);
323            let null = get_null_hash_value(&random_state);
324
325            let mapping = self.get_mapping();
326            for opt_cat in self.phys.iter() {
327                if let Some(cat) = opt_cat {
328                    let base_h = unsafe { mapping.cat_to_hash_unchecked(cat.as_cat()) };
329                    buf.push(folded_multiply(base_h, mult));
330                } else {
331                    buf.push(null);
332                }
333            }
334            Ok(())
335        }
336    }
337
338    fn vec_hash_combine(
339        &self,
340        random_state: PlSeedableRandomStateQuality,
341        hashes: &mut [u64],
342    ) -> PolarsResult<()> {
343        if self.is_enum() {
344            self.phys.vec_hash_combine(random_state, hashes)
345        } else {
346            let mult = random_state.hash_one(0);
347            let null = get_null_hash_value(&random_state);
348
349            let mapping = self.get_mapping();
350            assert!(self.phys.len() == hashes.len());
351            for (opt_cat, h) in self.phys.iter().zip(hashes.iter_mut()) {
352                let our_h = if let Some(cat) = opt_cat {
353                    let base_h = unsafe { mapping.cat_to_hash_unchecked(cat.as_cat()) };
354                    folded_multiply(base_h, mult)
355                } else {
356                    null
357                };
358                *h = _boost_hash_combine(our_h, *h);
359            }
360            Ok(())
361        }
362    }
363}