Skip to main content

polars_core/series/implementations/
categorical.rs

1use super::*;
2use crate::prelude::*;
3
4unsafe impl<T: PolarsCategoricalType> IntoSeries for CategoricalChunked<T> {
5    fn into_series(self) -> Series {
6        // We do this hack to go from generic T to concrete T to avoid adding bounds on IntoSeries.
7        with_match_categorical_physical_type!(T::physical(), |$C| {
8            unsafe {
9                Series(Arc::new(SeriesWrap(core::mem::transmute::<Self, CategoricalChunked<$C>>(self))))
10            }
11        })
12    }
13}
14
15impl<T: PolarsCategoricalType> SeriesWrap<CategoricalChunked<T>> {
16    unsafe fn apply_on_phys<F>(&self, apply: F) -> CategoricalChunked<T>
17    where
18        F: FnOnce(&ChunkedArray<T::PolarsPhysical>) -> ChunkedArray<T::PolarsPhysical>,
19    {
20        let cats = apply(self.0.physical());
21        unsafe { CategoricalChunked::from_cats_and_dtype_unchecked(cats, self.0.dtype().clone()) }
22    }
23
24    unsafe fn try_apply_on_phys<F>(&self, apply: F) -> PolarsResult<CategoricalChunked<T>>
25    where
26        F: FnOnce(
27            &ChunkedArray<T::PolarsPhysical>,
28        ) -> PolarsResult<ChunkedArray<T::PolarsPhysical>>,
29    {
30        let cats = apply(self.0.physical())?;
31        unsafe {
32            Ok(CategoricalChunked::from_cats_and_dtype_unchecked(
33                cats,
34                self.0.dtype().clone(),
35            ))
36        }
37    }
38}
39
40macro_rules! impl_cat_series {
41    ($ca: ident, $pdt:ty, $ca_fn:ident) => {
42        impl private::PrivateSeries for SeriesWrap<$ca> {
43            fn compute_len(&mut self) {
44                self.0.physical_mut().compute_len()
45            }
46            fn _field(&self) -> Cow<'_, Field> {
47                Cow::Owned(self.0.field())
48            }
49            fn _dtype(&self) -> &DataType {
50                self.0.dtype()
51            }
52            fn _get_flags(&self) -> StatisticsFlags {
53                self.0.get_flags()
54            }
55            fn _set_flags(&mut self, flags: StatisticsFlags) {
56                self.0.set_flags(flags)
57            }
58
59            #[cfg(feature = "zip_with")]
60            fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult<Series> {
61                polars_ensure!(self.dtype() == other.dtype(), SchemaMismatch: "expected '{}' found '{}'", self.dtype(), other.dtype());
62                let other = other.to_physical_repr().into_owned();
63                unsafe {
64                    Ok(self.try_apply_on_phys(|ca| {
65                        ca.zip_with(mask, other.as_ref().as_ref())
66                    })?.into_series())
67                }
68            }
69
70            fn into_total_ord_inner<'a>(&'a self) -> Box<dyn TotalOrdInner + 'a> {
71                if self.0.uses_lexical_ordering() {
72                    (&self.0).into_total_ord_inner()
73                } else {
74                    self.0.physical().into_total_ord_inner()
75                }
76            }
77            fn into_total_eq_inner<'a>(&'a self) -> Box<dyn TotalEqInner + 'a> {
78                invalid_operation_panic!(into_total_eq_inner, self)
79            }
80
81            fn vec_hash(
82                &self,
83                random_state: PlSeedableRandomStateQuality,
84                buf: &mut Vec<u64>,
85            ) -> PolarsResult<()> {
86                self.0.vec_hash(random_state, buf)
87            }
88
89            fn vec_hash_combine(
90                &self,
91                build_hasher: PlSeedableRandomStateQuality,
92                hashes: &mut [u64],
93            ) -> PolarsResult<()> {
94                self.0.vec_hash_combine(build_hasher, hashes)
95            }
96
97            #[cfg(feature = "algorithm_group_by")]
98            unsafe fn agg_min(&self, groups: &GroupsType) -> Series {
99                if self.0.uses_lexical_ordering() {
100                    unsafe { self.0.agg_min(groups) }
101                } else {
102                    self.apply_on_phys(|phys| phys.agg_min(groups).$ca_fn().unwrap().clone())
103                        .into_series()
104                }
105            }
106
107            #[cfg(feature = "algorithm_group_by")]
108            unsafe fn agg_max(&self, groups: &GroupsType) -> Series {
109                if self.0.uses_lexical_ordering() {
110                    unsafe { self.0.agg_max(groups) }
111                } else {
112                    self.apply_on_phys(|phys| phys.agg_max(groups).$ca_fn().unwrap().clone())
113                        .into_series()
114                }
115            }
116
117            #[cfg(feature = "algorithm_group_by")]
118            unsafe fn agg_arg_min(&self, groups: &GroupsType) -> Series {
119                if self.0.uses_lexical_ordering() {
120                    unsafe { self.0.agg_arg_min(groups) }
121                } else {
122                    self.0.physical().agg_arg_min(groups)
123                }
124            }
125
126            #[cfg(feature = "algorithm_group_by")]
127            unsafe fn agg_arg_max(&self, groups: &GroupsType) -> Series {
128                if self.0.uses_lexical_ordering() {
129                    unsafe { self.0.agg_arg_max(groups) }
130                } else {
131                    self.0.physical().agg_arg_max(groups)
132                }
133            }
134
135
136            #[cfg(feature = "algorithm_group_by")]
137            unsafe fn agg_list(&self, groups: &GroupsType) -> Series {
138                // we cannot cast and dispatch as the inner type of the list would be incorrect
139                let list = self.0.physical().agg_list(groups);
140                let mut list = list.list().unwrap().clone();
141                unsafe { list.to_logical(self.dtype().clone()) };
142                list.into_series()
143            }
144
145            #[cfg(feature = "algorithm_group_by")]
146            fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsType> {
147                self.0.physical().group_tuples(multithreaded, sorted)
148            }
149
150            fn arg_sort_multiple(
151                &self,
152                by: &[Column],
153                options: &SortMultipleOptions,
154            ) -> PolarsResult<IdxCa> {
155                self.0.arg_sort_multiple(by, options)
156            }
157        }
158
159        impl SeriesTrait for SeriesWrap<$ca> {
160            fn rename(&mut self, name: PlSmallStr) {
161                self.0.physical_mut().rename(name);
162            }
163
164            fn chunk_lengths(&self) -> ChunkLenIter<'_> {
165                self.0.physical().chunk_lengths()
166            }
167
168            fn name(&self) -> &PlSmallStr {
169                self.0.physical().name()
170            }
171
172            fn chunks(&self) -> &Vec<ArrayRef> {
173                self.0.physical().chunks()
174            }
175
176            unsafe fn chunks_mut(&mut self) -> &mut Vec<ArrayRef> {
177                self.0.physical_mut().chunks_mut()
178            }
179
180            fn shrink_to_fit(&mut self) {
181                self.0.physical_mut().shrink_to_fit()
182            }
183
184            fn slice(&self, offset: i64, length: usize) -> Series {
185                unsafe { self.apply_on_phys(|cats| cats.slice(offset, length)).into_series() }
186            }
187
188            fn split_at(&self, offset: i64) -> (Series, Series) {
189                unsafe {
190                    let (a, b) = self.0.physical().split_at(offset);
191                    let a = <$ca>::from_cats_and_dtype_unchecked(a, self.0.dtype().clone()).into_series();
192                    let b = <$ca>::from_cats_and_dtype_unchecked(b, self.0.dtype().clone()).into_series();
193                    (a, b)
194                }
195            }
196
197            fn append(&mut self, other: &Series) -> PolarsResult<()> {
198                polars_ensure!(self.0.dtype() == other.dtype(), append);
199                self.0.append(other.cat::<$pdt>().unwrap())
200            }
201
202            fn append_owned(&mut self, mut other: Series) -> PolarsResult<()> {
203                polars_ensure!(self.0.dtype() == other.dtype(), append);
204                self.0.physical_mut().append_owned(std::mem::take(
205                    other
206                        ._get_inner_mut()
207                        .as_any_mut()
208                        .downcast_mut::<$ca>()
209                        .unwrap()
210                        .physical_mut(),
211                ))
212            }
213
214            fn extend(&mut self, other: &Series) -> PolarsResult<()> {
215                polars_ensure!(self.0.dtype() == other.dtype(), extend);
216                self.0.extend(other.cat::<$pdt>().unwrap())
217            }
218
219            fn filter(&self, filter: &BooleanChunked) -> PolarsResult<Series> {
220                unsafe { Ok(self.try_apply_on_phys(|cats| cats.filter(filter))?.into_series()) }
221            }
222
223            fn take(&self, indices: &IdxCa) -> PolarsResult<Series> {
224                unsafe { Ok(self.try_apply_on_phys(|cats| cats.take(indices))?.into_series() ) }
225            }
226
227            unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series {
228                unsafe { self.apply_on_phys(|cats| cats.take_unchecked(indices)).into_series() }
229            }
230
231            fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult<Series> {
232                unsafe { Ok(self.try_apply_on_phys(|cats| cats.take(indices))?.into_series()) }
233            }
234
235            unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series {
236                unsafe { self.apply_on_phys(|cats| cats.take_unchecked(indices)).into_series() }
237            }
238
239            fn deposit(&self, validity: &Bitmap) -> Series {
240                unsafe { self.apply_on_phys(|cats| cats.deposit(validity)) }
241                    .into_series()
242            }
243
244            fn len(&self) -> usize {
245                self.0.len()
246            }
247
248            fn rechunk(&self) -> Series {
249                unsafe { self.apply_on_phys(|cats| cats.rechunk().into_owned()).into_series() }
250            }
251
252            fn with_validity(&self, validity: Option<Bitmap>) -> Series {
253                unsafe { self.apply_on_phys(move |cats| cats.clone().with_validity(validity)).into_series() }
254            }
255
256            fn new_from_index(&self, index: usize, length: usize) -> Series {
257                unsafe { self.apply_on_phys(|cats| cats.new_from_index(index, length)).into_series() }
258            }
259
260            fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult<Series> {
261                self.0.cast_with_options(dtype, options)
262            }
263
264            #[inline]
265            unsafe fn get_unchecked(&self, index: usize) -> AnyValue<'_> {
266                self.0.get_any_value_unchecked(index)
267            }
268
269            fn sort_with(&self, options: SortOptions) -> PolarsResult<Series> {
270                Ok(self.0.sort_with(options).into_series())
271            }
272
273            fn arg_sort(&self, options: SortOptions) -> IdxCa {
274                self.0.arg_sort(options)
275            }
276
277            fn null_count(&self) -> usize {
278                self.0.physical().null_count()
279            }
280
281            fn has_nulls(&self) -> bool {
282                self.0.physical().has_nulls()
283            }
284
285            #[cfg(feature = "algorithm_group_by")]
286            fn unique(&self) -> PolarsResult<Series> {
287                unsafe { Ok(self.try_apply_on_phys(|cats| cats.unique())?.into_series()) }
288            }
289
290            #[cfg(feature = "algorithm_group_by")]
291            fn n_unique(&self) -> PolarsResult<usize> {
292                self.0.physical().n_unique()
293            }
294
295            #[cfg(feature = "approx_unique")]
296            fn approx_n_unique(&self) -> PolarsResult<IdxSize> {
297                Ok(self.0.physical().approx_n_unique())
298            }
299
300            #[cfg(feature = "algorithm_group_by")]
301            fn arg_unique(&self) -> PolarsResult<IdxCa> {
302                self.0.physical().arg_unique()
303            }
304
305            fn unique_id(&self) -> PolarsResult<(IdxSize, Vec<IdxSize>)> {
306                ChunkUnique::unique_id(self.0.physical())
307            }
308
309            fn is_null(&self) -> BooleanChunked {
310                self.0.physical().is_null()
311            }
312
313            fn is_not_null(&self) -> BooleanChunked {
314                self.0.physical().is_not_null()
315            }
316
317            fn reverse(&self) -> Series {
318                unsafe { self.apply_on_phys(|cats| cats.reverse()).into_series() }
319            }
320
321            fn as_single_ptr(&mut self) -> PolarsResult<usize> {
322                self.0.physical_mut().as_single_ptr()
323            }
324
325            fn shift(&self, periods: i64) -> Series {
326                unsafe { self.apply_on_phys(|ca| ca.shift(periods)).into_series() }
327            }
328
329            fn clone_inner(&self) -> Arc<dyn SeriesTrait> {
330                Arc::new(SeriesWrap(Clone::clone(&self.0)))
331            }
332
333            fn min_reduce(&self) -> PolarsResult<Scalar> {
334                Ok(ChunkAggSeries::min_reduce(&self.0))
335            }
336
337            fn max_reduce(&self) -> PolarsResult<Scalar> {
338                Ok(ChunkAggSeries::max_reduce(&self.0))
339            }
340
341            fn find_validity_mismatch(&self, other: &Series, idxs: &mut Vec<IdxSize>) {
342                self.0.physical().find_validity_mismatch(other, idxs)
343            }
344
345            fn as_any(&self) -> &dyn Any {
346                &self.0
347            }
348
349            fn as_any_mut(&mut self) -> &mut dyn Any {
350                &mut self.0
351            }
352
353            fn as_phys_any(&self) -> &dyn Any {
354                self.0.physical()
355            }
356
357            fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
358                self as _
359            }
360        }
361
362        impl private::PrivateSeriesNumeric for SeriesWrap<$ca> {
363            fn bit_repr(&self) -> Option<BitRepr> {
364                Some(self.0.physical().to_bit_repr())
365            }
366        }
367    }
368}
369
370impl_cat_series!(Categorical8Chunked, Categorical8Type, u8);
371impl_cat_series!(Categorical16Chunked, Categorical16Type, u16);
372impl_cat_series!(Categorical32Chunked, Categorical32Type, u32);