polars_core/chunked_array/logical/
categorical.rs1use std::hash::BuildHasher;
2use std::marker::PhantomData;
3
4use arrow::bitmap::BitmapBuilder;
5use num_traits::Zero;
6use polars_utils::hashing::{_boost_hash_combine, folded_multiply};
7
8use crate::chunked_array::cast::CastOptions;
9use crate::chunked_array::flags::StatisticsFlags;
10use crate::chunked_array::ops::ChunkFullNull;
11use crate::hashing::get_null_hash_value;
12use crate::prelude::*;
13use crate::series::IsSorted;
14use crate::utils::handle_casting_failures;
15
16pub type CategoricalChunked<T> = Logical<T, <T as PolarsCategoricalType>::PolarsPhysical>;
17pub type Categorical8Chunked = CategoricalChunked<Categorical8Type>;
18pub type Categorical16Chunked = CategoricalChunked<Categorical16Type>;
19pub type Categorical32Chunked = CategoricalChunked<Categorical32Type>;
20
21pub trait CategoricalPhysicalDtypeExt {
22 fn dtype(&self) -> DataType;
23}
24
25impl CategoricalPhysicalDtypeExt for CategoricalPhysical {
26 fn dtype(&self) -> DataType {
27 match self {
28 Self::U8 => DataType::UInt8,
29 Self::U16 => DataType::UInt16,
30 Self::U32 => DataType::UInt32,
31 }
32 }
33}
34
35impl<T: PolarsCategoricalType> CategoricalChunked<T> {
36 pub fn is_enum(&self) -> bool {
37 matches!(self.dtype(), DataType::Enum(_, _))
38 }
39
40 pub(crate) fn get_flags(&self) -> StatisticsFlags {
41 let mut flags = self.phys.get_flags();
44 if self.uses_lexical_ordering() {
45 flags.set_sorted(IsSorted::Not);
46 }
47 flags
48 }
49
50 pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) {
52 if self.uses_lexical_ordering() {
54 flags.set_sorted(IsSorted::Not)
55 }
56 self.physical_mut().set_flags(flags)
57 }
58
59 pub fn uses_lexical_ordering(&self) -> bool {
62 !self.is_enum()
63 }
64
65 pub fn full_null_with_dtype(name: PlSmallStr, length: usize, dtype: DataType) -> Self {
66 let phys =
67 ChunkedArray::<<T as PolarsCategoricalType>::PolarsPhysical>::full_null(name, length);
68 unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) }
69 }
70
71 pub fn from_cats_and_dtype(
75 mut cat_ids: ChunkedArray<T::PolarsPhysical>,
76 dtype: DataType,
77 ) -> Self {
78 let (DataType::Enum(_, mapping) | DataType::Categorical(_, mapping)) = &dtype else {
79 panic!("from_cats_and_dtype called on non-categorical type")
80 };
81 assert!(dtype.cat_physical().ok() == Some(T::physical()));
82
83 unsafe {
84 let mut invariants_violated = false;
85 let mut validity = BitmapBuilder::new();
86 for arr in cat_ids.downcast_iter_mut() {
87 validity.reserve(arr.len());
88 if arr.has_nulls() {
89 for opt_cat_id in arr.iter() {
90 if let Some(cat_id) = opt_cat_id {
91 validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
92 } else {
93 validity.push_unchecked(false);
94 }
95 }
96 } else {
97 for cat_id in arr.values_iter() {
98 validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
99 }
100 }
101
102 if arr.null_count() != validity.unset_bits() {
103 invariants_violated = true;
104 arr.set_validity(core::mem::take(&mut validity).into_opt_validity());
105 } else {
106 validity.clear();
107 }
108 }
109
110 if invariants_violated {
111 cat_ids.set_flags(StatisticsFlags::empty());
112 cat_ids.compute_len();
113 }
114 }
115
116 Self {
117 phys: cat_ids,
118 dtype,
119 _phantom: PhantomData,
120 }
121 }
122
123 pub unsafe fn from_cats_and_dtype_unchecked(
128 cat_ids: ChunkedArray<T::PolarsPhysical>,
129 dtype: DataType,
130 ) -> Self {
131 debug_assert!(dtype.cat_physical().ok() == Some(T::physical()));
132
133 Self {
134 phys: cat_ids,
135 dtype,
136 _phantom: PhantomData,
137 }
138 }
139
140 pub fn get_mapping(&self) -> &Arc<CategoricalMapping> {
142 let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = self.dtype() else {
143 unreachable!()
144 };
145 mapping
146 }
147
148 pub fn iter_str(&self) -> impl PolarsIterator<Item = Option<&str>> {
150 let mapping = self.get_mapping();
151 self.phys
152 .iter()
153 .map(|cat| unsafe { Some(mapping.cat_to_str_unchecked(cat?.as_cat())) })
154 }
155
156 pub fn from_str_iter<'a, I: IntoIterator<Item = Option<&'a str>>>(
160 name: PlSmallStr,
161 dtype: DataType,
162 strings: I,
163 ) -> PolarsResult<Self> {
164 let strings = strings.into_iter();
165
166 let hint = strings.size_hint().0;
167 let mut cat_ids = Vec::with_capacity(hint);
168 let mut validity = BitmapBuilder::with_capacity(hint);
169
170 match &dtype {
171 DataType::Categorical(cats, mapping) => {
172 assert!(cats.physical() == T::physical());
173 for opt_s in strings {
174 cat_ids.push(if let Some(s) = opt_s {
175 T::Native::from_cat(mapping.insert_cat(s)?)
176 } else {
177 T::Native::zero()
178 });
179 validity.push(opt_s.is_some());
180 }
181 },
182 DataType::Enum(fcats, mapping) => {
183 assert!(fcats.physical() == T::physical());
184 for opt_s in strings {
185 cat_ids.push(if let Some(cat) = opt_s.and_then(|s| mapping.get_cat(s)) {
186 validity.push(true);
187 T::Native::from_cat(cat)
188 } else {
189 validity.push(false);
190 T::Native::zero()
191 });
192 }
193 },
194 _ => panic!("from_strings_and_dtype_strict called on non-categorical type"),
195 }
196
197 let arr = <T::PolarsPhysical as PolarsDataType>::Array::from_vec(cat_ids)
198 .with_validity(validity.into_opt_validity());
199 let phys = ChunkedArray::<T::PolarsPhysical>::with_chunk(name, arr);
200 Ok(unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) })
201 }
202
203 pub fn to_arrow(&self, compat_level: CompatLevel) -> DictionaryArray<T::Native> {
204 let keys = self.physical().rechunk();
205 let keys = keys.downcast_as_array();
206 let values = self
207 .get_mapping()
208 .to_arrow(compat_level != CompatLevel::oldest());
209 let values_dtype = Box::new(values.dtype().clone());
210 let dtype =
211 ArrowDataType::Dictionary(<T::Native as DictionaryKey>::KEY_TYPE, values_dtype, false);
212 unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() }
213 }
214}
215
216impl<T: PolarsCategoricalType> LogicalType for CategoricalChunked<T> {
217 fn dtype(&self) -> &DataType {
218 &self.dtype
219 }
220
221 fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
222 polars_ensure!(i < self.len(), oob = i, self.len());
223 Ok(unsafe { self.get_any_value_unchecked(i) })
224 }
225
226 unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
227 match self.phys.get_unchecked(i) {
228 Some(i) => match &self.dtype {
229 DataType::Enum(_, mapping) => AnyValue::Enum(i.as_cat(), mapping),
230 DataType::Categorical(_, mapping) => AnyValue::Categorical(i.as_cat(), mapping),
231 _ => unreachable!(),
232 },
233 None => AnyValue::Null,
234 }
235 }
236
237 fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
238 if &self.dtype == dtype {
239 return Ok(self.clone().into_series());
240 }
241
242 match dtype {
243 DataType::String => {
244 let mapping = self.get_mapping();
245
246 let mut builder = StringChunkedBuilder::new(self.phys.name().clone(), self.len());
249 let to_str = |cat_id: CatSize| unsafe { mapping.cat_to_str_unchecked(cat_id) };
250 if !self.phys.has_nulls() {
251 for cat_id in self.phys.into_no_null_iter() {
252 builder.append_value(to_str(cat_id.as_cat()));
253 }
254 } else {
255 for opt_cat_id in self.phys.into_iter() {
256 let opt_cat_id: Option<_> = opt_cat_id;
257 builder.append_option(opt_cat_id.map(|c| to_str(c.as_cat())));
258 }
259 }
260
261 let ca = builder.finish();
262 Ok(ca.into_series())
263 },
264
265 DataType::Enum(fcats, _mapping) => {
266 let ret = with_match_categorical_physical_type!(fcats.physical(), |$C| {
268 CategoricalChunked::<$C>::from_str_iter(
269 self.name().clone(),
270 dtype.clone(),
271 self.iter_str()
272 )?.into_series()
273 });
274
275 if options.is_strict() && self.null_count() != ret.null_count() {
276 handle_casting_failures(&self.clone().into_series(), &ret)?;
277 }
278
279 Ok(ret)
280 },
281
282 DataType::Categorical(cats, _mapping) => {
283 Ok(
285 with_match_categorical_physical_type!(cats.physical(), |$C| {
286 CategoricalChunked::<$C>::from_str_iter(
287 self.name().clone(),
288 dtype.clone(),
289 self.iter_str()
290 )?.into_series()
291 }),
292 )
293 },
294
295 dt if dt.is_integer() => self.phys.clone().cast_with_options(dtype, options),
298
299 _ => polars_bail!(ComputeError: "cannot cast categorical types to {dtype:?}"),
300 }
301 }
302}
303
304impl<T: PolarsCategoricalType> VecHash for CategoricalChunked<T>
305where
306 ChunkedArray<<T as PolarsCategoricalType>::PolarsPhysical>: VecHash,
307{
308 fn vec_hash(
309 &self,
310 random_state: PlSeedableRandomStateQuality,
311 buf: &mut Vec<u64>,
312 ) -> PolarsResult<()> {
313 if self.is_enum() {
314 self.phys.vec_hash(random_state, buf)
315 } else {
316 buf.clear();
317 buf.reserve(self.phys.len());
318 let mult = random_state.hash_one(0);
319 let null = get_null_hash_value(&random_state);
320
321 let mapping = self.get_mapping();
322 for opt_cat in self.phys.iter() {
323 if let Some(cat) = opt_cat {
324 let base_h = unsafe { mapping.cat_to_hash_unchecked(cat.as_cat()) };
325 buf.push(folded_multiply(base_h, mult));
326 } else {
327 buf.push(null);
328 }
329 }
330 Ok(())
331 }
332 }
333
334 fn vec_hash_combine(
335 &self,
336 random_state: PlSeedableRandomStateQuality,
337 hashes: &mut [u64],
338 ) -> PolarsResult<()> {
339 if self.is_enum() {
340 self.phys.vec_hash_combine(random_state, hashes)
341 } else {
342 let mult = random_state.hash_one(0);
343 let null = get_null_hash_value(&random_state);
344
345 let mapping = self.get_mapping();
346 assert!(self.phys.len() == hashes.len());
347 for (opt_cat, h) in self.phys.iter().zip(hashes.iter_mut()) {
348 let our_h = if let Some(cat) = opt_cat {
349 let base_h = unsafe { mapping.cat_to_hash_unchecked(cat.as_cat()) };
350 folded_multiply(base_h, mult)
351 } else {
352 null
353 };
354 *h = _boost_hash_combine(our_h, *h);
355 }
356 Ok(())
357 }
358 }
359}