1use std::any::Any;
2
3use arrow::array::builder::ArrayBuilder;
4use arrow::types::NativeType;
5use num_traits::AsPrimitive;
6use polars_compute::cast::cast_unchecked;
7
8use crate::prelude::*;
9
10pub struct CategoricalToArrowConverter {
12 pub converters: PlIndexMap<usize, CategoricalArrayToArrowConverter>,
18 pub persist_remap: bool,
20 pub output_keys_only: bool,
22}
23
24impl CategoricalToArrowConverter {
25 pub fn array_to_arrow(
31 &mut self,
32 keys_arr: &dyn Array,
33 dtype: &DataType,
34 compat_level: CompatLevel,
35 ) -> Box<dyn Array> {
36 let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = dtype else {
37 unreachable!()
38 };
39
40 let key = Arc::as_ptr(mapping) as *const () as usize;
41 let converter = self.converters.get_mut(&key).unwrap();
42
43 with_match_categorical_physical_type!(dtype.cat_physical().unwrap(), |$C| {
44 let keys_arr: &PrimitiveArray<<$C as PolarsCategoricalType>::Native> = keys_arr.as_any().downcast_ref().unwrap();
45
46 converter.array_to_arrow(
47 keys_arr,
48 dtype,
49 self.persist_remap,
50 self.output_keys_only,
51 compat_level
52 )
53 })
54 }
55
56 pub fn initialize(&mut self, dtype: &DataType) {
58 use DataType::*;
59
60 match dtype {
61 Categorical(_categories, mapping) => {
62 let key = Arc::as_ptr(mapping) as *const () as usize;
63
64 if !self.converters.contains_key(&key) {
65 with_match_categorical_physical_type!(dtype.cat_physical().unwrap(), |$C| {
66 self.converters.insert(
67 key,
68 CategoricalArrayToArrowConverter::Categorical {
69 mapping: mapping.clone(),
70 key_remap: CategoricalKeyRemap::from(
71 PlIndexSet::<<$C as PolarsCategoricalType>::Native>::with_capacity(
72 mapping.num_cats_upper_bound()
73 )
74 ),
75 },
76 );
77 })
78 }
79 },
80 Enum(categories, mapping) => {
81 let key = Arc::as_ptr(mapping) as *const () as usize;
82
83 if !self.converters.contains_key(&key) {
84 self.converters.insert(
85 key,
86 CategoricalArrayToArrowConverter::Enum {
87 frozen: categories.clone(),
88 mapping: mapping.clone(),
89 },
90 );
91 }
92 },
93 List(inner) => self.initialize(inner),
94 #[cfg(feature = "dtype-array")]
95 Array(inner, _width) => self.initialize(inner),
96 #[cfg(feature = "dtype-struct")]
97 Struct(fields) => {
98 for field in fields {
99 self.initialize(field.dtype())
100 }
101 },
102 #[cfg(feature = "dtype-extension")]
103 Extension(_, inner) => self.initialize(inner),
104 _ => assert!(!dtype.is_nested()),
105 }
106 }
107}
108
109pub enum CategoricalArrayToArrowConverter {
110 Categorical {
115 mapping: Arc<CategoricalMapping>,
116 key_remap: CategoricalKeyRemap,
117 },
118 Enum {
121 mapping: Arc<CategoricalMapping>,
122 frozen: Arc<FrozenCategories>,
123 },
124}
125
126impl CategoricalArrayToArrowConverter {
127 fn array_to_arrow<T>(
128 &mut self,
129 keys_arr: &PrimitiveArray<T>,
130 dtype: &DataType,
131 persist_remap: bool,
132 output_keys_only: bool,
133 compat_level: CompatLevel,
134 ) -> Box<dyn Array>
135 where
136 T: DictionaryKey + NativeType + std::hash::Hash + Eq,
137 usize: AsPrimitive<T>,
138 {
139 let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = dtype else {
140 unreachable!()
141 };
142
143 let input_mapping_ptr: *const CategoricalMapping = Arc::as_ptr(mapping);
144
145 let keys_arr: PrimitiveArray<T> = match self {
146 Self::Categorical { mapping, key_remap } => {
147 assert_eq!(input_mapping_ptr, Arc::as_ptr(mapping));
148
149 let key_remap: &mut PlIndexSet<T> = key_remap.as_any_mut().downcast_mut().unwrap();
150
151 if !persist_remap {
152 key_remap.clear()
153 }
154
155 keys_arr
156 .iter()
157 .map(|x| {
158 x.map(|x: &T| {
159 let idx: usize = key_remap.insert_full(*x).0;
160 let out: T = idx.as_();
162 out
163 })
164 })
165 .collect()
166 },
167 Self::Enum { mapping, .. } => {
168 assert_eq!(input_mapping_ptr, Arc::as_ptr(mapping));
169 keys_arr.clone()
170 },
171 };
172
173 if output_keys_only {
174 return keys_arr.boxed();
175 }
176
177 let values = self.build_values_array(compat_level);
178
179 let dictionary_dtype = ArrowDataType::Dictionary(
180 <T as DictionaryKey>::KEY_TYPE,
181 Box::new(values.dtype().clone()),
182 false, );
184
185 unsafe {
186 DictionaryArray::<T>::try_new_unchecked(dictionary_dtype, keys_arr, values)
187 .unwrap()
188 .boxed()
189 }
190 }
191
192 pub fn build_values_array(&self, compat_level: CompatLevel) -> Box<dyn Array> {
198 match self {
199 Self::Categorical { mapping, key_remap } => match key_remap {
200 CategoricalKeyRemap::U8(keys) => self.build_values_array_from_keys(
201 keys.iter().map(|x: &u8| *x as CatSize),
202 mapping,
203 compat_level,
204 ),
205 CategoricalKeyRemap::U16(keys) => self.build_values_array_from_keys(
206 keys.iter().map(|x: &u16| *x as CatSize),
207 mapping,
208 compat_level,
209 ),
210 CategoricalKeyRemap::U32(keys) => self.build_values_array_from_keys(
211 keys.iter().map(|x: &u32| *x as CatSize),
212 mapping,
213 compat_level,
214 ),
215 },
216
217 Self::Enum { frozen, .. } => {
218 let array: &Utf8ViewArray = frozen.categories();
219
220 if compat_level != CompatLevel::oldest() {
221 array.to_boxed()
222 } else {
223 cast_unchecked(array, &ArrowDataType::LargeUtf8).unwrap()
226 }
227 },
228 }
229 }
230
231 fn build_values_array_from_keys<I>(
232 &self,
233 keys_iter: I,
234 mapping: &CategoricalMapping,
235 compat_level: CompatLevel,
236 ) -> Box<dyn Array>
237 where
238 I: ExactSizeIterator<Item = CatSize>,
239 {
240 if compat_level != CompatLevel::oldest() {
241 let mut builder = Utf8ViewArrayBuilder::new(ArrowDataType::Utf8View);
242 builder.reserve(keys_iter.len());
243
244 for x in keys_iter {
245 builder.push_value_ignore_validity(mapping.cat_to_str(x).unwrap())
246 }
247
248 builder.freeze().to_boxed()
249 } else {
250 let mut builder: MutableUtf8Array<i64> = MutableUtf8Array::new();
251 builder.reserve(keys_iter.len(), 0);
252
253 for x in keys_iter {
254 builder.push(Some(mapping.cat_to_str(x).unwrap()));
255 }
256
257 let out: Utf8Array<i64> = builder.into();
258 out.boxed()
259 }
260 }
261}
262
263pub enum CategoricalKeyRemap {
264 U8(PlIndexSet<u8>),
265 U16(PlIndexSet<u16>),
266 U32(PlIndexSet<u32>),
267}
268
269impl CategoricalKeyRemap {
270 fn as_any_mut(&mut self) -> &mut dyn Any {
271 match self {
272 Self::U8(v) => v as _,
273 Self::U16(v) => v as _,
274 Self::U32(v) => v as _,
275 }
276 }
277}
278
279impl From<PlIndexSet<u8>> for CategoricalKeyRemap {
280 fn from(value: PlIndexSet<u8>) -> Self {
281 Self::U8(value)
282 }
283}
284
285impl From<PlIndexSet<u16>> for CategoricalKeyRemap {
286 fn from(value: PlIndexSet<u16>) -> Self {
287 Self::U16(value)
288 }
289}
290
291impl From<PlIndexSet<u32>> for CategoricalKeyRemap {
292 fn from(value: PlIndexSet<u32>) -> Self {
293 Self::U32(value)
294 }
295}