polars_core/series/
categorical_to_arrow.rs

1use std::any::Any;
2
3use arrow::array::builder::ArrayBuilder;
4use arrow::types::NativeType;
5use num_traits::AsPrimitive;
6use polars_compute::cast::cast_unchecked;
7
8use crate::prelude::*;
9
10/// Categorical converter that prunes unused categories.
11pub struct CategoricalToArrowConverter {
12    /// Converters keyed by the Arc address of `Arc<CategoricalMapping>`.
13    pub converters: PlIndexMap<usize, CategoricalArrayToArrowConverter>,
14    /// Persist the key remap to ensure consistent mapping across multiple calls.
15    pub persist_remap: bool,
16    /// Return only the keys array when going to arrow.
17    pub output_keys_only: bool,
18}
19
20impl CategoricalToArrowConverter {
21    /// # Panics
22    /// Panics if:
23    /// * `keys_arr` is not of a `Categorical` or `Enum` type
24    /// * The arc address of the `Arc<CategoricalMapping>` is not present within `self.converters`
25    ///   (likely due to forgetting to call `initialize()` on this converter).
26    pub fn array_to_arrow(
27        &mut self,
28        keys_arr: &dyn Array,
29        dtype: &DataType,
30        compat_level: CompatLevel,
31    ) -> Box<dyn Array> {
32        let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = dtype else {
33            unreachable!()
34        };
35
36        let key = Arc::as_ptr(mapping) as *const () as usize;
37        let converter = self.converters.get_mut(&key).unwrap();
38
39        with_match_categorical_physical_type!(dtype.cat_physical().unwrap(), |$C| {
40            let keys_arr: &PrimitiveArray<<$C as PolarsCategoricalType>::Native> = keys_arr.as_any().downcast_ref().unwrap();
41
42            converter.array_to_arrow(
43                keys_arr,
44                dtype,
45                self.persist_remap,
46                self.output_keys_only,
47                compat_level
48            )
49        })
50    }
51
52    /// Initializes categorical converters for all categorical mappings present in this dtype.
53    pub fn initialize(&mut self, dtype: &DataType) {
54        use DataType::*;
55
56        match dtype {
57            Categorical(_categories, mapping) => {
58                let key = Arc::as_ptr(mapping) as *const () as usize;
59
60                if !self.converters.contains_key(&key) {
61                    with_match_categorical_physical_type!(dtype.cat_physical().unwrap(), |$C| {
62                        self.converters.insert(
63                            key,
64                            CategoricalArrayToArrowConverter::Categorical {
65                                mapping: mapping.clone(),
66                                key_remap: CategoricalKeyRemap::from(
67                                    PlIndexSet::<<$C as PolarsCategoricalType>::Native>::with_capacity(
68                                        mapping.num_cats_upper_bound()
69                                    )
70                                ),
71                            },
72                        );
73                    })
74                }
75            },
76            Enum(categories, mapping) => {
77                let key = Arc::as_ptr(mapping) as *const () as usize;
78
79                if !self.converters.contains_key(&key) {
80                    self.converters.insert(
81                        key,
82                        CategoricalArrayToArrowConverter::Enum {
83                            frozen: categories.clone(),
84                            mapping: mapping.clone(),
85                        },
86                    );
87                }
88            },
89            List(inner) => self.initialize(inner),
90            #[cfg(feature = "dtype-array")]
91            Array(inner, _width) => self.initialize(inner),
92            #[cfg(feature = "dtype-struct")]
93            Struct(fields) => {
94                for field in fields {
95                    self.initialize(field.dtype())
96                }
97            },
98            _ => assert!(!dtype.is_nested()),
99        }
100    }
101}
102
103pub enum CategoricalArrayToArrowConverter {
104    Categorical {
105        mapping: Arc<CategoricalMapping>,
106        key_remap: CategoricalKeyRemap,
107    },
108    /// Enum keys are not remapped, but we still track this variant to support
109    /// the `build_values_array()` function.
110    Enum {
111        mapping: Arc<CategoricalMapping>,
112        frozen: Arc<FrozenCategories>,
113    },
114}
115
116impl CategoricalArrayToArrowConverter {
117    fn array_to_arrow<T>(
118        &mut self,
119        keys_arr: &PrimitiveArray<T>,
120        dtype: &DataType,
121        persist_remap: bool,
122        output_keys_only: bool,
123        compat_level: CompatLevel,
124    ) -> Box<dyn Array>
125    where
126        T: DictionaryKey + NativeType + std::hash::Hash + Eq,
127        usize: AsPrimitive<T>,
128    {
129        let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = dtype else {
130            unreachable!()
131        };
132
133        let input_mapping_ptr: *const CategoricalMapping = Arc::as_ptr(mapping);
134
135        let keys_arr: PrimitiveArray<T> = match self {
136            Self::Categorical { mapping, key_remap } => {
137                assert_eq!(input_mapping_ptr, Arc::as_ptr(mapping));
138
139                let key_remap: &mut PlIndexSet<T> = key_remap.as_any_mut().downcast_mut().unwrap();
140
141                if !persist_remap {
142                    key_remap.clear()
143                }
144
145                keys_arr
146                    .iter()
147                    .map(|x| {
148                        x.map(|x: &T| {
149                            let idx: usize = key_remap.insert_full(*x).0;
150                            // Indexset of T cannot return an index exceeding T::MAX.
151                            let out: T = idx.as_();
152                            out
153                        })
154                    })
155                    .collect()
156            },
157            Self::Enum { mapping, .. } => {
158                assert_eq!(input_mapping_ptr, Arc::as_ptr(mapping));
159                keys_arr.clone()
160            },
161        };
162
163        if output_keys_only {
164            return keys_arr.boxed();
165        }
166
167        let values = self.build_values_array(compat_level);
168
169        let dictionary_dtype = ArrowDataType::Dictionary(
170            <T as DictionaryKey>::KEY_TYPE,
171            Box::new(values.dtype().clone()),
172            false, // is_sorted
173        );
174
175        unsafe {
176            DictionaryArray::<T>::try_new_unchecked(dictionary_dtype, keys_arr, values)
177                .unwrap()
178                .boxed()
179        }
180    }
181
182    /// Build the values array of the dictionary:
183    /// * If `Self` is `::Categorical`, this builds according to the current `key_remap` state:
184    ///   * If `persist_remap` is `true`, this state will hold all the keys this converter has encountered.
185    ///     It will otherwise hold only the keys seen from the last `array_to_arrow()` call.
186    /// * If `Self` is `::Enum`, this returns the full set of values present in the Enum's `FrozenCategories`.
187    pub fn build_values_array(&self, compat_level: CompatLevel) -> Box<dyn Array> {
188        match self {
189            Self::Categorical { mapping, key_remap } => match key_remap {
190                CategoricalKeyRemap::U8(keys) => self.build_values_array_from_keys(
191                    keys.iter().map(|x: &u8| *x as CatSize),
192                    mapping,
193                    compat_level,
194                ),
195                CategoricalKeyRemap::U16(keys) => self.build_values_array_from_keys(
196                    keys.iter().map(|x: &u16| *x as CatSize),
197                    mapping,
198                    compat_level,
199                ),
200                CategoricalKeyRemap::U32(keys) => self.build_values_array_from_keys(
201                    keys.iter().map(|x: &u32| *x as CatSize),
202                    mapping,
203                    compat_level,
204                ),
205            },
206
207            Self::Enum { frozen, .. } => {
208                let array: &Utf8ViewArray = frozen.categories();
209
210                if compat_level != CompatLevel::oldest() {
211                    array.to_boxed()
212                } else {
213                    // Note: Could store a once-init Utf8Array on the frozen categories to avoid
214                    // building this multiple times for the oldest compat level.
215                    cast_unchecked(array, &ArrowDataType::LargeUtf8).unwrap()
216                }
217            },
218        }
219    }
220
221    fn build_values_array_from_keys<I>(
222        &self,
223        keys_iter: I,
224        mapping: &CategoricalMapping,
225        compat_level: CompatLevel,
226    ) -> Box<dyn Array>
227    where
228        I: ExactSizeIterator<Item = CatSize>,
229    {
230        if compat_level != CompatLevel::oldest() {
231            let mut builder = Utf8ViewArrayBuilder::new(ArrowDataType::Utf8View);
232            builder.reserve(keys_iter.len());
233
234            for x in keys_iter {
235                builder.push_value_ignore_validity(mapping.cat_to_str(x).unwrap())
236            }
237
238            builder.freeze().to_boxed()
239        } else {
240            let mut builder: MutableUtf8Array<i64> = MutableUtf8Array::new();
241            builder.reserve(keys_iter.len(), 0);
242
243            for x in keys_iter {
244                builder.push(Some(mapping.cat_to_str(x).unwrap()));
245            }
246
247            let out: Utf8Array<i64> = builder.into();
248            out.boxed()
249        }
250    }
251}
252
253pub enum CategoricalKeyRemap {
254    U8(PlIndexSet<u8>),
255    U16(PlIndexSet<u16>),
256    U32(PlIndexSet<u32>),
257}
258
259impl CategoricalKeyRemap {
260    fn as_any_mut(&mut self) -> &mut dyn Any {
261        match self {
262            Self::U8(v) => v as _,
263            Self::U16(v) => v as _,
264            Self::U32(v) => v as _,
265        }
266    }
267}
268
269impl From<PlIndexSet<u8>> for CategoricalKeyRemap {
270    fn from(value: PlIndexSet<u8>) -> Self {
271        Self::U8(value)
272    }
273}
274
275impl From<PlIndexSet<u16>> for CategoricalKeyRemap {
276    fn from(value: PlIndexSet<u16>) -> Self {
277        Self::U16(value)
278    }
279}
280
281impl From<PlIndexSet<u32>> for CategoricalKeyRemap {
282    fn from(value: PlIndexSet<u32>) -> Self {
283        Self::U32(value)
284    }
285}