1mod builder;
2mod from;
3mod merge;
4mod ops;
5pub mod revmap;
6pub mod string_cache;
7
8use bitflags::bitflags;
9pub use builder::*;
10pub use merge::*;
11use polars_utils::itertools::Itertools;
12use polars_utils::sync::SyncPtr;
13pub use revmap::*;
14
15use super::*;
16use crate::chunked_array::cast::CastOptions;
17use crate::chunked_array::flags::StatisticsFlags;
18use crate::prelude::*;
19use crate::series::IsSorted;
20use crate::using_string_cache;
21
22bitflags! {
23 #[derive(Default, Clone)]
24 struct BitSettings: u8 {
25 const ORIGINAL = 0x01;
26 }
27}
28
29#[derive(Clone)]
30pub struct CategoricalChunked {
31 physical: Logical<CategoricalType, UInt32Type>,
32 bit_settings: BitSettings,
35}
36
37impl CategoricalChunked {
38 pub(crate) fn field(&self) -> Field {
39 let name = self.physical().name();
40 Field::new(name.clone(), self.dtype().clone())
41 }
42
43 pub fn is_empty(&self) -> bool {
44 self.len() == 0
45 }
46
47 #[inline]
48 pub fn len(&self) -> usize {
49 self.physical.len()
50 }
51
52 #[inline]
53 pub fn null_count(&self) -> usize {
54 self.physical.null_count()
55 }
56
57 pub fn name(&self) -> &PlSmallStr {
58 self.physical.name()
59 }
60
61 pub fn into_physical(self) -> UInt32Chunked {
63 self.physical.phys
64 }
65
66 pub fn physical(&self) -> &UInt32Chunked {
69 &self.physical
70 }
71
72 pub(crate) fn physical_mut(&mut self) -> &mut UInt32Chunked {
74 &mut self.physical
75 }
76
77 pub fn is_enum(&self) -> bool {
78 matches!(self.dtype(), DataType::Enum(_, _))
79 }
80
81 pub fn to_local(&self) -> Self {
83 let rev_map = self.get_rev_map();
84 let (physical_map, categories) = match rev_map.as_ref() {
85 RevMapping::Global(m, c, _) => (m, c),
86 RevMapping::Local(_, _) if !self.is_enum() => return self.clone(),
87 RevMapping::Local(_, _) => {
88 let mut local = self.clone();
90 local.physical.dtype =
91 DataType::Categorical(Some(rev_map.clone()), self.get_ordering());
92 return local;
93 },
94 };
95
96 let local_rev_map = RevMapping::build_local(categories.clone());
97 let local_ca = self
101 .physical()
102 .apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap()));
103
104 let mut out = unsafe {
105 Self::from_cats_and_rev_map_unchecked(
106 local_ca,
107 local_rev_map.into(),
108 false,
109 self.get_ordering(),
110 )
111 };
112 out.set_fast_unique(self._can_fast_unique());
113
114 out
115 }
116
117 pub fn to_global(&self) -> PolarsResult<Self> {
118 polars_ensure!(using_string_cache(), string_cache_mismatch);
119 let categories = match &**self.get_rev_map() {
121 RevMapping::Global(_, _, _) => return Ok(self.clone()),
122 RevMapping::Local(categories, _) => categories,
123 };
124
125 unsafe {
127 Ok(CategoricalChunked::from_keys_and_values_global(
128 self.name().clone(),
129 self.physical(),
130 self.len(),
131 categories,
132 self.get_ordering(),
133 ))
134 }
135 }
136
137 pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> Self {
139 match self.get_rev_map().as_ref() {
141 RevMapping::Local(_, cur_hash) if hash == *cur_hash => {
142 return unsafe {
143 CategoricalChunked::from_cats_and_rev_map_unchecked(
144 self.physical().clone(),
145 self.get_rev_map().clone(),
146 true,
147 self.get_ordering(),
148 )
149 };
150 },
151 _ => (),
152 };
153 let old_rev_map = self.get_rev_map();
155
156 let old_categories = old_rev_map.get_categories();
158 let old_idx_map: PlHashMap<&str, u32> = old_categories
159 .values_iter()
160 .zip(0..old_categories.len() as u32)
161 .collect();
162
163 #[allow(clippy::unnecessary_cast)]
164 let idx_map: PlHashMap<u32, u32> = categories
165 .values_iter()
166 .enumerate_idx()
167 .filter_map(|(new_idx, s)| old_idx_map.get(s).map(|old_idx| (*old_idx, new_idx as u32)))
168 .collect();
169
170 let new_phys: UInt32Chunked = self
172 .physical()
173 .into_iter()
174 .map(|opt_v: Option<u32>| opt_v.and_then(|v| idx_map.get(&v).copied()))
175 .collect();
176
177 unsafe {
179 CategoricalChunked::from_cats_and_rev_map_unchecked(
180 new_phys,
181 Arc::new(RevMapping::Local(categories.clone(), hash)),
182 true,
183 self.get_ordering(),
184 )
185 }
186 }
187
188 pub(crate) fn get_flags(&self) -> StatisticsFlags {
189 self.physical().get_flags()
190 }
191
192 pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) {
194 if self.uses_lexical_ordering() {
196 flags.set_sorted(IsSorted::Not)
197 }
198 self.physical_mut().set_flags(flags)
199 }
200
201 pub fn uses_lexical_ordering(&self) -> bool {
204 self.get_ordering() == CategoricalOrdering::Lexical
205 }
206
207 pub fn get_ordering(&self) -> CategoricalOrdering {
208 if let DataType::Categorical(_, ordering) | DataType::Enum(_, ordering) =
209 &self.physical.dtype
210 {
211 *ordering
212 } else {
213 panic!("implementation error")
214 }
215 }
216
217 pub unsafe fn from_cats_and_dtype_unchecked(idx: UInt32Chunked, dtype: DataType) -> Self {
223 debug_assert!(matches!(
224 dtype,
225 DataType::Enum { .. } | DataType::Categorical { .. }
226 ));
227 Self {
228 physical: Logical::new_logical(idx, dtype),
229 bit_settings: Default::default(),
230 }
231 }
232
233 pub unsafe fn from_cats_and_rev_map_unchecked(
238 idx: UInt32Chunked,
239 rev_map: Arc<RevMapping>,
240 is_enum: bool,
241 ordering: CategoricalOrdering,
242 ) -> Self {
243 let dtype = if is_enum {
244 DataType::Enum(Some(rev_map), ordering)
245 } else {
246 DataType::Categorical(Some(rev_map), ordering)
247 };
248 Self {
249 physical: Logical::new_logical(idx, dtype),
250 bit_settings: Default::default(),
251 }
252 }
253
254 pub(crate) fn set_ordering(
255 mut self,
256 ordering: CategoricalOrdering,
257 keep_fast_unique: bool,
258 ) -> Self {
259 self.physical.dtype = match self.dtype() {
260 DataType::Enum(_, _) => DataType::Enum(Some(self.get_rev_map().clone()), ordering),
261 DataType::Categorical(_, _) => {
262 DataType::Categorical(Some(self.get_rev_map().clone()), ordering)
263 },
264 _ => panic!("implementation error"),
265 };
266
267 if !keep_fast_unique {
268 self.set_fast_unique(false)
269 }
270 self
271 }
272
273 pub(crate) unsafe fn set_rev_map(&mut self, rev_map: Arc<RevMapping>, keep_fast_unique: bool) {
276 self.physical.dtype = match self.dtype() {
277 DataType::Enum(_, _) => DataType::Enum(Some(rev_map), self.get_ordering()),
278 DataType::Categorical(_, _) => {
279 DataType::Categorical(Some(rev_map), self.get_ordering())
280 },
281 _ => panic!("implementation error"),
282 };
283
284 if !keep_fast_unique {
285 self.set_fast_unique(false)
286 }
287 }
288
289 pub fn _can_fast_unique(&self) -> bool {
292 self.bit_settings.contains(BitSettings::ORIGINAL)
293 && self.physical.chunks.len() == 1
294 && self.null_count() == 0
295 }
296
297 pub(crate) fn set_fast_unique(&mut self, toggle: bool) {
298 if toggle {
299 self.bit_settings.insert(BitSettings::ORIGINAL);
300 } else {
301 self.bit_settings.remove(BitSettings::ORIGINAL);
302 }
303 }
304
305 pub(crate) unsafe fn with_fast_unique(mut self, toggle: bool) -> Self {
309 self.set_fast_unique(toggle);
310 self
311 }
312
313 pub unsafe fn _with_fast_unique(self, toggle: bool) -> Self {
317 self.with_fast_unique(toggle)
318 }
319
320 pub fn get_rev_map(&self) -> &Arc<RevMapping> {
322 if let DataType::Categorical(Some(rev_map), _) | DataType::Enum(Some(rev_map), _) =
323 &self.physical.dtype
324 {
325 rev_map
326 } else {
327 panic!("implementation error")
328 }
329 }
330
331 pub fn iter_str(&self) -> CatIter<'_> {
333 let iter = self.physical().into_iter();
334 CatIter {
335 rev: self.get_rev_map(),
336 iter,
337 }
338 }
339}
340
341impl LogicalType for CategoricalChunked {
342 fn dtype(&self) -> &DataType {
343 &self.physical.dtype
344 }
345
346 fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
347 polars_ensure!(i < self.len(), oob = i, self.len());
348 Ok(unsafe { self.get_any_value_unchecked(i) })
349 }
350
351 unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
352 match self.physical.phys.get_unchecked(i) {
353 Some(i) => match self.dtype() {
354 DataType::Enum(_, _) => AnyValue::Enum(i, self.get_rev_map(), SyncPtr::new_null()),
355 DataType::Categorical(_, _) => {
356 AnyValue::Categorical(i, self.get_rev_map(), SyncPtr::new_null())
357 },
358 _ => unimplemented!(),
359 },
360 None => AnyValue::Null,
361 }
362 }
363
364 fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
365 match dtype {
366 DataType::String => {
367 let mapping = &**self.get_rev_map();
368
369 let mut builder =
370 StringChunkedBuilder::new(self.physical.name().clone(), self.len());
371
372 let f = |idx: u32| mapping.get(idx);
373
374 if !self.physical.has_nulls() {
375 self.physical
376 .into_no_null_iter()
377 .for_each(|idx| builder.append_value(f(idx)));
378 } else {
379 self.physical.into_iter().for_each(|opt_idx| {
380 builder.append_option(opt_idx.map(f));
381 });
382 }
383
384 let ca = builder.finish();
385 Ok(ca.into_series())
386 },
387 DataType::UInt32 => {
388 let ca = unsafe {
389 UInt32Chunked::from_chunks(
390 self.physical.name().clone(),
391 self.physical.chunks.clone(),
392 )
393 };
394 Ok(ca.into_series())
395 },
396 #[cfg(feature = "dtype-categorical")]
397 DataType::Enum(Some(rev_map), ordering) => {
398 let RevMapping::Local(categories, hash) = &**rev_map else {
399 polars_bail!(ComputeError: "can not cast to enum with global mapping")
400 };
401 Ok(self
402 .to_enum(categories, *hash)
403 .set_ordering(*ordering, true)
404 .into_series()
405 .with_name(self.name().clone()))
406 },
407 DataType::Enum(None, _) => {
408 polars_bail!(ComputeError: "can not cast to enum without categories present")
409 },
410 #[cfg(feature = "dtype-categorical")]
411 DataType::Categorical(rev_map, ordering) => {
412 if matches!(self.dtype(), DataType::Enum(_, _)) && rev_map.is_none() {
414 if using_string_cache() {
415 return Ok(self
416 .to_global()?
417 .set_ordering(*ordering, true)
418 .into_series());
419 } else {
420 return Ok(self.to_local().set_ordering(*ordering, true).into_series());
421 }
422 }
423 let mut ca = self.clone().set_ordering(*ordering, true);
426 if ca.uses_lexical_ordering() {
427 ca.physical.set_sorted_flag(IsSorted::Not);
428 }
429 Ok(ca.into_series())
430 },
431 dt if dt.is_primitive_numeric() => {
432 let slf = self.to_local();
435 let categories = StringChunked::with_chunk(
436 slf.physical.name().clone(),
437 slf.get_rev_map().get_categories().clone(),
438 );
439 let casted_series = categories.cast_with_options(dtype, options)?;
440
441 #[cfg(feature = "bigidx")]
442 {
443 let s = slf.physical.cast_with_options(&DataType::UInt64, options)?;
444 Ok(unsafe { casted_series.take_unchecked(s.u64()?) })
445 }
446 #[cfg(not(feature = "bigidx"))]
447 {
448 Ok(unsafe { casted_series.take_unchecked(&slf.physical) })
450 }
451 },
452 _ => self.physical.cast_with_options(dtype, options),
453 }
454 }
455}
456
457pub struct CatIter<'a> {
458 rev: &'a RevMapping,
459 iter: Box<dyn PolarsIterator<Item = Option<u32>> + 'a>,
460}
461
462unsafe impl TrustedLen for CatIter<'_> {}
463
464impl<'a> Iterator for CatIter<'a> {
465 type Item = Option<&'a str>;
466
467 fn next(&mut self) -> Option<Self::Item> {
468 self.iter.next().map(|item| {
469 item.map(|idx| {
470 unsafe { self.rev.get_unchecked(idx) }
473 })
474 })
475 }
476
477 fn size_hint(&self) -> (usize, Option<usize>) {
478 self.iter.size_hint()
479 }
480}
481
482impl DoubleEndedIterator for CatIter<'_> {
483 fn next_back(&mut self) -> Option<Self::Item> {
484 self.iter.next_back().map(|item| {
485 item.map(|idx| {
486 unsafe { self.rev.get_unchecked(idx) }
489 })
490 })
491 }
492}
493
494impl ExactSizeIterator for CatIter<'_> {}
495
496#[cfg(test)]
497mod test {
498 use super::*;
499 use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache};
500
501 #[test]
502 fn test_categorical_round_trip() -> PolarsResult<()> {
503 let _lock = SINGLE_LOCK.lock();
504 disable_string_cache();
505 let slice = &[
506 Some("foo"),
507 None,
508 Some("bar"),
509 Some("foo"),
510 Some("foo"),
511 Some("bar"),
512 ];
513 let ca = StringChunked::new(PlSmallStr::from_static("a"), slice);
514 let ca = ca.cast(&DataType::Categorical(None, Default::default()))?;
515 let ca = ca.categorical().unwrap();
516
517 let arr = ca.to_arrow(CompatLevel::newest(), false);
518 let s = Series::try_from((PlSmallStr::from_static("foo"), arr))?;
519 assert!(matches!(s.dtype(), &DataType::Categorical(_, _)));
520 assert_eq!(s.null_count(), 1);
521 assert_eq!(s.len(), 6);
522
523 Ok(())
524 }
525
526 #[test]
527 fn test_append_categorical() {
528 let _lock = SINGLE_LOCK.lock();
529 disable_string_cache();
530 enable_string_cache();
531
532 let mut s1 = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"])
533 .cast(&DataType::Categorical(None, Default::default()))
534 .unwrap();
535 let s2 = Series::new(PlSmallStr::from_static("2"), vec!["a", "x", "y"])
536 .cast(&DataType::Categorical(None, Default::default()))
537 .unwrap();
538 let appended = s1.append(&s2).unwrap();
539 assert_eq!(appended.str_value(0).unwrap(), "a");
540 assert_eq!(appended.str_value(1).unwrap(), "b");
541 assert_eq!(appended.str_value(4).unwrap(), "x");
542 assert_eq!(appended.str_value(5).unwrap(), "y");
543 }
544
545 #[test]
546 fn test_fast_unique() {
547 let _lock = SINGLE_LOCK.lock();
548 let s = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"])
549 .cast(&DataType::Categorical(None, Default::default()))
550 .unwrap();
551
552 assert_eq!(s.n_unique().unwrap(), 3);
553 let out = s.take(&IdxCa::new(PlSmallStr::EMPTY, [1, 2])).unwrap();
555 assert_eq!(out.n_unique().unwrap(), 2);
556 let out = s.slice(1, 2);
557 assert_eq!(out.n_unique().unwrap(), 2);
558 }
559
560 #[test]
561 fn test_categorical_flow() -> PolarsResult<()> {
562 let _lock = SINGLE_LOCK.lock();
563 disable_string_cache();
564
565 let s = Series::new(PlSmallStr::from_static("a"), vec!["a", "b", "c"])
567 .cast(&DataType::Categorical(None, Default::default()))?;
568
569 assert_eq!(
570 s.field().into_owned(),
571 Field::new(
572 PlSmallStr::from_static("a"),
573 DataType::Categorical(None, Default::default())
574 )
575 );
576 assert!(matches!(
577 s.get(0)?,
578 AnyValue::Categorical(0, RevMapping::Local(_, _), _)
579 ));
580
581 let groups = s.group_tuples(false, true);
582 let aggregated = unsafe { s.agg_list(&groups?) };
583 match aggregated.get(0)? {
584 AnyValue::List(s) => {
585 assert!(matches!(s.dtype(), DataType::Categorical(_, _)));
586 let str_s = s.cast(&DataType::String).unwrap();
587 assert_eq!(str_s.get(0)?, AnyValue::String("a"));
588 assert_eq!(s.len(), 1);
589 },
590 _ => panic!(),
591 }
592 let flat = aggregated.explode(false)?;
593 let ca = flat.categorical().unwrap();
594 let vals = ca.iter_str().map(|v| v.unwrap()).collect::<Vec<_>>();
595 assert_eq!(vals, &["a", "b", "c"]);
596 Ok(())
597 }
598}