1use std::any::Any;
2
3use arrow::array::builder::ArrayBuilder;
4use arrow::types::NativeType;
5use num_traits::AsPrimitive;
6use polars_compute::cast::cast_unchecked;
7
8use crate::prelude::*;
9
10pub struct CategoricalToArrowConverter {
12 pub converters: PlIndexMap<usize, CategoricalArrayToArrowConverter>,
14 pub persist_remap: bool,
16 pub output_keys_only: bool,
18}
19
20impl CategoricalToArrowConverter {
21 pub fn array_to_arrow(
27 &mut self,
28 keys_arr: &dyn Array,
29 dtype: &DataType,
30 compat_level: CompatLevel,
31 ) -> Box<dyn Array> {
32 let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = dtype else {
33 unreachable!()
34 };
35
36 let key = Arc::as_ptr(mapping) as *const () as usize;
37 let converter = self.converters.get_mut(&key).unwrap();
38
39 with_match_categorical_physical_type!(dtype.cat_physical().unwrap(), |$C| {
40 let keys_arr: &PrimitiveArray<<$C as PolarsCategoricalType>::Native> = keys_arr.as_any().downcast_ref().unwrap();
41
42 converter.array_to_arrow(
43 keys_arr,
44 dtype,
45 self.persist_remap,
46 self.output_keys_only,
47 compat_level
48 )
49 })
50 }
51
52 pub fn initialize(&mut self, dtype: &DataType) {
54 use DataType::*;
55
56 match dtype {
57 Categorical(_categories, mapping) => {
58 let key = Arc::as_ptr(mapping) as *const () as usize;
59
60 if !self.converters.contains_key(&key) {
61 with_match_categorical_physical_type!(dtype.cat_physical().unwrap(), |$C| {
62 self.converters.insert(
63 key,
64 CategoricalArrayToArrowConverter::Categorical {
65 mapping: mapping.clone(),
66 key_remap: CategoricalKeyRemap::from(
67 PlIndexSet::<<$C as PolarsCategoricalType>::Native>::with_capacity(
68 mapping.num_cats_upper_bound()
69 )
70 ),
71 },
72 );
73 })
74 }
75 },
76 Enum(categories, mapping) => {
77 let key = Arc::as_ptr(mapping) as *const () as usize;
78
79 if !self.converters.contains_key(&key) {
80 self.converters.insert(
81 key,
82 CategoricalArrayToArrowConverter::Enum {
83 frozen: categories.clone(),
84 mapping: mapping.clone(),
85 },
86 );
87 }
88 },
89 List(inner) => self.initialize(inner),
90 #[cfg(feature = "dtype-array")]
91 Array(inner, _width) => self.initialize(inner),
92 #[cfg(feature = "dtype-struct")]
93 Struct(fields) => {
94 for field in fields {
95 self.initialize(field.dtype())
96 }
97 },
98 _ => assert!(!dtype.is_nested()),
99 }
100 }
101}
102
103pub enum CategoricalArrayToArrowConverter {
104 Categorical {
105 mapping: Arc<CategoricalMapping>,
106 key_remap: CategoricalKeyRemap,
107 },
108 Enum {
111 mapping: Arc<CategoricalMapping>,
112 frozen: Arc<FrozenCategories>,
113 },
114}
115
116impl CategoricalArrayToArrowConverter {
117 fn array_to_arrow<T>(
118 &mut self,
119 keys_arr: &PrimitiveArray<T>,
120 dtype: &DataType,
121 persist_remap: bool,
122 output_keys_only: bool,
123 compat_level: CompatLevel,
124 ) -> Box<dyn Array>
125 where
126 T: DictionaryKey + NativeType + std::hash::Hash + Eq,
127 usize: AsPrimitive<T>,
128 {
129 let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = dtype else {
130 unreachable!()
131 };
132
133 let input_mapping_ptr: *const CategoricalMapping = Arc::as_ptr(mapping);
134
135 let keys_arr: PrimitiveArray<T> = match self {
136 Self::Categorical { mapping, key_remap } => {
137 assert_eq!(input_mapping_ptr, Arc::as_ptr(mapping));
138
139 let key_remap: &mut PlIndexSet<T> = key_remap.as_any_mut().downcast_mut().unwrap();
140
141 if !persist_remap {
142 key_remap.clear()
143 }
144
145 keys_arr
146 .iter()
147 .map(|x| {
148 x.map(|x: &T| {
149 let idx: usize = key_remap.insert_full(*x).0;
150 let out: T = idx.as_();
152 out
153 })
154 })
155 .collect()
156 },
157 Self::Enum { mapping, .. } => {
158 assert_eq!(input_mapping_ptr, Arc::as_ptr(mapping));
159 keys_arr.clone()
160 },
161 };
162
163 if output_keys_only {
164 return keys_arr.boxed();
165 }
166
167 let values = self.build_values_array(compat_level);
168
169 let dictionary_dtype = ArrowDataType::Dictionary(
170 <T as DictionaryKey>::KEY_TYPE,
171 Box::new(values.dtype().clone()),
172 false, );
174
175 unsafe {
176 DictionaryArray::<T>::try_new_unchecked(dictionary_dtype, keys_arr, values)
177 .unwrap()
178 .boxed()
179 }
180 }
181
182 pub fn build_values_array(&self, compat_level: CompatLevel) -> Box<dyn Array> {
188 match self {
189 Self::Categorical { mapping, key_remap } => match key_remap {
190 CategoricalKeyRemap::U8(keys) => self.build_values_array_from_keys(
191 keys.iter().map(|x: &u8| *x as CatSize),
192 mapping,
193 compat_level,
194 ),
195 CategoricalKeyRemap::U16(keys) => self.build_values_array_from_keys(
196 keys.iter().map(|x: &u16| *x as CatSize),
197 mapping,
198 compat_level,
199 ),
200 CategoricalKeyRemap::U32(keys) => self.build_values_array_from_keys(
201 keys.iter().map(|x: &u32| *x as CatSize),
202 mapping,
203 compat_level,
204 ),
205 },
206
207 Self::Enum { frozen, .. } => {
208 let array: &Utf8ViewArray = frozen.categories();
209
210 if compat_level != CompatLevel::oldest() {
211 array.to_boxed()
212 } else {
213 cast_unchecked(array, &ArrowDataType::LargeUtf8).unwrap()
216 }
217 },
218 }
219 }
220
221 fn build_values_array_from_keys<I>(
222 &self,
223 keys_iter: I,
224 mapping: &CategoricalMapping,
225 compat_level: CompatLevel,
226 ) -> Box<dyn Array>
227 where
228 I: ExactSizeIterator<Item = CatSize>,
229 {
230 if compat_level != CompatLevel::oldest() {
231 let mut builder = Utf8ViewArrayBuilder::new(ArrowDataType::Utf8View);
232 builder.reserve(keys_iter.len());
233
234 for x in keys_iter {
235 builder.push_value_ignore_validity(mapping.cat_to_str(x).unwrap())
236 }
237
238 builder.freeze().to_boxed()
239 } else {
240 let mut builder: MutableUtf8Array<i64> = MutableUtf8Array::new();
241 builder.reserve(keys_iter.len(), 0);
242
243 for x in keys_iter {
244 builder.push(Some(mapping.cat_to_str(x).unwrap()));
245 }
246
247 let out: Utf8Array<i64> = builder.into();
248 out.boxed()
249 }
250 }
251}
252
253pub enum CategoricalKeyRemap {
254 U8(PlIndexSet<u8>),
255 U16(PlIndexSet<u16>),
256 U32(PlIndexSet<u32>),
257}
258
259impl CategoricalKeyRemap {
260 fn as_any_mut(&mut self) -> &mut dyn Any {
261 match self {
262 Self::U8(v) => v as _,
263 Self::U16(v) => v as _,
264 Self::U32(v) => v as _,
265 }
266 }
267}
268
269impl From<PlIndexSet<u8>> for CategoricalKeyRemap {
270 fn from(value: PlIndexSet<u8>) -> Self {
271 Self::U8(value)
272 }
273}
274
275impl From<PlIndexSet<u16>> for CategoricalKeyRemap {
276 fn from(value: PlIndexSet<u16>) -> Self {
277 Self::U16(value)
278 }
279}
280
281impl From<PlIndexSet<u32>> for CategoricalKeyRemap {
282 fn from(value: PlIndexSet<u32>) -> Self {
283 Self::U32(value)
284 }
285}