1use std::any::Any;
2
3use arrow::array::builder::ArrayBuilder;
4use arrow::types::NativeType;
5use num_traits::AsPrimitive;
6use polars_compute::cast::cast_unchecked;
7
8use crate::prelude::*;
9
10pub struct CategoricalToArrowConverter {
12 pub converters: PlIndexMap<usize, CategoricalArrayToArrowConverter>,
18 pub persist_remap: bool,
20 pub output_keys_only: bool,
22}
23
24impl CategoricalToArrowConverter {
25 pub fn array_to_arrow(
31 &mut self,
32 keys_arr: &dyn Array,
33 dtype: &DataType,
34 compat_level: CompatLevel,
35 ) -> Box<dyn Array> {
36 let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = dtype else {
37 unreachable!()
38 };
39
40 let key = Arc::as_ptr(mapping) as *const () as usize;
41 let converter = self.converters.get_mut(&key).unwrap();
42
43 with_match_categorical_physical_type!(dtype.cat_physical().unwrap(), |$C| {
44 let keys_arr: &PrimitiveArray<<$C as PolarsCategoricalType>::Native> = keys_arr.as_any().downcast_ref().unwrap();
45
46 converter.array_to_arrow(
47 keys_arr,
48 dtype,
49 self.persist_remap,
50 self.output_keys_only,
51 compat_level
52 )
53 })
54 }
55
56 pub fn initialize(&mut self, dtype: &DataType) {
58 use DataType::*;
59
60 match dtype {
61 Categorical(_categories, mapping) => {
62 let key = Arc::as_ptr(mapping) as *const () as usize;
63
64 if !self.converters.contains_key(&key) {
65 with_match_categorical_physical_type!(dtype.cat_physical().unwrap(), |$C| {
66 self.converters.insert(
67 key,
68 CategoricalArrayToArrowConverter::Categorical {
69 mapping: mapping.clone(),
70 key_remap: CategoricalKeyRemap::from(
71 PlIndexSet::<<$C as PolarsCategoricalType>::Native>::with_capacity(
72 mapping.num_cats_upper_bound()
73 )
74 ),
75 },
76 );
77 })
78 }
79 },
80 Enum(categories, mapping) => {
81 let key = Arc::as_ptr(mapping) as *const () as usize;
82
83 if !self.converters.contains_key(&key) {
84 self.converters.insert(
85 key,
86 CategoricalArrayToArrowConverter::Enum {
87 frozen: categories.clone(),
88 mapping: mapping.clone(),
89 },
90 );
91 }
92 },
93 List(inner) => self.initialize(inner),
94 #[cfg(feature = "dtype-array")]
95 Array(inner, _width) => self.initialize(inner),
96 #[cfg(feature = "dtype-struct")]
97 Struct(fields) => {
98 for field in fields {
99 self.initialize(field.dtype())
100 }
101 },
102 _ => assert!(!dtype.is_nested()),
103 }
104 }
105}
106
107pub enum CategoricalArrayToArrowConverter {
108 Categorical {
113 mapping: Arc<CategoricalMapping>,
114 key_remap: CategoricalKeyRemap,
115 },
116 Enum {
119 mapping: Arc<CategoricalMapping>,
120 frozen: Arc<FrozenCategories>,
121 },
122}
123
124impl CategoricalArrayToArrowConverter {
125 fn array_to_arrow<T>(
126 &mut self,
127 keys_arr: &PrimitiveArray<T>,
128 dtype: &DataType,
129 persist_remap: bool,
130 output_keys_only: bool,
131 compat_level: CompatLevel,
132 ) -> Box<dyn Array>
133 where
134 T: DictionaryKey + NativeType + std::hash::Hash + Eq,
135 usize: AsPrimitive<T>,
136 {
137 let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = dtype else {
138 unreachable!()
139 };
140
141 let input_mapping_ptr: *const CategoricalMapping = Arc::as_ptr(mapping);
142
143 let keys_arr: PrimitiveArray<T> = match self {
144 Self::Categorical { mapping, key_remap } => {
145 assert_eq!(input_mapping_ptr, Arc::as_ptr(mapping));
146
147 let key_remap: &mut PlIndexSet<T> = key_remap.as_any_mut().downcast_mut().unwrap();
148
149 if !persist_remap {
150 key_remap.clear()
151 }
152
153 keys_arr
154 .iter()
155 .map(|x| {
156 x.map(|x: &T| {
157 let idx: usize = key_remap.insert_full(*x).0;
158 let out: T = idx.as_();
160 out
161 })
162 })
163 .collect()
164 },
165 Self::Enum { mapping, .. } => {
166 assert_eq!(input_mapping_ptr, Arc::as_ptr(mapping));
167 keys_arr.clone()
168 },
169 };
170
171 if output_keys_only {
172 return keys_arr.boxed();
173 }
174
175 let values = self.build_values_array(compat_level);
176
177 let dictionary_dtype = ArrowDataType::Dictionary(
178 <T as DictionaryKey>::KEY_TYPE,
179 Box::new(values.dtype().clone()),
180 false, );
182
183 unsafe {
184 DictionaryArray::<T>::try_new_unchecked(dictionary_dtype, keys_arr, values)
185 .unwrap()
186 .boxed()
187 }
188 }
189
190 pub fn build_values_array(&self, compat_level: CompatLevel) -> Box<dyn Array> {
196 match self {
197 Self::Categorical { mapping, key_remap } => match key_remap {
198 CategoricalKeyRemap::U8(keys) => self.build_values_array_from_keys(
199 keys.iter().map(|x: &u8| *x as CatSize),
200 mapping,
201 compat_level,
202 ),
203 CategoricalKeyRemap::U16(keys) => self.build_values_array_from_keys(
204 keys.iter().map(|x: &u16| *x as CatSize),
205 mapping,
206 compat_level,
207 ),
208 CategoricalKeyRemap::U32(keys) => self.build_values_array_from_keys(
209 keys.iter().map(|x: &u32| *x as CatSize),
210 mapping,
211 compat_level,
212 ),
213 },
214
215 Self::Enum { frozen, .. } => {
216 let array: &Utf8ViewArray = frozen.categories();
217
218 if compat_level != CompatLevel::oldest() {
219 array.to_boxed()
220 } else {
221 cast_unchecked(array, &ArrowDataType::LargeUtf8).unwrap()
224 }
225 },
226 }
227 }
228
229 fn build_values_array_from_keys<I>(
230 &self,
231 keys_iter: I,
232 mapping: &CategoricalMapping,
233 compat_level: CompatLevel,
234 ) -> Box<dyn Array>
235 where
236 I: ExactSizeIterator<Item = CatSize>,
237 {
238 if compat_level != CompatLevel::oldest() {
239 let mut builder = Utf8ViewArrayBuilder::new(ArrowDataType::Utf8View);
240 builder.reserve(keys_iter.len());
241
242 for x in keys_iter {
243 builder.push_value_ignore_validity(mapping.cat_to_str(x).unwrap())
244 }
245
246 builder.freeze().to_boxed()
247 } else {
248 let mut builder: MutableUtf8Array<i64> = MutableUtf8Array::new();
249 builder.reserve(keys_iter.len(), 0);
250
251 for x in keys_iter {
252 builder.push(Some(mapping.cat_to_str(x).unwrap()));
253 }
254
255 let out: Utf8Array<i64> = builder.into();
256 out.boxed()
257 }
258 }
259}
260
261pub enum CategoricalKeyRemap {
262 U8(PlIndexSet<u8>),
263 U16(PlIndexSet<u16>),
264 U32(PlIndexSet<u32>),
265}
266
267impl CategoricalKeyRemap {
268 fn as_any_mut(&mut self) -> &mut dyn Any {
269 match self {
270 Self::U8(v) => v as _,
271 Self::U16(v) => v as _,
272 Self::U32(v) => v as _,
273 }
274 }
275}
276
277impl From<PlIndexSet<u8>> for CategoricalKeyRemap {
278 fn from(value: PlIndexSet<u8>) -> Self {
279 Self::U8(value)
280 }
281}
282
283impl From<PlIndexSet<u16>> for CategoricalKeyRemap {
284 fn from(value: PlIndexSet<u16>) -> Self {
285 Self::U16(value)
286 }
287}
288
289impl From<PlIndexSet<u32>> for CategoricalKeyRemap {
290 fn from(value: PlIndexSet<u32>) -> Self {
291 Self::U32(value)
292 }
293}