polars_core/chunked_array/logical/
categorical.rs1use std::marker::PhantomData;
2
3use arrow::bitmap::BitmapBuilder;
4use num_traits::Zero;
5
6use crate::chunked_array::cast::CastOptions;
7use crate::chunked_array::flags::StatisticsFlags;
8use crate::chunked_array::ops::ChunkFullNull;
9use crate::prelude::*;
10use crate::series::IsSorted;
11use crate::utils::handle_casting_failures;
12
13pub type CategoricalChunked<T> = Logical<T, <T as PolarsCategoricalType>::PolarsPhysical>;
14pub type Categorical8Chunked = CategoricalChunked<Categorical8Type>;
15pub type Categorical16Chunked = CategoricalChunked<Categorical16Type>;
16pub type Categorical32Chunked = CategoricalChunked<Categorical32Type>;
17
18pub trait CategoricalPhysicalDtypeExt {
19 fn dtype(&self) -> DataType;
20}
21
22impl CategoricalPhysicalDtypeExt for CategoricalPhysical {
23 fn dtype(&self) -> DataType {
24 match self {
25 Self::U8 => DataType::UInt8,
26 Self::U16 => DataType::UInt16,
27 Self::U32 => DataType::UInt32,
28 }
29 }
30}
31
32impl<T: PolarsCategoricalType> CategoricalChunked<T> {
33 pub fn is_enum(&self) -> bool {
34 matches!(self.dtype(), DataType::Enum(_, _))
35 }
36
37 pub(crate) fn get_flags(&self) -> StatisticsFlags {
38 let mut flags = self.phys.get_flags();
41 if self.uses_lexical_ordering() {
42 flags.set_sorted(IsSorted::Not);
43 }
44 flags
45 }
46
47 pub(crate) fn set_flags(&mut self, mut flags: StatisticsFlags) {
49 if self.uses_lexical_ordering() {
51 flags.set_sorted(IsSorted::Not)
52 }
53 self.physical_mut().set_flags(flags)
54 }
55
56 pub fn uses_lexical_ordering(&self) -> bool {
59 !self.is_enum()
60 }
61
62 pub fn full_null_with_dtype(name: PlSmallStr, length: usize, dtype: DataType) -> Self {
63 let phys =
64 ChunkedArray::<<T as PolarsCategoricalType>::PolarsPhysical>::full_null(name, length);
65 unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) }
66 }
67
68 pub fn from_cats_and_dtype(
72 mut cat_ids: ChunkedArray<T::PolarsPhysical>,
73 dtype: DataType,
74 ) -> Self {
75 let (DataType::Enum(_, mapping) | DataType::Categorical(_, mapping)) = &dtype else {
76 panic!("from_cats_and_dtype called on non-categorical type")
77 };
78 assert!(dtype.cat_physical().ok() == Some(T::physical()));
79
80 unsafe {
81 let mut invariants_violated = false;
82 let mut validity = BitmapBuilder::new();
83 for arr in cat_ids.downcast_iter_mut() {
84 validity.reserve(arr.len());
85 if arr.has_nulls() {
86 for opt_cat_id in arr.iter() {
87 if let Some(cat_id) = opt_cat_id {
88 validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
89 } else {
90 validity.push_unchecked(false);
91 }
92 }
93 } else {
94 for cat_id in arr.values_iter() {
95 validity.push_unchecked(mapping.cat_to_str(cat_id.as_cat()).is_some());
96 }
97 }
98
99 if arr.null_count() != validity.unset_bits() {
100 invariants_violated = true;
101 arr.set_validity(core::mem::take(&mut validity).into_opt_validity());
102 } else {
103 validity.clear();
104 }
105 }
106
107 if invariants_violated {
108 cat_ids.set_flags(StatisticsFlags::empty());
109 cat_ids.compute_len();
110 }
111 }
112
113 Self {
114 phys: cat_ids,
115 dtype,
116 _phantom: PhantomData,
117 }
118 }
119
120 pub unsafe fn from_cats_and_dtype_unchecked(
125 cat_ids: ChunkedArray<T::PolarsPhysical>,
126 dtype: DataType,
127 ) -> Self {
128 debug_assert!(dtype.cat_physical().ok() == Some(T::physical()));
129
130 Self {
131 phys: cat_ids,
132 dtype,
133 _phantom: PhantomData,
134 }
135 }
136
137 pub fn get_mapping(&self) -> &Arc<CategoricalMapping> {
139 let (DataType::Categorical(_, mapping) | DataType::Enum(_, mapping)) = self.dtype() else {
140 unreachable!()
141 };
142 mapping
143 }
144
145 pub fn iter_str(&self) -> impl PolarsIterator<Item = Option<&str>> {
147 let mapping = self.get_mapping();
148 self.phys
149 .iter()
150 .map(|cat| unsafe { Some(mapping.cat_to_str_unchecked(cat?.as_cat())) })
151 }
152
153 pub fn from_str_iter<'a, I: IntoIterator<Item = Option<&'a str>>>(
157 name: PlSmallStr,
158 dtype: DataType,
159 strings: I,
160 ) -> PolarsResult<Self> {
161 let strings = strings.into_iter();
162
163 let hint = strings.size_hint().0;
164 let mut cat_ids = Vec::with_capacity(hint);
165 let mut validity = BitmapBuilder::with_capacity(hint);
166
167 match &dtype {
168 DataType::Categorical(cats, mapping) => {
169 assert!(cats.physical() == T::physical());
170 for opt_s in strings {
171 cat_ids.push(if let Some(s) = opt_s {
172 T::Native::from_cat(mapping.insert_cat(s)?)
173 } else {
174 T::Native::zero()
175 });
176 validity.push(opt_s.is_some());
177 }
178 },
179 DataType::Enum(fcats, mapping) => {
180 assert!(fcats.physical() == T::physical());
181 for opt_s in strings {
182 cat_ids.push(if let Some(cat) = opt_s.and_then(|s| mapping.get_cat(s)) {
183 validity.push(true);
184 T::Native::from_cat(cat)
185 } else {
186 validity.push(false);
187 T::Native::zero()
188 });
189 }
190 },
191 _ => panic!("from_strings_and_dtype_strict called on non-categorical type"),
192 }
193
194 let arr = <T::PolarsPhysical as PolarsDataType>::Array::from_vec(cat_ids)
195 .with_validity(validity.into_opt_validity());
196 let phys = ChunkedArray::<T::PolarsPhysical>::with_chunk(name, arr);
197 Ok(unsafe { Self::from_cats_and_dtype_unchecked(phys, dtype) })
198 }
199
200 pub fn to_arrow(&self, compat_level: CompatLevel) -> DictionaryArray<T::Native> {
201 let keys = self.physical().rechunk();
202 let keys = keys.downcast_as_array();
203 let values = self
204 .get_mapping()
205 .to_arrow(compat_level != CompatLevel::oldest());
206 let values_dtype = Box::new(values.dtype().clone());
207 let dtype =
208 ArrowDataType::Dictionary(<T::Native as DictionaryKey>::KEY_TYPE, values_dtype, false);
209 unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() }
210 }
211}
212
213impl<T: PolarsCategoricalType> LogicalType for CategoricalChunked<T> {
214 fn dtype(&self) -> &DataType {
215 &self.dtype
216 }
217
218 fn get_any_value(&self, i: usize) -> PolarsResult<AnyValue<'_>> {
219 polars_ensure!(i < self.len(), oob = i, self.len());
220 Ok(unsafe { self.get_any_value_unchecked(i) })
221 }
222
223 unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
224 match self.phys.get_unchecked(i) {
225 Some(i) => match &self.dtype {
226 DataType::Enum(_, mapping) => AnyValue::Enum(i.as_cat(), mapping),
227 DataType::Categorical(_, mapping) => AnyValue::Categorical(i.as_cat(), mapping),
228 _ => unreachable!(),
229 },
230 None => AnyValue::Null,
231 }
232 }
233
234 fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
235 if &self.dtype == dtype {
236 return Ok(self.clone().into_series());
237 }
238
239 match dtype {
240 DataType::String => {
241 let mapping = self.get_mapping();
242
243 let mut builder = StringChunkedBuilder::new(self.phys.name().clone(), self.len());
246 let to_str = |cat_id: CatSize| unsafe { mapping.cat_to_str_unchecked(cat_id) };
247 if !self.phys.has_nulls() {
248 for cat_id in self.phys.into_no_null_iter() {
249 builder.append_value(to_str(cat_id.as_cat()));
250 }
251 } else {
252 for opt_cat_id in self.phys.into_iter() {
253 let opt_cat_id: Option<_> = opt_cat_id;
254 builder.append_option(opt_cat_id.map(|c| to_str(c.as_cat())));
255 }
256 }
257
258 let ca = builder.finish();
259 Ok(ca.into_series())
260 },
261
262 DataType::Enum(fcats, _mapping) => {
263 let ret = with_match_categorical_physical_type!(fcats.physical(), |$C| {
265 CategoricalChunked::<$C>::from_str_iter(
266 self.name().clone(),
267 dtype.clone(),
268 self.iter_str()
269 )?.into_series()
270 });
271
272 if options.is_strict() && self.null_count() != ret.null_count() {
273 handle_casting_failures(&self.clone().into_series(), &ret)?;
274 }
275
276 Ok(ret)
277 },
278
279 DataType::Categorical(cats, _mapping) => {
280 Ok(
282 with_match_categorical_physical_type!(cats.physical(), |$C| {
283 CategoricalChunked::<$C>::from_str_iter(
284 self.name().clone(),
285 dtype.clone(),
286 self.iter_str()
287 )?.into_series()
288 }),
289 )
290 },
291
292 dt if dt.is_integer() => self.phys.clone().cast_with_options(dtype, options),
295
296 _ => polars_bail!(ComputeError: "cannot cast categorical types to {dtype:?}"),
297 }
298 }
299}