1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::array::*;
3use arrow::legacy::trusted_len::TrustedLenPush;
4use hashbrown::hash_map::Entry;
5use hashbrown::hash_table::{Entry as HTEntry, HashTable};
6use polars_utils::itertools::Itertools;
7
8use crate::hashing::_HASHMAP_INIT_SIZE;
9use crate::prelude::*;
10use crate::{POOL, StringCache, using_string_cache};
11
12pub struct CategoricalChunkedBuilder {
13 cat_builder: UInt32Vec,
14 name: PlSmallStr,
15 ordering: CategoricalOrdering,
16 categories: MutablePlString,
17 local_mapping: HashTable<u32>,
18 local_hasher: PlRandomState,
19}
20
21impl CategoricalChunkedBuilder {
22 pub fn new(name: PlSmallStr, capacity: usize, ordering: CategoricalOrdering) -> Self {
23 Self {
24 cat_builder: UInt32Vec::with_capacity(capacity),
25 name,
26 ordering,
27 categories: MutablePlString::with_capacity(_HASHMAP_INIT_SIZE),
28 local_mapping: HashTable::with_capacity(capacity / 10),
29 local_hasher: StringCache::get_hash_builder(),
30 }
31 }
32
33 fn get_cat_idx(&mut self, s: &str, h: u64) -> (u32, bool) {
34 let len = self.local_mapping.len() as u32;
35
36 unsafe {
38 let r = self.local_mapping.entry(
39 h,
40 |k| self.categories.value_unchecked(*k as usize) == s,
41 |k| {
42 self.local_hasher
43 .hash_one(self.categories.value_unchecked(*k as usize))
44 },
45 );
46
47 match r {
48 HTEntry::Occupied(v) => (*v.get(), false),
49 HTEntry::Vacant(slot) => {
50 self.categories.push(Some(s));
51 slot.insert(len);
52 (len, true)
53 },
54 }
55 }
56 }
57
58 fn try_get_cat_idx(&mut self, s: &str, h: u64) -> Option<u32> {
59 unsafe {
61 let r = self.local_mapping.entry(
62 h,
63 |k| self.categories.value_unchecked(*k as usize) == s,
64 |k| {
65 self.local_hasher
66 .hash_one(self.categories.value_unchecked(*k as usize))
67 },
68 );
69
70 match r {
71 HTEntry::Occupied(v) => Some(*v.get()),
72 HTEntry::Vacant(_) => None,
73 }
74 }
75 }
76
77 #[inline]
80 pub fn try_append_value(&mut self, s: &str) -> PolarsResult<()> {
81 let h = self.local_hasher.hash_one(s);
82 let idx = self.try_get_cat_idx(s, h).ok_or_else(
83 || polars_err!(ComputeError: "category {} doesn't exist in Enum dtype", s),
84 )?;
85 self.cat_builder.push(Some(idx));
86 Ok(())
87 }
88
89 #[inline]
92 pub fn try_append(&mut self, opt_s: Option<&str>) -> PolarsResult<()> {
93 match opt_s {
94 None => self.append_null(),
95 Some(s) => self.try_append_value(s)?,
96 }
97 Ok(())
98 }
99
100 #[inline]
103 pub fn register_value(&mut self, s: &str) -> (u32, bool) {
104 let h = self.local_hasher.hash_one(s);
105 self.get_cat_idx(s, h)
106 }
107
108 #[inline]
109 pub fn append_value(&mut self, s: &str) {
110 let h = self.local_hasher.hash_one(s);
111 let idx = self.get_cat_idx(s, h).0;
112 self.cat_builder.push(Some(idx));
113 }
114
115 #[inline]
116 pub fn append_null(&mut self) {
117 self.cat_builder.push(None)
118 }
119
120 #[inline]
121 pub fn append(&mut self, opt_s: Option<&str>) {
122 match opt_s {
123 None => self.append_null(),
124 Some(s) => self.append_value(s),
125 }
126 }
127
128 fn drain_iter<'a, I>(&mut self, i: I)
129 where
130 I: IntoIterator<Item = Option<&'a str>>,
131 {
132 for opt_s in i.into_iter() {
133 self.append(opt_s);
134 }
135 }
136
137 fn drain_iter_global_and_finish<'a, I>(&mut self, i: I) -> CategoricalChunked
140 where
141 I: IntoIterator<Item = Option<&'a str>>,
142 {
143 let iter = i.into_iter();
144 let mut hashes = Vec::with_capacity(_HASHMAP_INIT_SIZE);
146 for s in self.categories.values_iter() {
147 hashes.push(self.local_hasher.hash_one(s));
148 }
149
150 for opt_s in iter {
151 match opt_s {
152 None => self.append_null(),
153 Some(s) => {
154 let hash = self.local_hasher.hash_one(s);
155 let (cat_idx, new) = self.get_cat_idx(s, hash);
156 self.cat_builder.push(Some(cat_idx));
157 if new {
158 hashes.push(hash);
160 }
161 },
162 }
163 }
164
165 let categories = std::mem::take(&mut self.categories).freeze();
166
167 let mut local_to_global: Vec<u32> = Vec::with_capacity(categories.len());
170 let (id, local_to_global) = crate::STRING_CACHE.apply(|cache| {
171 for (s, h) in categories.values_iter().zip(hashes) {
172 unsafe { local_to_global.push_unchecked(cache.insert_from_hash(h, s)) }
174 }
175 local_to_global
176 });
177
178 let update_cats = || {
180 if !local_to_global.is_empty() {
181 self.cat_builder.apply_values(|cats| {
183 for cat in cats {
184 debug_assert!((*cat as usize) < local_to_global.len());
185 *cat = *unsafe { local_to_global.get_unchecked(*cat as usize) };
186 }
187 })
188 }
189 };
190
191 let mut global_to_local = PlHashMap::with_capacity(local_to_global.len());
192 POOL.join(
193 || fill_global_to_local(&local_to_global, &mut global_to_local),
194 update_cats,
195 );
196
197 let indices = std::mem::take(&mut self.cat_builder).into();
198 let indices = UInt32Chunked::with_chunk(self.name.clone(), indices);
199
200 unsafe {
202 CategoricalChunked::from_cats_and_rev_map_unchecked(
203 indices,
204 Arc::new(RevMapping::Global(global_to_local, categories, id)),
205 false,
206 self.ordering,
207 )
208 .with_fast_unique(true)
209 }
210 }
211
212 pub fn drain_iter_and_finish<'a, I>(mut self, i: I) -> CategoricalChunked
213 where
214 I: IntoIterator<Item = Option<&'a str>>,
215 {
216 if using_string_cache() {
217 self.drain_iter_global_and_finish(i)
218 } else {
219 self.drain_iter(i);
220 self.finish()
221 }
222 }
223
224 pub fn finish(self) -> CategoricalChunked {
225 unsafe {
227 CategoricalChunked::from_keys_and_values(
228 self.name.clone(),
229 &self.cat_builder.into(),
230 &self.categories.into(),
231 self.ordering,
232 )
233 .with_fast_unique(true)
234 }
235 }
236}
237
238fn fill_global_to_local(local_to_global: &[u32], global_to_local: &mut PlHashMap<u32, u32>) {
239 let mut local_idx = 0;
240 #[allow(clippy::explicit_counter_loop)]
241 for global_idx in local_to_global {
242 unsafe {
244 global_to_local.insert_unique_unchecked(*global_idx, local_idx);
245 }
246 local_idx += 1;
247 }
248}
249
250impl CategoricalChunked {
251 pub(crate) fn from_global_indices(
254 cats: UInt32Chunked,
255 ordering: CategoricalOrdering,
256 ) -> PolarsResult<CategoricalChunked> {
257 let len = crate::STRING_CACHE.read_map().len() as u32;
258 let oob = cats.into_iter().flatten().any(|cat| cat >= len);
259 polars_ensure!(
260 !oob,
261 ComputeError:
262 "cannot construct Categorical from these categories; at least one of them is out of bounds"
263 );
264 Ok(unsafe { Self::from_global_indices_unchecked(cats, ordering) })
265 }
266
267 pub unsafe fn from_global_indices_unchecked(
273 cats: UInt32Chunked,
274 ordering: CategoricalOrdering,
275 ) -> CategoricalChunked {
276 let cache = crate::STRING_CACHE.read_map();
277
278 let cap = std::cmp::min(std::cmp::min(cats.len(), cache.len()), _HASHMAP_INIT_SIZE);
279 let mut rev_map = PlHashMap::with_capacity(cap);
280 let mut str_values = MutablePlString::with_capacity(cap);
281
282 for arr in cats.downcast_iter() {
283 for cat in arr.into_iter().flatten().copied() {
284 let offset = str_values.len() as u32;
285
286 if let Entry::Vacant(entry) = rev_map.entry(cat) {
287 entry.insert(offset);
288 let str_val = cache.get_unchecked(cat);
289 str_values.push(Some(str_val))
290 }
291 }
292 }
293
294 let rev_map = RevMapping::Global(rev_map, str_values.into(), cache.uuid);
295
296 CategoricalChunked::from_cats_and_rev_map_unchecked(
297 cats,
298 Arc::new(rev_map),
299 false,
300 ordering,
301 )
302 }
303
304 pub(crate) unsafe fn from_keys_and_values_global(
305 name: PlSmallStr,
306 keys: impl IntoIterator<Item = Option<u32>> + Send,
307 capacity: usize,
308 values: &Utf8ViewArray,
309 ordering: CategoricalOrdering,
310 ) -> Self {
311 let mut local_to_global: Vec<u32> = Vec::with_capacity(values.len());
313 let (id, local_to_global) = crate::STRING_CACHE.apply(|cache| {
314 for s in values.values_iter() {
317 unsafe { local_to_global.push_unchecked(cache.insert(s)) }
319 }
320 local_to_global
321 });
322
323 let compute_cats = || {
324 let mut result = UInt32Vec::with_capacity(capacity);
325
326 for opt_value in keys.into_iter() {
327 result.push(opt_value.map(|cat| {
328 debug_assert!((cat as usize) < local_to_global.len());
329 *unsafe { local_to_global.get_unchecked(cat as usize) }
330 }));
331 }
332 result
333 };
334
335 let mut global_to_local = PlHashMap::with_capacity(local_to_global.len());
336 let (_, cats) = POOL.join(
337 || fill_global_to_local(&local_to_global, &mut global_to_local),
338 compute_cats,
339 );
340 unsafe {
341 CategoricalChunked::from_cats_and_rev_map_unchecked(
342 UInt32Chunked::with_chunk(name, cats.into()),
343 Arc::new(RevMapping::Global(global_to_local, values.clone(), id)),
344 false,
345 ordering,
346 )
347 }
348 }
349
350 pub(crate) unsafe fn from_keys_and_values_local(
351 name: PlSmallStr,
352 keys: &PrimitiveArray<u32>,
353 values: &Utf8ViewArray,
354 ordering: CategoricalOrdering,
355 ) -> CategoricalChunked {
356 CategoricalChunked::from_cats_and_rev_map_unchecked(
357 UInt32Chunked::with_chunk(name, keys.clone()),
358 Arc::new(RevMapping::build_local(values.clone())),
359 false,
360 ordering,
361 )
362 }
363
364 pub(crate) unsafe fn from_keys_and_values(
367 name: PlSmallStr,
368 keys: &PrimitiveArray<u32>,
369 values: &Utf8ViewArray,
370 ordering: CategoricalOrdering,
371 ) -> Self {
372 if !using_string_cache() {
373 CategoricalChunked::from_keys_and_values_local(name, keys, values, ordering)
374 } else {
375 CategoricalChunked::from_keys_and_values_global(
376 name,
377 keys.into_iter().map(|c| c.copied()),
378 keys.len(),
379 values,
380 ordering,
381 )
382 }
383 }
384
385 pub fn from_string_to_enum(
388 values: &StringChunked,
389 categories: &Utf8ViewArray,
390 ordering: CategoricalOrdering,
391 ) -> PolarsResult<CategoricalChunked> {
392 polars_ensure!(categories.null_count() == 0, ComputeError: "categories can not contain null values");
393
394 let mut map = PlHashMap::with_capacity(categories.len());
396 for (idx, cat) in categories.values_iter().enumerate_idx() {
397 #[allow(clippy::unnecessary_cast)]
398 map.insert(cat, idx as u32);
399 }
400 let iter = values.downcast_iter().map(|arr| {
402 arr.iter()
403 .map(|opt_s: Option<&str>| opt_s.and_then(|s| map.get(s).copied()))
404 .collect_arr()
405 });
406 let mut keys: UInt32Chunked = ChunkedArray::from_chunk_iter(values.name().clone(), iter);
407 keys.rename(values.name().clone());
408 let rev_map = RevMapping::build_local(categories.clone());
409 unsafe {
410 Ok(CategoricalChunked::from_cats_and_rev_map_unchecked(
411 keys,
412 Arc::new(rev_map),
413 true,
414 ordering,
415 )
416 .with_fast_unique(false))
417 }
418 }
419}
420
421#[cfg(test)]
422mod test {
423 use crate::prelude::*;
424 use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache};
425
426 #[test]
427 fn test_categorical_rev() -> PolarsResult<()> {
428 let _lock = SINGLE_LOCK.lock();
429 disable_string_cache();
430 let slice = &[
431 Some("foo"),
432 None,
433 Some("bar"),
434 Some("foo"),
435 Some("foo"),
436 Some("bar"),
437 ];
438 let ca = StringChunked::new(PlSmallStr::from_static("a"), slice);
439 let out = ca.cast(&DataType::Categorical(None, Default::default()))?;
440 let out = out.categorical().unwrap().clone();
441 assert_eq!(out.get_rev_map().len(), 2);
442
443 enable_string_cache();
445 let out = ca.cast(&DataType::Categorical(None, Default::default()))?;
447 let out = out.categorical().unwrap().clone();
448 assert_eq!(out.get_rev_map().len(), 2);
449 let out = ca.cast(&DataType::Categorical(None, Default::default()))?;
451 let out = out.categorical().unwrap().clone();
452 assert_eq!(out.get_rev_map().len(), 2);
453
454 let ca1 = StringChunked::new(PlSmallStr::from_static("a"), slice)
458 .cast(&DataType::Categorical(None, Default::default()))?;
459 let mut ca1 = ca1.categorical().unwrap().clone();
460 let ca2 = StringChunked::new(PlSmallStr::from_static("a"), slice)
461 .cast(&DataType::Categorical(None, Default::default()))?;
462 let ca2 = ca2.categorical().unwrap();
463 ca1.append(ca2).unwrap();
464
465 Ok(())
466 }
467
468 #[test]
469 fn test_categorical_builder() {
470 use crate::{disable_string_cache, enable_string_cache};
471 let _lock = crate::SINGLE_LOCK.lock();
472 for use_string_cache in [false, true] {
473 disable_string_cache();
474 if use_string_cache {
475 enable_string_cache();
476 }
477
478 let builder1 = CategoricalChunkedBuilder::new(
481 PlSmallStr::from_static("foo"),
482 10,
483 Default::default(),
484 );
485 let builder2 = CategoricalChunkedBuilder::new(
486 PlSmallStr::from_static("foo"),
487 10,
488 Default::default(),
489 );
490 let s = builder1
491 .drain_iter_and_finish(vec![None, Some("hello"), Some("vietnam")])
492 .into_series();
493 assert_eq!(s.str_value(0).unwrap(), "null");
494 assert_eq!(s.str_value(1).unwrap(), "hello");
495 assert_eq!(s.str_value(2).unwrap(), "vietnam");
496
497 let s = builder2
498 .drain_iter_and_finish(vec![Some("hello"), None, Some("world")])
499 .into_series();
500 assert_eq!(s.str_value(0).unwrap(), "hello");
501 assert_eq!(s.str_value(1).unwrap(), "null");
502 assert_eq!(s.str_value(2).unwrap(), "world");
503 }
504 }
505}