polars_core/chunked_array/logical/
categorical.rs1use std::hash::BuildHasher;
2use std::marker::PhantomData;
3
4use arrow::bitmap::BitmapBuilder;
5use num_traits::Zero;
6use polars_utils::hashing::{_boost_hash_combine, folded_multiply};
7
8use crate::chunked_array::cast::CastOptions;
9use crate::chunked_array::flags::StatisticsFlags;
10use crate::chunked_array::iterator::PolarsIterator;
11use crate::chunked_array::ops::ChunkFullNull;
12use crate::hashing::get_null_hash_value;
13use crate::prelude::*;
14use crate::series::IsSorted;
15use crate::utils::handle_casting_failures;
16
17pub type CategoricalChunked<T> = Logical<T, <T as PolarsCategoricalType>::PolarsPhysical>;
18pub type Categorical8Chunked = CategoricalChunked<Categorical8Type>;
19pub type Categorical16Chunked = CategoricalChunked<Categorical16Type>;
20pub type Categorical32Chunked = CategoricalChunked<Categorical32Type>;
21
22pub trait CategoricalPhysicalDtypeExt {
23 fn dtype(&self) -> DataType;
24}
25
26impl CategoricalPhysicalDtypeExt for CategoricalPhysical {
27 fn dtype(&self) -> DataType {
28 match self {
29 Self::U8 => DataType::UInt8,
30 Self::U16 => DataType::UInt16,
31 Self::U32 => DataType::UInt32,
32 }
33 }
34}
35
36impl<T: PolarsCategoricalType> CategoricalChunked<T> {
37 pub fn is_enum(&self) -> bool {
38 matches!(self.dtype(), DataType::Enum(_, _))
39 }
40
41 pub(crate) fn get_flags(&self) -> StatisticsFlags {
42 let mut flags = self.phys.get_flags();
45 if self.uses_lexical_ordering() {
46 flags.set_sorted(IsSorted::Not);
47 }
48 flags
49 }
50
51 pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) {
53 if self.uses_lexical_ordering() {
55 flags.set_sorted(IsSorted::Not)
56 }
57 self.physical_mut().set_flags(flags)
58 }
59
60 pub fn uses_lexical_ordering(&self) -> bool {
63 !self.is_enum()
64 }
65
66 pub fn full_null_with_dtype(name: PlSmallStr, length: usize, dtype: DataType) -> Self {
67 let phys =
68 ChunkedArray::<<T as PolarsCategoricalType>::PolarsPhysical>::full_null(name, length);
69 unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) }
70 }
71
72 pub fn from_cats_and_dtype(
76 mut cat_ids: ChunkedArray<T::PolarsPhysical>,
77 dtype: DataType,
78 ) -> Self {
79 let (DataType::Enum(_, mapping) | DataType::Categorical(_, mapping)) = &dtype else {
80 panic!("from_cats_and_dtype called on non-categorical type")
81 };
82 assert!(dtype.cat_physical().ok() == Some(T::physical()));
83
84 unsafe {
85 let mut invariants_violated = false;
86 let mut validity = BitmapBuilder::new();
87 for arr in cat_ids.downcast_iter_mut() {
88 validity.reserve(arr.len());
89 if arr.has_nulls() {
90 for opt_cat_id in arr.iter() {
91 if let Some(cat_id) = opt_cat_id {
92 validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
93 } else {
94 validity.push_unchecked(false);
95 }
96 }
97 } else {
98 for cat_id in arr.values_iter() {
99 validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
100 }
101 }
102
103 if arr.null_count() != validity.unset_bits() {
104 invariants_violated = true;
105 arr.set_validity(core::mem::take(&mut validity).into_opt_validity());
106 } else {
107 validity.clear();
108 }
109 }
110
111 if invariants_violated {
112 cat_ids.set_flags(StatisticsFlags::empty());
113 cat_ids.compute_len();
114 }
115 }
116
117 Self {
118 phys: cat_ids,
119 dtype,
120 _phantom: PhantomData,
121 }
122 }
123
124 pub unsafe fn from_cats_and_dtype_unchecked(
129 cat_ids: ChunkedArray<T::PolarsPhysical>,
130 dtype: DataType,
131 ) -> Self {
132 debug_assert!(dtype.cat_physical().ok() == Some(T::physical()));
133
134 Self {
135 phys: cat_ids,
136 dtype,
137 _phantom: PhantomData,
138 }
139 }
140
141 pub fn get_mapping(&self) -> &Arc<CategoricalMapping> {
143 let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = self.dtype() else {
144 unreachable!()
145 };
146 mapping
147 }
148
149 pub fn iter_str(&self) -> impl PolarsIterator<Item = Option<&str>> {
151 let mapping = self.get_mapping();
152 self.phys
153 .iter()
154 .map(|cat| unsafe { Some(mapping.cat_to_str_unchecked(cat?.as_cat())) })
155 }
156
157 pub fn from_str_iter<'a, I: IntoIterator<Item = Option<&'a str>>>(
161 name: PlSmallStr,
162 dtype: DataType,
163 strings: I,
164 ) -> PolarsResult<Self> {
165 let strings = strings.into_iter();
166
167 let hint = strings.size_hint().0;
168 let mut cat_ids = Vec::with_capacity(hint);
169 let mut validity = BitmapBuilder::with_capacity(hint);
170
171 match &dtype {
172 DataType::Categorical(cats, mapping) => {
173 assert!(cats.physical() == T::physical());
174 for opt_s in strings {
175 cat_ids.push(if let Some(s) = opt_s {
176 T::Native::from_cat(mapping.insert_cat(s)?)
177 } else {
178 T::Native::zero()
179 });
180 validity.push(opt_s.is_some());
181 }
182 },
183 DataType::Enum(fcats, mapping) => {
184 assert!(fcats.physical() == T::physical());
185 for opt_s in strings {
186 cat_ids.push(if let Some(cat) = opt_s.and_then(|s| mapping.get_cat(s)) {
187 validity.push(true);
188 T::Native::from_cat(cat)
189 } else {
190 validity.push(false);
191 T::Native::zero()
192 });
193 }
194 },
195 _ => panic!("from_strings_and_dtype_strict called on non-categorical type"),
196 }
197
198 let arr = <T::PolarsPhysical as PolarsDataType>::Array::from_vec(cat_ids)
199 .with_validity(validity.into_opt_validity());
200 let phys = ChunkedArray::<T::PolarsPhysical>::with_chunk(name, arr);
201 Ok(unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) })
202 }
203
204 pub fn to_arrow(&self, compat_level: CompatLevel) -> DictionaryArray<T::Native> {
205 let keys = self.physical().rechunk();
206 let keys = keys.downcast_as_array();
207 let values = self
208 .get_mapping()
209 .to_arrow(compat_level.uses_binview_types());
210 let values_dtype = Box::new(values.dtype().clone());
211 let dtype = ArrowDataType::Dictionary(
212 <T::Native as DictionaryKey>::KEY_TYPE,
213 values_dtype,
214 self.is_enum(),
215 );
216 unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() }
217 }
218}
219
220impl<T: PolarsCategoricalType> LogicalType for CategoricalChunked<T> {
221 fn dtype(&self) -> &DataType {
222 &self.dtype
223 }
224
225 fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
226 polars_ensure!(i < self.len(), oob = i, self.len());
227 Ok(unsafe { self.get_any_value_unchecked(i) })
228 }
229
230 unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
231 match self.phys.get_unchecked(i) {
232 Some(i) => match &self.dtype {
233 DataType::Enum(_, mapping) => AnyValue::Enum(i.as_cat(), mapping),
234 DataType::Categorical(_, mapping) => AnyValue::Categorical(i.as_cat(), mapping),
235 _ => unreachable!(),
236 },
237 None => AnyValue::Null,
238 }
239 }
240
241 fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
242 if &self.dtype == dtype {
243 return Ok(self.clone().into_series());
244 }
245
246 match dtype {
247 DataType::String => {
248 let mapping = self.get_mapping();
249
250 let mut builder = StringChunkedBuilder::new(self.phys.name().clone(), self.len());
253 let to_str = |cat_id: CatSize| unsafe { mapping.cat_to_str_unchecked(cat_id) };
254 if !self.phys.has_nulls() {
255 for cat_id in self.phys.into_no_null_iter() {
256 builder.append_value(to_str(cat_id.as_cat()));
257 }
258 } else {
259 for opt_cat_id in self.phys.iter() {
260 let opt_cat_id: Option<_> = opt_cat_id;
261 builder.append_option(opt_cat_id.map(|c| to_str(c.as_cat())));
262 }
263 }
264
265 let ca = builder.finish();
266 Ok(ca.into_series())
267 },
268
269 DataType::Enum(fcats, _mapping) => {
270 let ret = with_match_categorical_physical_type!(fcats.physical(), |$C| {
272 CategoricalChunked::<$C>::from_str_iter(
273 self.name().clone(),
274 dtype.clone(),
275 self.iter_str()
276 )?.into_series()
277 });
278
279 if options.is_strict() && self.null_count() != ret.null_count() {
280 handle_casting_failures(&self.clone().into_series(), &ret)?;
281 }
282
283 Ok(ret)
284 },
285
286 DataType::Categorical(cats, _mapping) => {
287 Ok(
289 with_match_categorical_physical_type!(cats.physical(), |$C| {
290 CategoricalChunked::<$C>::from_str_iter(
291 self.name().clone(),
292 dtype.clone(),
293 self.iter_str()
294 )?.into_series()
295 }),
296 )
297 },
298
299 dt if dt.is_integer() => self.phys.clone().cast_with_options(dtype, options),
302
303 _ => polars_bail!(ComputeError: "cannot cast categorical types to {dtype:?}"),
304 }
305 }
306}
307
308impl<T: PolarsCategoricalType> VecHash for CategoricalChunked<T>
309where
310 ChunkedArray<<T as PolarsCategoricalType>::PolarsPhysical>: VecHash,
311{
312 fn vec_hash(
313 &self,
314 random_state: PlSeedableRandomStateQuality,
315 buf: &mut Vec<u64>,
316 ) -> PolarsResult<()> {
317 if self.is_enum() {
318 self.phys.vec_hash(random_state, buf)
319 } else {
320 buf.clear();
321 buf.reserve(self.phys.len());
322 let mult = random_state.hash_one(0);
323 let null = get_null_hash_value(&random_state);
324
325 let mapping = self.get_mapping();
326 for opt_cat in self.phys.iter() {
327 if let Some(cat) = opt_cat {
328 let base_h = unsafe { mapping.cat_to_hash_unchecked(cat.as_cat()) };
329 buf.push(folded_multiply(base_h, mult));
330 } else {
331 buf.push(null);
332 }
333 }
334 Ok(())
335 }
336 }
337
338 fn vec_hash_combine(
339 &self,
340 random_state: PlSeedableRandomStateQuality,
341 hashes: &mut [u64],
342 ) -> PolarsResult<()> {
343 if self.is_enum() {
344 self.phys.vec_hash_combine(random_state, hashes)
345 } else {
346 let mult = random_state.hash_one(0);
347 let null = get_null_hash_value(&random_state);
348
349 let mapping = self.get_mapping();
350 assert!(self.phys.len() == hashes.len());
351 for (opt_cat, h) in self.phys.iter().zip(hashes.iter_mut()) {
352 let our_h = if let Some(cat) = opt_cat {
353 let base_h = unsafe { mapping.cat_to_hash_unchecked(cat.as_cat()) };
354 folded_multiply(base_h, mult)
355 } else {
356 null
357 };
358 *h = _boost_hash_combine(our_h, *h);
359 }
360 Ok(())
361 }
362 }
363}