Skip to main content

polars_core/series/implementations/
categorical.rs

1use super::*;
2use crate::chunked_array::comparison::*;
3use crate::prelude::*;
4
5unsafe impl<T: PolarsCategoricalType> IntoSeries for CategoricalChunked<T> {
6    fn into_series(self) -> Series {
7        // We do this hack to go from generic T to concrete T to avoid adding bounds on IntoSeries.
8        with_match_categorical_physical_type!(T::physical(), |$C| {
9            unsafe {
10                Series(Arc::new(SeriesWrap(core::mem::transmute::<Self, CategoricalChunked<$C>>(self))))
11            }
12        })
13    }
14}
15
16impl<T: PolarsCategoricalType> SeriesWrap<CategoricalChunked<T>> {
17    unsafe fn apply_on_phys<F>(&self, apply: F) -> CategoricalChunked<T>
18    where
19        F: FnOnce(&ChunkedArray<T::PolarsPhysical>) -> ChunkedArray<T::PolarsPhysical>,
20    {
21        let cats = apply(self.0.physical());
22        unsafe { CategoricalChunked::from_cats_and_dtype_unchecked(cats, self.0.dtype().clone()) }
23    }
24
25    unsafe fn try_apply_on_phys<F>(&self, apply: F) -> PolarsResult<CategoricalChunked<T>>
26    where
27        F: FnOnce(
28            &ChunkedArray<T::PolarsPhysical>,
29        ) -> PolarsResult<ChunkedArray<T::PolarsPhysical>>,
30    {
31        let cats = apply(self.0.physical())?;
32        unsafe {
33            Ok(CategoricalChunked::from_cats_and_dtype_unchecked(
34                cats,
35                self.0.dtype().clone(),
36            ))
37        }
38    }
39}
40
41macro_rules! impl_cat_series {
42    ($ca: ident, $pdt:ty, $ca_fn:ident) => {
43        impl private::PrivateSeries for SeriesWrap<$ca> {
44            fn compute_len(&mut self) {
45                self.0.physical_mut().compute_len()
46            }
47            fn _field(&self) -> Cow<'_, Field> {
48                Cow::Owned(self.0.field())
49            }
50            fn _dtype(&self) -> &DataType {
51                self.0.dtype()
52            }
53            fn _get_flags(&self) -> StatisticsFlags {
54                self.0.get_flags()
55            }
56            fn _set_flags(&mut self, flags: StatisticsFlags) {
57                self.0.set_flags(flags)
58            }
59
60            unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool {
61                self.0.physical().equal_element(idx_self, idx_other, other)
62            }
63
64            #[cfg(feature = "zip_with")]
65            fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult<Series> {
66                polars_ensure!(self.dtype() == other.dtype(), SchemaMismatch: "expected '{}' found '{}'", self.dtype(), other.dtype());
67                let other = other.to_physical_repr().into_owned();
68                unsafe {
69                    Ok(self.try_apply_on_phys(|ca| {
70                        ca.zip_with(mask, other.as_ref().as_ref())
71                    })?.into_series())
72                }
73            }
74
75            fn into_total_ord_inner<'a>(&'a self) -> Box<dyn TotalOrdInner + 'a> {
76                if self.0.uses_lexical_ordering() {
77                    (&self.0).into_total_ord_inner()
78                } else {
79                    self.0.physical().into_total_ord_inner()
80                }
81            }
82            fn into_total_eq_inner<'a>(&'a self) -> Box<dyn TotalEqInner + 'a> {
83                invalid_operation_panic!(into_total_eq_inner, self)
84            }
85
86            fn vec_hash(
87                &self,
88                random_state: PlSeedableRandomStateQuality,
89                buf: &mut Vec<u64>,
90            ) -> PolarsResult<()> {
91                self.0.vec_hash(random_state, buf)
92            }
93
94            fn vec_hash_combine(
95                &self,
96                build_hasher: PlSeedableRandomStateQuality,
97                hashes: &mut [u64],
98            ) -> PolarsResult<()> {
99                self.0.vec_hash_combine(build_hasher, hashes)
100            }
101
102            #[cfg(feature = "algorithm_group_by")]
103            unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
104                if self.0.uses_lexical_ordering() {
105                    unsafe { self.0.agg_min(groups) }
106                } else {
107                    self.apply_on_phys(|phys| phys.agg_min(groups).$ca_fn().unwrap().clone())
108                        .into_series()
109                }
110            }
111
112            #[cfg(feature = "algorithm_group_by")]
113            unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
114                if self.0.uses_lexical_ordering() {
115                    unsafe { self.0.agg_max(groups) }
116                } else {
117                    self.apply_on_phys(|phys| phys.agg_max(groups).$ca_fn().unwrap().clone())
118                        .into_series()
119                }
120            }
121
122            #[cfg(feature = "algorithm_group_by")]
123            unsafe fn agg_arg_min(&self, groups: &GroupsType) -> Series {
124                if self.0.uses_lexical_ordering() {
125                    unsafe { self.0.agg_arg_min(groups) }
126                } else {
127                    self.0.physical().agg_arg_min(groups)
128                }
129            }
130
131            #[cfg(feature = "algorithm_group_by")]
132            unsafe fn agg_arg_max(&self, groups: &GroupsType) -> Series {
133                if self.0.uses_lexical_ordering() {
134                    unsafe { self.0.agg_arg_max(groups) }
135                } else {
136                    self.0.physical().agg_arg_max(groups)
137                }
138            }
139
140
141            #[cfg(feature = "algorithm_group_by")]
142            unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
143                // we cannot cast and dispatch as the inner type of the list would be incorrect
144                let list = self.0.physical().agg_list(groups);
145                let mut list = list.list().unwrap().clone();
146                unsafe { list.to_logical(self.dtype().clone()) };
147                list.into_series()
148            }
149
150            #[cfg(feature = "algorithm_group_by")]
151            fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsType> {
152                self.0.physical().group_tuples(multithreaded, sorted)
153            }
154
155            fn arg_sort_multiple(
156                &self,
157                by: &[Column],
158                options: &SortMultipleOptions,
159            ) -> PolarsResult<IdxCa> {
160                self.0.arg_sort_multiple(by, options)
161            }
162        }
163
164        impl SeriesTrait for SeriesWrap<$ca> {
165            fn rename(&mut self, name: PlSmallStr) {
166                self.0.physical_mut().rename(name);
167            }
168
169            fn chunk_lengths(&self) -> ChunkLenIter<'_> {
170                self.0.physical().chunk_lengths()
171            }
172
173            fn name(&self) -> &PlSmallStr {
174                self.0.physical().name()
175            }
176
177            fn chunks(&self) -> &Vec<ArrayRef> {
178                self.0.physical().chunks()
179            }
180
181            unsafe fn chunks_mut(&mut self) -> &mut Vec<ArrayRef> {
182                self.0.physical_mut().chunks_mut()
183            }
184
185            fn shrink_to_fit(&mut self) {
186                self.0.physical_mut().shrink_to_fit()
187            }
188
189            fn slice(&self, offset: i64, length: usize) -> Series {
190                unsafe { self.apply_on_phys(|cats| cats.slice(offset, length)).into_series() }
191            }
192
193            fn split_at(&self, offset: i64) -> (Series, Series) {
194                unsafe {
195                    let (a, b) = self.0.physical().split_at(offset);
196                    let a = <$ca>::from_cats_and_dtype_unchecked(a, self.0.dtype().clone()).into_series();
197                    let b = <$ca>::from_cats_and_dtype_unchecked(b, self.0.dtype().clone()).into_series();
198                    (a, b)
199                }
200            }
201
202            fn append(&mut self, other: &Series) -> PolarsResult<()> {
203                polars_ensure!(self.0.dtype() == other.dtype(), append);
204                self.0.append(other.cat::<$pdt>().unwrap())
205            }
206
207            fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> {
208                polars_ensure!(self.0.dtype() == other.dtype(), append);
209                self.0.physical_mut().append_owned(std::mem::take(
210                    other
211                        ._get_inner_mut()
212                        .as_any_mut()
213                        .downcast_mut::<$ca>()
214                        .unwrap()
215                        .physical_mut(),
216                ))
217            }
218
219            fn extend(&mut self, other: &Series) -> PolarsResult<()> {
220                polars_ensure!(self.0.dtype() == other.dtype(), extend);
221                self.0.extend(other.cat::<$pdt>().unwrap())
222            }
223
224            fn filter(&self, filter: &BooleanChunked) -> PolarsResult<Series> {
225                unsafe { Ok(self.try_apply_on_phys(|cats| cats.filter(filter))?.into_series()) }
226            }
227
228            fn take(&self, indices: &IdxCa) -> PolarsResult<Series> {
229                unsafe { Ok(self.try_apply_on_phys(|cats| cats.take(indices))?.into_series() ) }
230            }
231
232            unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series {
233                unsafe { self.apply_on_phys(|cats| cats.take_unchecked(indices)).into_series() }
234            }
235
236            fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult<Series> {
237                unsafe { Ok(self.try_apply_on_phys(|cats| cats.take(indices))?.into_series()) }
238            }
239
240            unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series {
241                unsafe { self.apply_on_phys(|cats| cats.take_unchecked(indices)).into_series() }
242            }
243
244            fn deposit(&self, validity: &Bitmap) -> Series {
245                unsafe { self.apply_on_phys(|cats| cats.deposit(validity)) }
246                    .into_series()
247            }
248
249            fn len(&self) -> usize {
250                self.0.len()
251            }
252
253            fn rechunk(&self) -> Series {
254                unsafe { self.apply_on_phys(|cats| cats.rechunk().into_owned()).into_series() }
255            }
256
257            fn with_validity(&self, validity: Option<Bitmap>) -> Series {
258                unsafe { self.apply_on_phys(move |cats| cats.clone().with_validity(validity)).into_series() }
259            }
260
261            fn new_from_index(&self, index: usize, length: usize) -> Series {
262                unsafe { self.apply_on_phys(|cats| cats.new_from_index(index, length)).into_series() }
263            }
264
265            fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
266                self.0.cast_with_options(dtype, options)
267            }
268
269            #[inline]
270            unsafe fn get_unchecked(&self, index: usize) -> AnyValue<'_> {
271                self.0.get_any_value_unchecked(index)
272            }
273
274            fn sort_with(&self, options: SortOptions) -> PolarsResult<Series> {
275                Ok(self.0.sort_with(options).into_series())
276            }
277
278            fn arg_sort(&self, options: SortOptions) -> IdxCa {
279                self.0.arg_sort(options)
280            }
281
282            fn null_count(&self) -> usize {
283                self.0.physical().null_count()
284            }
285
286            fn has_nulls(&self) -> bool {
287                self.0.physical().has_nulls()
288            }
289
290            #[cfg(feature = "algorithm_group_by")]
291            fn unique(&self) -> PolarsResult<Series> {
292                unsafe { Ok(self.try_apply_on_phys(|cats| cats.unique())?.into_series()) }
293            }
294
295            #[cfg(feature = "algorithm_group_by")]
296            fn n_unique(&self) -> PolarsResult<usize> {
297                self.0.physical().n_unique()
298            }
299
300            #[cfg(feature = "approx_unique")]
301            fn approx_n_unique(&self) -> PolarsResult<IdxSize> {
302                Ok(self.0.physical().approx_n_unique())
303            }
304
305            #[cfg(feature = "algorithm_group_by")]
306            fn arg_unique(&self) -> PolarsResult<IdxCa> {
307                self.0.physical().arg_unique()
308            }
309
310            fn unique_id(&self) -> PolarsResult<(IdxSize, Vec<IdxSize>)> {
311                ChunkUnique::unique_id(self.0.physical())
312            }
313
314            fn is_null(&self) -> BooleanChunked {
315                self.0.physical().is_null()
316            }
317
318            fn is_not_null(&self) -> BooleanChunked {
319                self.0.physical().is_not_null()
320            }
321
322            fn reverse(&self) -> Series {
323                unsafe { self.apply_on_phys(|cats| cats.reverse()).into_series() }
324            }
325
326            fn as_single_ptr(&mut self) -> PolarsResult<usize> {
327                self.0.physical_mut().as_single_ptr()
328            }
329
330            fn shift(&self, periods: i64) -> Series {
331                unsafe { self.apply_on_phys(|ca| ca.shift(periods)).into_series() }
332            }
333
334            fn clone_inner(&self) -> Arc<dyn SeriesTrait> {
335                Arc::new(SeriesWrap(Clone::clone(&self.0)))
336            }
337
338            fn min_reduce(&self) -> PolarsResult<Scalar> {
339                Ok(ChunkAggSeries::min_reduce(&self.0))
340            }
341
342            fn max_reduce(&self) -> PolarsResult<Scalar> {
343                Ok(ChunkAggSeries::max_reduce(&self.0))
344            }
345
346            fn find_validity_mismatch(&self, other: &Series, idxs: &mut Vec<IdxSize>) {
347                self.0.physical().find_validity_mismatch(other, idxs)
348            }
349
350            fn as_any(&self) -> &dyn Any {
351                &self.0
352            }
353
354            fn as_any_mut(&mut self) -> &mut dyn Any {
355                &mut self.0
356            }
357
358            fn as_phys_any(&self) -> &dyn Any {
359                self.0.physical()
360            }
361
362            fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
363                self as _
364            }
365        }
366
367        impl private::PrivateSeriesNumeric for SeriesWrap<$ca> {
368            fn bit_repr(&self) -> Option<BitRepr> {
369                Some(self.0.physical().to_bit_repr())
370            }
371        }
372    }
373}
374
375impl_cat_series!(Categorical8Chunked, Categorical8Type, u8);
376impl_cat_series!(Categorical16Chunked, Categorical16Type, u16);
377impl_cat_series!(Categorical32Chunked, Categorical32Type, u32);