1mod builder;
2mod from;
3mod merge;
4mod ops;
5pub mod revmap;
6pub mod string_cache;
7
8use bitflags::bitflags;
9pub use builder::*;
10pub use merge::*;
11use polars_utils::itertools::Itertools;
12use polars_utils::sync::SyncPtr;
13pub use revmap::*;
14
15use super::*;
16use crate::chunked_array::cast::CastOptions;
17use crate::chunked_array::flags::StatisticsFlags;
18use crate::prelude::*;
19use crate::series::IsSorted;
20use crate::using_string_cache;
21
22bitflags! {
23 #[derive(Default, Clone)]
24 struct BitSettings: u8 {
25 const ORIGINAL = 0x01;
26 }
27}
28
29#[derive(Default, Clone)]
30pub struct CategoricalChunked {
31 physical: Logical<CategoricalType, UInt32Type>,
32 bit_settings: BitSettings,
35}
36
37impl CategoricalChunked {
38 pub(crate) fn field(&self) -> Field {
39 let name = self.physical().name();
40 Field::new(name.clone(), self.dtype().clone())
41 }
42
43 pub fn is_empty(&self) -> bool {
44 self.len() == 0
45 }
46
47 #[inline]
48 pub fn len(&self) -> usize {
49 self.physical.len()
50 }
51
52 #[inline]
53 pub fn null_count(&self) -> usize {
54 self.physical.null_count()
55 }
56
57 pub fn name(&self) -> &PlSmallStr {
58 self.physical.name()
59 }
60
61 pub fn physical(&self) -> &UInt32Chunked {
64 &self.physical
65 }
66
67 pub(crate) fn physical_mut(&mut self) -> &mut UInt32Chunked {
69 &mut self.physical
70 }
71
72 pub fn is_enum(&self) -> bool {
73 matches!(self.dtype(), DataType::Enum(_, _))
74 }
75
76 pub fn to_local(&self) -> Self {
78 let rev_map = self.get_rev_map();
79 let (physical_map, categories) = match rev_map.as_ref() {
80 RevMapping::Global(m, c, _) => (m, c),
81 RevMapping::Local(_, _) if !self.is_enum() => return self.clone(),
82 RevMapping::Local(_, _) => {
83 let mut local = self.clone();
85 local.physical.2 = Some(DataType::Categorical(
86 Some(rev_map.clone()),
87 self.get_ordering(),
88 ));
89 return local;
90 },
91 };
92
93 let local_rev_map = RevMapping::build_local(categories.clone());
94 let local_ca = self
98 .physical()
99 .apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap()));
100
101 let mut out = unsafe {
102 Self::from_cats_and_rev_map_unchecked(
103 local_ca,
104 local_rev_map.into(),
105 false,
106 self.get_ordering(),
107 )
108 };
109 out.set_fast_unique(self._can_fast_unique());
110
111 out
112 }
113
114 pub fn to_global(&self) -> PolarsResult<Self> {
115 polars_ensure!(using_string_cache(), string_cache_mismatch);
116 let categories = match &**self.get_rev_map() {
118 RevMapping::Global(_, _, _) => return Ok(self.clone()),
119 RevMapping::Local(categories, _) => categories,
120 };
121
122 unsafe {
124 Ok(CategoricalChunked::from_keys_and_values_global(
125 self.name().clone(),
126 self.physical(),
127 self.len(),
128 categories,
129 self.get_ordering(),
130 ))
131 }
132 }
133
134 pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> Self {
136 match self.get_rev_map().as_ref() {
138 RevMapping::Local(_, cur_hash) if hash == *cur_hash => {
139 return unsafe {
140 CategoricalChunked::from_cats_and_rev_map_unchecked(
141 self.physical().clone(),
142 self.get_rev_map().clone(),
143 true,
144 self.get_ordering(),
145 )
146 };
147 },
148 _ => (),
149 };
150 let old_rev_map = self.get_rev_map();
152 #[allow(clippy::unnecessary_cast)]
153 let idx_map: PlHashMap<u32, u32> = categories
154 .values_iter()
155 .enumerate_idx()
156 .filter_map(|(new_idx, s)| old_rev_map.find(s).map(|old_idx| (old_idx, new_idx as u32)))
157 .collect();
158
159 let new_phys: UInt32Chunked = self
161 .physical()
162 .into_iter()
163 .map(|opt_v: Option<u32>| opt_v.and_then(|v| idx_map.get(&v).copied()))
164 .collect();
165
166 unsafe {
168 CategoricalChunked::from_cats_and_rev_map_unchecked(
169 new_phys,
170 Arc::new(RevMapping::Local(categories.clone(), hash)),
171 true,
172 self.get_ordering(),
173 )
174 }
175 }
176
177 pub(crate) fn get_flags(&self) -> StatisticsFlags {
178 self.physical().get_flags()
179 }
180
181 pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) {
183 if self.uses_lexical_ordering() {
185 flags.set_sorted(IsSorted::Not)
186 }
187 self.physical_mut().set_flags(flags)
188 }
189
190 pub fn uses_lexical_ordering(&self) -> bool {
193 self.get_ordering() == CategoricalOrdering::Lexical
194 }
195
196 pub(crate) fn get_ordering(&self) -> CategoricalOrdering {
197 if let DataType::Categorical(_, ordering) | DataType::Enum(_, ordering) =
198 &self.physical.2.as_ref().unwrap()
199 {
200 *ordering
201 } else {
202 panic!("implementation error")
203 }
204 }
205
206 pub unsafe fn from_cats_and_dtype_unchecked(idx: UInt32Chunked, dtype: DataType) -> Self {
212 debug_assert!(matches!(
213 dtype,
214 DataType::Enum { .. } | DataType::Categorical { .. }
215 ));
216 let mut logical = Logical::<UInt32Type, _>::new_logical::<CategoricalType>(idx);
217 logical.2 = Some(dtype);
218 Self {
219 physical: logical,
220 bit_settings: Default::default(),
221 }
222 }
223
224 pub unsafe fn from_cats_and_rev_map_unchecked(
229 idx: UInt32Chunked,
230 rev_map: Arc<RevMapping>,
231 is_enum: bool,
232 ordering: CategoricalOrdering,
233 ) -> Self {
234 let mut logical = Logical::<UInt32Type, _>::new_logical::<CategoricalType>(idx);
235 if is_enum {
236 logical.2 = Some(DataType::Enum(Some(rev_map), ordering));
237 } else {
238 logical.2 = Some(DataType::Categorical(Some(rev_map), ordering));
239 }
240 Self {
241 physical: logical,
242 bit_settings: Default::default(),
243 }
244 }
245
246 pub(crate) fn set_ordering(
247 mut self,
248 ordering: CategoricalOrdering,
249 keep_fast_unique: bool,
250 ) -> Self {
251 self.physical.2 = match self.dtype() {
252 DataType::Enum(_, _) => {
253 Some(DataType::Enum(Some(self.get_rev_map().clone()), ordering))
254 },
255 DataType::Categorical(_, _) => Some(DataType::Categorical(
256 Some(self.get_rev_map().clone()),
257 ordering,
258 )),
259 _ => panic!("implementation error"),
260 };
261
262 if !keep_fast_unique {
263 self.set_fast_unique(false)
264 }
265 self
266 }
267
268 pub(crate) unsafe fn set_rev_map(&mut self, rev_map: Arc<RevMapping>, keep_fast_unique: bool) {
271 self.physical.2 = match self.dtype() {
272 DataType::Enum(_, _) => Some(DataType::Enum(Some(rev_map), self.get_ordering())),
273 DataType::Categorical(_, _) => {
274 Some(DataType::Categorical(Some(rev_map), self.get_ordering()))
275 },
276 _ => panic!("implementation error"),
277 };
278
279 if !keep_fast_unique {
280 self.set_fast_unique(false)
281 }
282 }
283
284 pub fn _can_fast_unique(&self) -> bool {
287 self.bit_settings.contains(BitSettings::ORIGINAL)
288 && self.physical.chunks.len() == 1
289 && self.null_count() == 0
290 }
291
292 pub(crate) fn set_fast_unique(&mut self, toggle: bool) {
293 if toggle {
294 self.bit_settings.insert(BitSettings::ORIGINAL);
295 } else {
296 self.bit_settings.remove(BitSettings::ORIGINAL);
297 }
298 }
299
300 pub(crate) unsafe fn with_fast_unique(mut self, toggle: bool) -> Self {
304 self.set_fast_unique(toggle);
305 self
306 }
307
308 pub unsafe fn _with_fast_unique(self, toggle: bool) -> Self {
312 self.with_fast_unique(toggle)
313 }
314
315 pub fn get_rev_map(&self) -> &Arc<RevMapping> {
317 if let DataType::Categorical(Some(rev_map), _) | DataType::Enum(Some(rev_map), _) =
318 &self.physical.2.as_ref().unwrap()
319 {
320 rev_map
321 } else {
322 panic!("implementation error")
323 }
324 }
325
326 pub fn iter_str(&self) -> CatIter<'_> {
328 let iter = self.physical().into_iter();
329 CatIter {
330 rev: self.get_rev_map(),
331 iter,
332 }
333 }
334}
335
336impl LogicalType for CategoricalChunked {
337 fn dtype(&self) -> &DataType {
338 self.physical.2.as_ref().unwrap()
339 }
340
341 fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
342 polars_ensure!(i < self.len(), oob = i, self.len());
343 Ok(unsafe { self.get_any_value_unchecked(i) })
344 }
345
346 unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
347 match self.physical.0.get_unchecked(i) {
348 Some(i) => match self.dtype() {
349 DataType::Enum(_, _) => AnyValue::Enum(i, self.get_rev_map(), SyncPtr::new_null()),
350 DataType::Categorical(_, _) => {
351 AnyValue::Categorical(i, self.get_rev_map(), SyncPtr::new_null())
352 },
353 _ => unimplemented!(),
354 },
355 None => AnyValue::Null,
356 }
357 }
358
359 fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
360 match dtype {
361 DataType::String => {
362 let mapping = &**self.get_rev_map();
363
364 let mut builder =
365 StringChunkedBuilder::new(self.physical.name().clone(), self.len());
366
367 let f = |idx: u32| mapping.get(idx);
368
369 if !self.physical.has_nulls() {
370 self.physical
371 .into_no_null_iter()
372 .for_each(|idx| builder.append_value(f(idx)));
373 } else {
374 self.physical.into_iter().for_each(|opt_idx| {
375 builder.append_option(opt_idx.map(f));
376 });
377 }
378
379 let ca = builder.finish();
380 Ok(ca.into_series())
381 },
382 DataType::UInt32 => {
383 let ca = unsafe {
384 UInt32Chunked::from_chunks(
385 self.physical.name().clone(),
386 self.physical.chunks.clone(),
387 )
388 };
389 Ok(ca.into_series())
390 },
391 #[cfg(feature = "dtype-categorical")]
392 DataType::Enum(Some(rev_map), ordering) => {
393 let RevMapping::Local(categories, hash) = &**rev_map else {
394 polars_bail!(ComputeError: "can not cast to enum with global mapping")
395 };
396 Ok(self
397 .to_enum(categories, *hash)
398 .set_ordering(*ordering, true)
399 .into_series()
400 .with_name(self.name().clone()))
401 },
402 DataType::Enum(None, _) => {
403 polars_bail!(ComputeError: "can not cast to enum without categories present")
404 },
405 #[cfg(feature = "dtype-categorical")]
406 DataType::Categorical(rev_map, ordering) => {
407 if matches!(self.dtype(), DataType::Enum(_, _)) && rev_map.is_none() {
409 if using_string_cache() {
410 return Ok(self
411 .to_global()?
412 .set_ordering(*ordering, true)
413 .into_series());
414 } else {
415 return Ok(self.to_local().set_ordering(*ordering, true).into_series());
416 }
417 }
418 let mut ca = self.clone().set_ordering(*ordering, true);
421 if ca.uses_lexical_ordering() {
422 ca.physical.set_sorted_flag(IsSorted::Not);
423 }
424 Ok(ca.into_series())
425 },
426 dt if dt.is_primitive_numeric() => {
427 let slf = self.to_local();
430 let categories = StringChunked::with_chunk(
431 slf.physical.name().clone(),
432 slf.get_rev_map().get_categories().clone(),
433 );
434 let casted_series = categories.cast_with_options(dtype, options)?;
435
436 #[cfg(feature = "bigidx")]
437 {
438 let s = slf.physical.cast_with_options(&DataType::UInt64, options)?;
439 Ok(unsafe { casted_series.take_unchecked(s.u64()?) })
440 }
441 #[cfg(not(feature = "bigidx"))]
442 {
443 Ok(unsafe { casted_series.take_unchecked(&slf.physical) })
445 }
446 },
447 _ => self.physical.cast_with_options(dtype, options),
448 }
449 }
450}
451
452pub struct CatIter<'a> {
453 rev: &'a RevMapping,
454 iter: Box<dyn PolarsIterator<Item = Option<u32>> + 'a>,
455}
456
457unsafe impl TrustedLen for CatIter<'_> {}
458
459impl<'a> Iterator for CatIter<'a> {
460 type Item = Option<&'a str>;
461
462 fn next(&mut self) -> Option<Self::Item> {
463 self.iter.next().map(|item| {
464 item.map(|idx| {
465 unsafe { self.rev.get_unchecked(idx) }
468 })
469 })
470 }
471
472 fn size_hint(&self) -> (usize, Option<usize>) {
473 self.iter.size_hint()
474 }
475}
476
477impl DoubleEndedIterator for CatIter<'_> {
478 fn next_back(&mut self) -> Option<Self::Item> {
479 self.iter.next_back().map(|item| {
480 item.map(|idx| {
481 unsafe { self.rev.get_unchecked(idx) }
484 })
485 })
486 }
487}
488
489impl ExactSizeIterator for CatIter<'_> {}
490
491#[cfg(test)]
492mod test {
493 use super::*;
494 use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache};
495
496 #[test]
497 fn test_categorical_round_trip() -> PolarsResult<()> {
498 let _lock = SINGLE_LOCK.lock();
499 disable_string_cache();
500 let slice = &[
501 Some("foo"),
502 None,
503 Some("bar"),
504 Some("foo"),
505 Some("foo"),
506 Some("bar"),
507 ];
508 let ca = StringChunked::new(PlSmallStr::from_static("a"), slice);
509 let ca = ca.cast(&DataType::Categorical(None, Default::default()))?;
510 let ca = ca.categorical().unwrap();
511
512 let arr = ca.to_arrow(CompatLevel::newest(), false);
513 let s = Series::try_from((PlSmallStr::from_static("foo"), arr))?;
514 assert!(matches!(s.dtype(), &DataType::Categorical(_, _)));
515 assert_eq!(s.null_count(), 1);
516 assert_eq!(s.len(), 6);
517
518 Ok(())
519 }
520
521 #[test]
522 fn test_append_categorical() {
523 let _lock = SINGLE_LOCK.lock();
524 disable_string_cache();
525 enable_string_cache();
526
527 let mut s1 = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"])
528 .cast(&DataType::Categorical(None, Default::default()))
529 .unwrap();
530 let s2 = Series::new(PlSmallStr::from_static("2"), vec!["a", "x", "y"])
531 .cast(&DataType::Categorical(None, Default::default()))
532 .unwrap();
533 let appended = s1.append(&s2).unwrap();
534 assert_eq!(appended.str_value(0).unwrap(), "a");
535 assert_eq!(appended.str_value(1).unwrap(), "b");
536 assert_eq!(appended.str_value(4).unwrap(), "x");
537 assert_eq!(appended.str_value(5).unwrap(), "y");
538 }
539
540 #[test]
541 fn test_fast_unique() {
542 let _lock = SINGLE_LOCK.lock();
543 let s = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"])
544 .cast(&DataType::Categorical(None, Default::default()))
545 .unwrap();
546
547 assert_eq!(s.n_unique().unwrap(), 3);
548 let out = s.take(&IdxCa::new(PlSmallStr::EMPTY, [1, 2])).unwrap();
550 assert_eq!(out.n_unique().unwrap(), 2);
551 let out = s.slice(1, 2);
552 assert_eq!(out.n_unique().unwrap(), 2);
553 }
554
555 #[test]
556 fn test_categorical_flow() -> PolarsResult<()> {
557 let _lock = SINGLE_LOCK.lock();
558 disable_string_cache();
559
560 let s = Series::new(PlSmallStr::from_static("a"), vec!["a", "b", "c"])
562 .cast(&DataType::Categorical(None, Default::default()))?;
563
564 assert_eq!(
565 s.field().into_owned(),
566 Field::new(
567 PlSmallStr::from_static("a"),
568 DataType::Categorical(None, Default::default())
569 )
570 );
571 assert!(matches!(
572 s.get(0)?,
573 AnyValue::Categorical(0, RevMapping::Local(_, _), _)
574 ));
575
576 let groups = s.group_tuples(false, true);
577 let aggregated = unsafe { s.agg_list(&groups?) };
578 match aggregated.get(0)? {
579 AnyValue::List(s) => {
580 assert!(matches!(s.dtype(), DataType::Categorical(_, _)));
581 let str_s = s.cast(&DataType::String).unwrap();
582 assert_eq!(str_s.get(0)?, AnyValue::String("a"));
583 assert_eq!(s.len(), 1);
584 },
585 _ => panic!(),
586 }
587 let flat = aggregated.explode()?;
588 let ca = flat.categorical().unwrap();
589 let vals = ca.iter_str().map(|v| v.unwrap()).collect::<Vec<_>>();
590 assert_eq!(vals, &["a", "b", "c"]);
591 Ok(())
592 }
593}