polars_core/chunked_array/logical/categorical/
from.rs

1use arrow::datatypes::IntegerType;
2use polars_compute::cast::{CastOptionsImpl, cast, utf8view_to_utf8};
3
4use super::*;
5
6fn convert_values(arr: &Utf8ViewArray, compat_level: CompatLevel) -> ArrayRef {
7    if compat_level.0 >= 1 {
8        arr.clone().boxed()
9    } else {
10        utf8view_to_utf8::<i64>(arr).boxed()
11    }
12}
13
14impl CategoricalChunked {
15    pub fn to_arrow(&self, compat_level: CompatLevel, as_i64: bool) -> ArrayRef {
16        if as_i64 {
17            self.to_i64(compat_level).boxed()
18        } else {
19            self.to_u32(compat_level).boxed()
20        }
21    }
22
23    fn to_u32(&self, compat_level: CompatLevel) -> DictionaryArray<u32> {
24        let values_dtype = if compat_level.0 >= 1 {
25            ArrowDataType::Utf8View
26        } else {
27            ArrowDataType::LargeUtf8
28        };
29        let keys = self.physical().rechunk();
30        let keys = keys.downcast_as_array();
31        let map = &**self.get_rev_map();
32        let dtype = ArrowDataType::Dictionary(IntegerType::UInt32, Box::new(values_dtype), false);
33        match map {
34            RevMapping::Local(arr, _) => {
35                let values = convert_values(arr, compat_level);
36
37                // SAFETY:
38                // the keys are in bounds
39                unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() }
40            },
41            RevMapping::Global(reverse_map, values, _uuid) => {
42                let iter = keys
43                    .iter()
44                    .map(|opt_k| opt_k.map(|k| *reverse_map.get(k).unwrap()));
45                let keys = PrimitiveArray::from_trusted_len_iter(iter);
46
47                let values = convert_values(values, compat_level);
48
49                // SAFETY:
50                // the keys are in bounds
51                unsafe { DictionaryArray::try_new_unchecked(dtype, keys, values).unwrap() }
52            },
53        }
54    }
55
56    fn to_i64(&self, compat_level: CompatLevel) -> DictionaryArray<i64> {
57        let values_dtype = if compat_level.0 >= 1 {
58            ArrowDataType::Utf8View
59        } else {
60            ArrowDataType::LargeUtf8
61        };
62        let keys = self.physical().rechunk();
63        let keys = keys.downcast_as_array();
64        let map = &**self.get_rev_map();
65        let dtype = ArrowDataType::Dictionary(IntegerType::Int64, Box::new(values_dtype), false);
66        match map {
67            RevMapping::Local(arr, _) => {
68                let values = convert_values(arr, compat_level);
69
70                // SAFETY:
71                // the keys are in bounds
72                unsafe {
73                    DictionaryArray::try_new_unchecked(
74                        dtype,
75                        cast(keys, &ArrowDataType::Int64, CastOptionsImpl::unchecked())
76                            .unwrap()
77                            .as_any()
78                            .downcast_ref::<PrimitiveArray<i64>>()
79                            .unwrap()
80                            .clone(),
81                        values,
82                    )
83                    .unwrap()
84                }
85            },
86            RevMapping::Global(reverse_map, values, _uuid) => {
87                let iter = keys
88                    .iter()
89                    .map(|opt_k| opt_k.map(|k| *reverse_map.get(k).unwrap() as i64));
90                let keys = PrimitiveArray::from_trusted_len_iter(iter);
91
92                let values = convert_values(values, compat_level);
93
94                // SAFETY:
95                // the keys are in bounds
96                unsafe { DictionaryArray::try_new_unchecked(dtype, keys, values).unwrap() }
97            },
98        }
99    }
100}