polars_core/chunked_array/logical/categorical/
builder.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::array::*;
3use arrow::legacy::trusted_len::TrustedLenPush;
4use hashbrown::hash_map::Entry;
5use hashbrown::hash_table::{Entry as HTEntry, HashTable};
6use polars_utils::itertools::Itertools;
7
8use crate::hashing::_HASHMAP_INIT_SIZE;
9use crate::prelude::*;
10use crate::{POOL, StringCache, using_string_cache};
11
12pub struct CategoricalChunkedBuilder {
13    cat_builder: UInt32Vec,
14    name: PlSmallStr,
15    ordering: CategoricalOrdering,
16    categories: MutablePlString,
17    local_mapping: HashTable<u32>,
18    local_hasher: PlRandomState,
19}
20
21impl CategoricalChunkedBuilder {
22    pub fn new(name: PlSmallStr, capacity: usize, ordering: CategoricalOrdering) -> Self {
23        Self {
24            cat_builder: UInt32Vec::with_capacity(capacity),
25            name,
26            ordering,
27            categories: MutablePlString::with_capacity(_HASHMAP_INIT_SIZE),
28            local_mapping: HashTable::with_capacity(capacity / 10),
29            local_hasher: StringCache::get_hash_builder(),
30        }
31    }
32
33    fn get_cat_idx(&mut self, s: &str, h: u64) -> (u32, bool) {
34        let len = self.local_mapping.len() as u32;
35
36        // SAFETY: index in hashmap are within bounds of categories
37        unsafe {
38            let r = self.local_mapping.entry(
39                h,
40                |k| self.categories.value_unchecked(*k as usize) == s,
41                |k| {
42                    self.local_hasher
43                        .hash_one(self.categories.value_unchecked(*k as usize))
44                },
45            );
46
47            match r {
48                HTEntry::Occupied(v) => (*v.get(), false),
49                HTEntry::Vacant(slot) => {
50                    self.categories.push(Some(s));
51                    slot.insert(len);
52                    (len, true)
53                },
54            }
55        }
56    }
57
58    fn try_get_cat_idx(&mut self, s: &str, h: u64) -> Option<u32> {
59        // SAFETY: index in hashmap are within bounds of categories
60        unsafe {
61            let r = self.local_mapping.entry(
62                h,
63                |k| self.categories.value_unchecked(*k as usize) == s,
64                |k| {
65                    self.local_hasher
66                        .hash_one(self.categories.value_unchecked(*k as usize))
67                },
68            );
69
70            match r {
71                HTEntry::Occupied(v) => Some(*v.get()),
72                HTEntry::Vacant(_) => None,
73            }
74        }
75    }
76
77    /// Append a new category, but fail if it didn't exist yet in the category state.
78    /// You can register categories up front with `register_value`, or via `append`.
79    #[inline]
80    pub fn try_append_value(&mut self, s: &str) -> PolarsResult<()> {
81        let h = self.local_hasher.hash_one(s);
82        let idx = self.try_get_cat_idx(s, h).ok_or_else(
83            || polars_err!(ComputeError: "category {} doesn't exist in Enum dtype", s),
84        )?;
85        self.cat_builder.push(Some(idx));
86        Ok(())
87    }
88
89    /// Append a new category, but fail if it didn't exist yet in the category state.
90    /// You can register categories up front with `register_value`, or via `append`.
91    #[inline]
92    pub fn try_append(&mut self, opt_s: Option<&str>) -> PolarsResult<()> {
93        match opt_s {
94            None => self.append_null(),
95            Some(s) => self.try_append_value(s)?,
96        }
97        Ok(())
98    }
99
100    /// Registers a value to a categorical index without pushing it.
101    /// Returns the index and if the value was new.
102    #[inline]
103    pub fn register_value(&mut self, s: &str) -> (u32, bool) {
104        let h = self.local_hasher.hash_one(s);
105        self.get_cat_idx(s, h)
106    }
107
108    #[inline]
109    pub fn append_value(&mut self, s: &str) {
110        let h = self.local_hasher.hash_one(s);
111        let idx = self.get_cat_idx(s, h).0;
112        self.cat_builder.push(Some(idx));
113    }
114
115    #[inline]
116    pub fn append_null(&mut self) {
117        self.cat_builder.push(None)
118    }
119
120    #[inline]
121    pub fn append(&mut self, opt_s: Option<&str>) {
122        match opt_s {
123            None => self.append_null(),
124            Some(s) => self.append_value(s),
125        }
126    }
127
128    fn drain_iter<'a, I>(&mut self, i: I)
129    where
130        I: IntoIterator<Item = Option<&'a str>>,
131    {
132        for opt_s in i.into_iter() {
133            self.append(opt_s);
134        }
135    }
136
137    /// Fast path for global categorical which preserves hashes and saves an allocation by
138    /// altering the keys in place.
139    fn drain_iter_global_and_finish<'a, I>(&mut self, i: I) -> CategoricalChunked
140    where
141        I: IntoIterator<Item = Option<&'a str>>,
142    {
143        let iter = i.into_iter();
144        // Save hashes for later when inserting into the global hashmap.
145        let mut hashes = Vec::with_capacity(_HASHMAP_INIT_SIZE);
146        for s in self.categories.values_iter() {
147            hashes.push(self.local_hasher.hash_one(s));
148        }
149
150        for opt_s in iter {
151            match opt_s {
152                None => self.append_null(),
153                Some(s) => {
154                    let hash = self.local_hasher.hash_one(s);
155                    let (cat_idx, new) = self.get_cat_idx(s, hash);
156                    self.cat_builder.push(Some(cat_idx));
157                    if new {
158                        // We appended a value to the map.
159                        hashes.push(hash);
160                    }
161                },
162            }
163        }
164
165        let categories = std::mem::take(&mut self.categories).freeze();
166
167        // We will create a mapping from our local categoricals to global categoricals
168        // and a mapping from global categoricals to our local categoricals.
169        let mut local_to_global: Vec<u32> = Vec::with_capacity(categories.len());
170        let (id, local_to_global) = crate::STRING_CACHE.apply(|cache| {
171            for (s, h) in categories.values_iter().zip(hashes) {
172                // SAFETY: we allocated enough.
173                unsafe { local_to_global.push_unchecked(cache.insert_from_hash(h, s)) }
174            }
175            local_to_global
176        });
177
178        // Change local indices inplace to their global counterparts.
179        let update_cats = || {
180            if !local_to_global.is_empty() {
181                // when all categorical are null, `local_to_global` is empty and all cats physical values are 0.
182                self.cat_builder.apply_values(|cats| {
183                    for cat in cats {
184                        debug_assert!((*cat as usize) < local_to_global.len());
185                        *cat = *unsafe { local_to_global.get_unchecked(*cat as usize) };
186                    }
187                })
188            }
189        };
190
191        let mut global_to_local = PlHashMap::with_capacity(local_to_global.len());
192        POOL.join(
193            || fill_global_to_local(&local_to_global, &mut global_to_local),
194            update_cats,
195        );
196
197        let indices = std::mem::take(&mut self.cat_builder).into();
198        let indices = UInt32Chunked::with_chunk(self.name.clone(), indices);
199
200        // SAFETY: indices are in bounds of new rev_map
201        unsafe {
202            CategoricalChunked::from_cats_and_rev_map_unchecked(
203                indices,
204                Arc::new(RevMapping::Global(global_to_local, categories, id)),
205                false,
206                self.ordering,
207            )
208            .with_fast_unique(true)
209        }
210    }
211
212    pub fn drain_iter_and_finish<'a, I>(mut self, i: I) -> CategoricalChunked
213    where
214        I: IntoIterator<Item = Option<&'a str>>,
215    {
216        if using_string_cache() {
217            self.drain_iter_global_and_finish(i)
218        } else {
219            self.drain_iter(i);
220            self.finish()
221        }
222    }
223
224    pub fn finish(self) -> CategoricalChunked {
225        // SAFETY: keys and values are in bounds
226        unsafe {
227            CategoricalChunked::from_keys_and_values(
228                self.name.clone(),
229                &self.cat_builder.into(),
230                &self.categories.into(),
231                self.ordering,
232            )
233            .with_fast_unique(true)
234        }
235    }
236}
237
238fn fill_global_to_local(local_to_global: &[u32], global_to_local: &mut PlHashMap<u32, u32>) {
239    let mut local_idx = 0;
240    #[allow(clippy::explicit_counter_loop)]
241    for global_idx in local_to_global {
242        // we know the keys are unique so this is much faster
243        unsafe {
244            global_to_local.insert_unique_unchecked(*global_idx, local_idx);
245        }
246        local_idx += 1;
247    }
248}
249
250impl CategoricalChunked {
251    /// Create a [`CategoricalChunked`] from a categorical indices. The indices will
252    /// probe the global string cache.
253    pub(crate) fn from_global_indices(
254        cats: UInt32Chunked,
255        ordering: CategoricalOrdering,
256    ) -> PolarsResult<CategoricalChunked> {
257        let len = crate::STRING_CACHE.read_map().len() as u32;
258        let oob = cats.into_iter().flatten().any(|cat| cat >= len);
259        polars_ensure!(
260            !oob,
261            ComputeError:
262            "cannot construct Categorical from these categories; at least one of them is out of bounds"
263        );
264        Ok(unsafe { Self::from_global_indices_unchecked(cats, ordering) })
265    }
266
267    /// Create a [`CategoricalChunked`] from a categorical indices. The indices will
268    /// probe the global string cache.
269    ///
270    /// # Safety
271    /// This does not do any bound checks
272    pub unsafe fn from_global_indices_unchecked(
273        cats: UInt32Chunked,
274        ordering: CategoricalOrdering,
275    ) -> CategoricalChunked {
276        let cache = crate::STRING_CACHE.read_map();
277
278        let cap = std::cmp::min(std::cmp::min(cats.len(), cache.len()), _HASHMAP_INIT_SIZE);
279        let mut rev_map = PlHashMap::with_capacity(cap);
280        let mut str_values = MutablePlString::with_capacity(cap);
281
282        for arr in cats.downcast_iter() {
283            for cat in arr.into_iter().flatten().copied() {
284                let offset = str_values.len() as u32;
285
286                if let Entry::Vacant(entry) = rev_map.entry(cat) {
287                    entry.insert(offset);
288                    let str_val = cache.get_unchecked(cat);
289                    str_values.push(Some(str_val))
290                }
291            }
292        }
293
294        let rev_map = RevMapping::Global(rev_map, str_values.into(), cache.uuid);
295
296        CategoricalChunked::from_cats_and_rev_map_unchecked(
297            cats,
298            Arc::new(rev_map),
299            false,
300            ordering,
301        )
302    }
303
304    pub(crate) unsafe fn from_keys_and_values_global(
305        name: PlSmallStr,
306        keys: impl IntoIterator<Item = Option<u32>> + Send,
307        capacity: usize,
308        values: &Utf8ViewArray,
309        ordering: CategoricalOrdering,
310    ) -> Self {
311        // Vec<u32> where the index is local and the value is the global index
312        let mut local_to_global: Vec<u32> = Vec::with_capacity(values.len());
313        let (id, local_to_global) = crate::STRING_CACHE.apply(|cache| {
314            // locally we don't need a hashmap because we all categories are 1 integer apart
315            // so the index is local, and the values is global
316            for s in values.values_iter() {
317                // SAFETY: we allocated enough
318                unsafe { local_to_global.push_unchecked(cache.insert(s)) }
319            }
320            local_to_global
321        });
322
323        let compute_cats = || {
324            let mut result = UInt32Vec::with_capacity(capacity);
325
326            for opt_value in keys.into_iter() {
327                result.push(opt_value.map(|cat| {
328                    debug_assert!((cat as usize) < local_to_global.len());
329                    *unsafe { local_to_global.get_unchecked(cat as usize) }
330                }));
331            }
332            result
333        };
334
335        let mut global_to_local = PlHashMap::with_capacity(local_to_global.len());
336        let (_, cats) = POOL.join(
337            || fill_global_to_local(&local_to_global, &mut global_to_local),
338            compute_cats,
339        );
340        unsafe {
341            CategoricalChunked::from_cats_and_rev_map_unchecked(
342                UInt32Chunked::with_chunk(name, cats.into()),
343                Arc::new(RevMapping::Global(global_to_local, values.clone(), id)),
344                false,
345                ordering,
346            )
347        }
348    }
349
350    pub(crate) unsafe fn from_keys_and_values_local(
351        name: PlSmallStr,
352        keys: &PrimitiveArray<u32>,
353        values: &Utf8ViewArray,
354        ordering: CategoricalOrdering,
355    ) -> CategoricalChunked {
356        CategoricalChunked::from_cats_and_rev_map_unchecked(
357            UInt32Chunked::with_chunk(name, keys.clone()),
358            Arc::new(RevMapping::build_local(values.clone())),
359            false,
360            ordering,
361        )
362    }
363
364    /// # Safety
365    /// The caller must ensure that index values in the `keys` are in within bounds of the `values` length.
366    pub(crate) unsafe fn from_keys_and_values(
367        name: PlSmallStr,
368        keys: &PrimitiveArray<u32>,
369        values: &Utf8ViewArray,
370        ordering: CategoricalOrdering,
371    ) -> Self {
372        if !using_string_cache() {
373            CategoricalChunked::from_keys_and_values_local(name, keys, values, ordering)
374        } else {
375            CategoricalChunked::from_keys_and_values_global(
376                name,
377                keys.into_iter().map(|c| c.copied()),
378                keys.len(),
379                values,
380                ordering,
381            )
382        }
383    }
384
385    /// Create a [`CategoricalChunked`] from a fixed list of categories and a List of strings.
386    /// This will error if a string is not in the fixed list of categories
387    pub fn from_string_to_enum(
388        values: &StringChunked,
389        categories: &Utf8ViewArray,
390        ordering: CategoricalOrdering,
391    ) -> PolarsResult<CategoricalChunked> {
392        polars_ensure!(categories.null_count()  == 0, ComputeError: "categories can not contain null values");
393
394        // Build a mapping string -> idx
395        let mut map = PlHashMap::with_capacity(categories.len());
396        for (idx, cat) in categories.values_iter().enumerate_idx() {
397            #[allow(clippy::unnecessary_cast)]
398            map.insert(cat, idx as u32);
399        }
400        // Find idx of every value in the map
401        let iter = values.downcast_iter().map(|arr| {
402            arr.iter()
403                .map(|opt_s: Option<&str>| opt_s.and_then(|s| map.get(s).copied()))
404                .collect_arr()
405        });
406        let mut keys: UInt32Chunked = ChunkedArray::from_chunk_iter(values.name().clone(), iter);
407        keys.rename(values.name().clone());
408        let rev_map = RevMapping::build_local(categories.clone());
409        unsafe {
410            Ok(CategoricalChunked::from_cats_and_rev_map_unchecked(
411                keys,
412                Arc::new(rev_map),
413                true,
414                ordering,
415            )
416            .with_fast_unique(false))
417        }
418    }
419}
420
421#[cfg(test)]
422mod test {
423    use crate::prelude::*;
424    use crate::{SINGLE_LOCK, disable_string_cache, enable_string_cache};
425
426    #[test]
427    fn test_categorical_rev() -> PolarsResult<()> {
428        let _lock = SINGLE_LOCK.lock();
429        disable_string_cache();
430        let slice = &[
431            Some("foo"),
432            None,
433            Some("bar"),
434            Some("foo"),
435            Some("foo"),
436            Some("bar"),
437        ];
438        let ca = StringChunked::new(PlSmallStr::from_static("a"), slice);
439        let out = ca.cast(&DataType::Categorical(None, Default::default()))?;
440        let out = out.categorical().unwrap().clone();
441        assert_eq!(out.get_rev_map().len(), 2);
442
443        // test the global branch
444        enable_string_cache();
445        // empty global cache
446        let out = ca.cast(&DataType::Categorical(None, Default::default()))?;
447        let out = out.categorical().unwrap().clone();
448        assert_eq!(out.get_rev_map().len(), 2);
449        // full global cache
450        let out = ca.cast(&DataType::Categorical(None, Default::default()))?;
451        let out = out.categorical().unwrap().clone();
452        assert_eq!(out.get_rev_map().len(), 2);
453
454        // Check that we don't panic if we append two categorical arrays
455        // build under the same string cache
456        // https://github.com/pola-rs/polars/issues/1115
457        let ca1 = StringChunked::new(PlSmallStr::from_static("a"), slice)
458            .cast(&DataType::Categorical(None, Default::default()))?;
459        let mut ca1 = ca1.categorical().unwrap().clone();
460        let ca2 = StringChunked::new(PlSmallStr::from_static("a"), slice)
461            .cast(&DataType::Categorical(None, Default::default()))?;
462        let ca2 = ca2.categorical().unwrap();
463        ca1.append(ca2).unwrap();
464
465        Ok(())
466    }
467
468    #[test]
469    fn test_categorical_builder() {
470        use crate::{disable_string_cache, enable_string_cache};
471        let _lock = crate::SINGLE_LOCK.lock();
472        for use_string_cache in [false, true] {
473            disable_string_cache();
474            if use_string_cache {
475                enable_string_cache();
476            }
477
478            // Use 2 builders to check if the global string cache
479            // does not interfere with the index mapping
480            let builder1 = CategoricalChunkedBuilder::new(
481                PlSmallStr::from_static("foo"),
482                10,
483                Default::default(),
484            );
485            let builder2 = CategoricalChunkedBuilder::new(
486                PlSmallStr::from_static("foo"),
487                10,
488                Default::default(),
489            );
490            let s = builder1
491                .drain_iter_and_finish(vec![None, Some("hello"), Some("vietnam")])
492                .into_series();
493            assert_eq!(s.str_value(0).unwrap(), "null");
494            assert_eq!(s.str_value(1).unwrap(), "hello");
495            assert_eq!(s.str_value(2).unwrap(), "vietnam");
496
497            let s = builder2
498                .drain_iter_and_finish(vec![Some("hello"), None, Some("world")])
499                .into_series();
500            assert_eq!(s.str_value(0).unwrap(), "hello");
501            assert_eq!(s.str_value(1).unwrap(), "null");
502            assert_eq!(s.str_value(2).unwrap(), "world");
503        }
504    }
505}