polars_core/series/
categorical_to_arrow.rs

1use std::any::Any;
2
3use arrow::array::builder::ArrayBuilder;
4use arrow::types::NativeType;
5use num_traits::AsPrimitive;
6use polars_compute::cast::cast_unchecked;
7
8use crate::prelude::*;
9
10/// Categorical converter that prunes unused categories.
11pub struct CategoricalToArrowConverter {
12    /// Converters keyed by the Arc address of `Arc<CategoricalMapping>`.
13    ///
14    /// # Safety
15    /// The `usize` key remains valid as `CategoricalArrayToArrowConverter` holds a ref-count to the
16    /// `Arc<>` that the key is derived from.
17    pub converters: PlIndexMap<usize, CategoricalArrayToArrowConverter>,
18    /// Persist the key remap to ensure consistent mapping across multiple calls.
19    pub persist_remap: bool,
20    /// Return only the keys array when going to arrow.
21    pub output_keys_only: bool,
22}
23
24impl CategoricalToArrowConverter {
25    /// # Panics
26    /// Panics if:
27    /// * `keys_arr` is not of a `Categorical` or `Enum` type
28    /// * The arc address of the `Arc<CategoricalMapping>` is not present within `self.converters`
29    ///   (likely due to forgetting to call `initialize()` on this converter).
30    pub fn array_to_arrow(
31        &mut self,
32        keys_arr: &dyn Array,
33        dtype: &DataType,
34        compat_level: CompatLevel,
35    ) -> Box<dyn Array> {
36        let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = dtype else {
37            unreachable!()
38        };
39
40        let key = Arc::as_ptr(mapping) as *const () as usize;
41        let converter = self.converters.get_mut(&key).unwrap();
42
43        with_match_categorical_physical_type!(dtype.cat_physical().unwrap(), |$C| {
44            let keys_arr: &PrimitiveArray<<$C as PolarsCategoricalType>::Native> = keys_arr.as_any().downcast_ref().unwrap();
45
46            converter.array_to_arrow(
47                keys_arr,
48                dtype,
49                self.persist_remap,
50                self.output_keys_only,
51                compat_level
52            )
53        })
54    }
55
56    /// Initializes categorical converters for all categorical mappings present in this dtype.
57    pub fn initialize(&mut self, dtype: &DataType) {
58        use DataType::*;
59
60        match dtype {
61            Categorical(_categories, mapping) => {
62                let key = Arc::as_ptr(mapping) as *const () as usize;
63
64                if !self.converters.contains_key(&key) {
65                    with_match_categorical_physical_type!(dtype.cat_physical().unwrap(), |$C| {
66                        self.converters.insert(
67                            key,
68                            CategoricalArrayToArrowConverter::Categorical {
69                                mapping: mapping.clone(),
70                                key_remap: CategoricalKeyRemap::from(
71                                    PlIndexSet::<<$C as PolarsCategoricalType>::Native>::with_capacity(
72                                        mapping.num_cats_upper_bound()
73                                    )
74                                ),
75                            },
76                        );
77                    })
78                }
79            },
80            Enum(categories, mapping) => {
81                let key = Arc::as_ptr(mapping) as *const () as usize;
82
83                if !self.converters.contains_key(&key) {
84                    self.converters.insert(
85                        key,
86                        CategoricalArrayToArrowConverter::Enum {
87                            frozen: categories.clone(),
88                            mapping: mapping.clone(),
89                        },
90                    );
91                }
92            },
93            List(inner) => self.initialize(inner),
94            #[cfg(feature = "dtype-array")]
95            Array(inner, _width) => self.initialize(inner),
96            #[cfg(feature = "dtype-struct")]
97            Struct(fields) => {
98                for field in fields {
99                    self.initialize(field.dtype())
100                }
101            },
102            #[cfg(feature = "dtype-extension")]
103            Extension(_, inner) => self.initialize(inner),
104            _ => assert!(!dtype.is_nested()),
105        }
106    }
107}
108
109pub enum CategoricalArrayToArrowConverter {
110    // # Safety
111    // All enum variants must hold a ref to the `Arc<CategoricalMapping>`, as this enum is inserted
112    // into a hashmap keyed by the address of the `Arc`ed value.
113    //
114    Categorical {
115        mapping: Arc<CategoricalMapping>,
116        key_remap: CategoricalKeyRemap,
117    },
118    /// Enum keys are not remapped, but we still track this variant to support
119    /// the `build_values_array()` function.
120    Enum {
121        mapping: Arc<CategoricalMapping>,
122        frozen: Arc<FrozenCategories>,
123    },
124}
125
126impl CategoricalArrayToArrowConverter {
127    fn array_to_arrow<T>(
128        &mut self,
129        keys_arr: &PrimitiveArray<T>,
130        dtype: &DataType,
131        persist_remap: bool,
132        output_keys_only: bool,
133        compat_level: CompatLevel,
134    ) -> Box<dyn Array>
135    where
136        T: DictionaryKey + NativeType + std::hash::Hash + Eq,
137        usize: AsPrimitive<T>,
138    {
139        let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = dtype else {
140            unreachable!()
141        };
142
143        let input_mapping_ptr: *const CategoricalMapping = Arc::as_ptr(mapping);
144
145        let keys_arr: PrimitiveArray<T> = match self {
146            Self::Categorical { mapping, key_remap } => {
147                assert_eq!(input_mapping_ptr, Arc::as_ptr(mapping));
148
149                let key_remap: &mut PlIndexSet<T> = key_remap.as_any_mut().downcast_mut().unwrap();
150
151                if !persist_remap {
152                    key_remap.clear()
153                }
154
155                keys_arr
156                    .iter()
157                    .map(|x| {
158                        x.map(|x: &T| {
159                            let idx: usize = key_remap.insert_full(*x).0;
160                            // Indexset of T cannot return an index exceeding T::MAX.
161                            let out: T = idx.as_();
162                            out
163                        })
164                    })
165                    .collect()
166            },
167            Self::Enum { mapping, .. } => {
168                assert_eq!(input_mapping_ptr, Arc::as_ptr(mapping));
169                keys_arr.clone()
170            },
171        };
172
173        if output_keys_only {
174            return keys_arr.boxed();
175        }
176
177        let values = self.build_values_array(compat_level);
178
179        let dictionary_dtype = ArrowDataType::Dictionary(
180            <T as DictionaryKey>::KEY_TYPE,
181            Box::new(values.dtype().clone()),
182            false, // is_sorted
183        );
184
185        unsafe {
186            DictionaryArray::<T>::try_new_unchecked(dictionary_dtype, keys_arr, values)
187                .unwrap()
188                .boxed()
189        }
190    }
191
192    /// Build the values array of the dictionary:
193    /// * If `Self` is `::Categorical`, this builds according to the current `key_remap` state:
194    ///   * If `persist_remap` is `true`, this state will hold all the keys this converter has encountered.
195    ///     It will otherwise hold only the keys seen from the last `array_to_arrow()` call.
196    /// * If `Self` is `::Enum`, this returns the full set of values present in the Enum's `FrozenCategories`.
197    pub fn build_values_array(&self, compat_level: CompatLevel) -> Box<dyn Array> {
198        match self {
199            Self::Categorical { mapping, key_remap } => match key_remap {
200                CategoricalKeyRemap::U8(keys) => self.build_values_array_from_keys(
201                    keys.iter().map(|x: &u8| *x as CatSize),
202                    mapping,
203                    compat_level,
204                ),
205                CategoricalKeyRemap::U16(keys) => self.build_values_array_from_keys(
206                    keys.iter().map(|x: &u16| *x as CatSize),
207                    mapping,
208                    compat_level,
209                ),
210                CategoricalKeyRemap::U32(keys) => self.build_values_array_from_keys(
211                    keys.iter().map(|x: &u32| *x as CatSize),
212                    mapping,
213                    compat_level,
214                ),
215            },
216
217            Self::Enum { frozen, .. } => {
218                let array: &Utf8ViewArray = frozen.categories();
219
220                if compat_level != CompatLevel::oldest() {
221                    array.to_boxed()
222                } else {
223                    // Note: Could store a once-init Utf8Array on the frozen categories to avoid
224                    // building this multiple times for the oldest compat level.
225                    cast_unchecked(array, &ArrowDataType::LargeUtf8).unwrap()
226                }
227            },
228        }
229    }
230
231    fn build_values_array_from_keys<I>(
232        &self,
233        keys_iter: I,
234        mapping: &CategoricalMapping,
235        compat_level: CompatLevel,
236    ) -> Box<dyn Array>
237    where
238        I: ExactSizeIterator<Item = CatSize>,
239    {
240        if compat_level != CompatLevel::oldest() {
241            let mut builder = Utf8ViewArrayBuilder::new(ArrowDataType::Utf8View);
242            builder.reserve(keys_iter.len());
243
244            for x in keys_iter {
245                builder.push_value_ignore_validity(mapping.cat_to_str(x).unwrap())
246            }
247
248            builder.freeze().to_boxed()
249        } else {
250            let mut builder: MutableUtf8Array<i64> = MutableUtf8Array::new();
251            builder.reserve(keys_iter.len(), 0);
252
253            for x in keys_iter {
254                builder.push(Some(mapping.cat_to_str(x).unwrap()));
255            }
256
257            let out: Utf8Array<i64> = builder.into();
258            out.boxed()
259        }
260    }
261}
262
263pub enum CategoricalKeyRemap {
264    U8(PlIndexSet<u8>),
265    U16(PlIndexSet<u16>),
266    U32(PlIndexSet<u32>),
267}
268
269impl CategoricalKeyRemap {
270    fn as_any_mut(&mut self) -> &mut dyn Any {
271        match self {
272            Self::U8(v) => v as _,
273            Self::U16(v) => v as _,
274            Self::U32(v) => v as _,
275        }
276    }
277}
278
279impl From<PlIndexSet<u8>> for CategoricalKeyRemap {
280    fn from(value: PlIndexSet<u8>) -> Self {
281        Self::U8(value)
282    }
283}
284
285impl From<PlIndexSet<u16>> for CategoricalKeyRemap {
286    fn from(value: PlIndexSet<u16>) -> Self {
287        Self::U16(value)
288    }
289}
290
291impl From<PlIndexSet<u32>> for CategoricalKeyRemap {
292    fn from(value: PlIndexSet<u32>) -> Self {
293        Self::U32(value)
294    }
295}