Skip to main content

polars_core/chunked_array/ops/
bit_repr.rs

1use polars_buffer::Buffer;
2use polars_error::feature_gated;
3
4use crate::prelude::*;
5use crate::series::BitRepr;
6
7/// Reinterprets the type of a [`ChunkedArray`]. T and U must have the same size
8/// and alignment.
9fn reinterpret_chunked_array<T: PolarsNumericType, U: PolarsNumericType>(
10    ca: &ChunkedArray<T>,
11) -> ChunkedArray<U> {
12    assert!(size_of::<T::Native>() == size_of::<U::Native>());
13    assert!(align_of::<T::Native>() == align_of::<U::Native>());
14
15    let chunks = ca.downcast_iter().map(|array| {
16        let buf = array.values().clone();
17        let reinterpreted_buf = Buffer::try_transmute::<U::Native>(buf).unwrap();
18        PrimitiveArray::from_data_default(reinterpreted_buf, array.validity().cloned())
19    });
20
21    ChunkedArray::from_chunk_iter(ca.name().clone(), chunks)
22}
23
24impl<T> ToBitRepr for ChunkedArray<T>
25where
26    T: PolarsNumericType,
27{
28    fn to_bit_repr(&self) -> BitRepr {
29        match size_of::<T::Native>() {
30            16 => {
31                feature_gated!("dtype-u128", {
32                    if matches!(self.dtype(), DataType::UInt128) {
33                        let ca: &UInt128Chunked = self.as_any().downcast_ref().unwrap();
34                        return BitRepr::U128(ca.clone());
35                    }
36
37                    BitRepr::U128(reinterpret_chunked_array(self))
38                })
39            },
40
41            8 => {
42                if matches!(self.dtype(), DataType::UInt64) {
43                    let ca: &UInt64Chunked = self.as_any().downcast_ref().unwrap();
44                    return BitRepr::U64(ca.clone());
45                }
46
47                BitRepr::U64(reinterpret_chunked_array(self))
48            },
49
50            4 => {
51                if matches!(self.dtype(), DataType::UInt32) {
52                    let ca: &UInt32Chunked = self.as_any().downcast_ref().unwrap();
53                    return BitRepr::U32(ca.clone());
54                }
55
56                BitRepr::U32(reinterpret_chunked_array(self))
57            },
58
59            2 => {
60                if matches!(self.dtype(), DataType::UInt16) {
61                    let ca: &UInt16Chunked = self.as_any().downcast_ref().unwrap();
62                    return BitRepr::U16(ca.clone());
63                }
64
65                BitRepr::U16(reinterpret_chunked_array(self))
66            },
67
68            1 => {
69                if matches!(self.dtype(), DataType::UInt8) {
70                    let ca: &UInt8Chunked = self.as_any().downcast_ref().unwrap();
71                    return BitRepr::U8(ca.clone());
72                }
73
74                BitRepr::U8(reinterpret_chunked_array(self))
75            },
76
77            _ => unreachable!(),
78        }
79    }
80}
81
82pub fn reinterpret(s: &Series, dtype: &DataType) -> PolarsResult<Series> {
83    if s.dtype().is_numeric() && s.dtype() == dtype {
84        return Ok(s.clone());
85    }
86
87    Ok(match (s.dtype(), dtype) {
88        #[cfg(all(feature = "dtype-u8", feature = "dtype-i8"))]
89        (DataType::UInt8, DataType::Int8) => {
90            reinterpret_chunked_array::<_, Int8Type>(s.u8().unwrap()).into_series()
91        },
92        #[cfg(all(feature = "dtype-u16", feature = "dtype-i16"))]
93        (DataType::UInt16, DataType::Int16) => {
94            reinterpret_chunked_array::<_, Int16Type>(s.u16().unwrap()).into_series()
95        },
96        (DataType::UInt32, DataType::Int32) => {
97            reinterpret_chunked_array::<_, Int32Type>(s.u32().unwrap()).into_series()
98        },
99        (DataType::UInt64, DataType::Int64) => {
100            reinterpret_chunked_array::<_, Int64Type>(s.u64().unwrap()).into_series()
101        },
102        #[cfg(all(feature = "dtype-u128", feature = "dtype-i128"))]
103        (DataType::UInt128, DataType::Int128) => {
104            reinterpret_chunked_array::<_, Int128Type>(s.u128().unwrap()).into_series()
105        },
106
107        #[cfg(all(feature = "dtype-u16", feature = "dtype-f16"))]
108        (DataType::UInt16, DataType::Float16) => {
109            reinterpret_chunked_array::<_, Float16Type>(s.u16().unwrap()).into_series()
110        },
111        (DataType::UInt32, DataType::Float32) => {
112            reinterpret_chunked_array::<_, Float32Type>(s.u32().unwrap()).into_series()
113        },
114        (DataType::UInt64, DataType::Float64) => {
115            reinterpret_chunked_array::<_, Float64Type>(s.u64().unwrap()).into_series()
116        },
117
118        #[cfg(all(feature = "dtype-i8", feature = "dtype-u8"))]
119        (DataType::Int8, DataType::UInt8) => {
120            reinterpret_chunked_array::<_, UInt8Type>(s.i8().unwrap()).into_series()
121        },
122        #[cfg(all(feature = "dtype-i16", feature = "dtype-u16"))]
123        (DataType::Int16, DataType::UInt16) => {
124            reinterpret_chunked_array::<_, UInt16Type>(s.i16().unwrap()).into_series()
125        },
126        (DataType::Int32, DataType::UInt32) => {
127            reinterpret_chunked_array::<_, UInt32Type>(s.i32().unwrap()).into_series()
128        },
129        (DataType::Int64, DataType::UInt64) => {
130            reinterpret_chunked_array::<_, UInt64Type>(s.i64().unwrap()).into_series()
131        },
132        #[cfg(all(feature = "dtype-i128", feature = "dtype-u128"))]
133        (DataType::Int128, DataType::UInt128) => {
134            reinterpret_chunked_array::<_, UInt128Type>(s.i128().unwrap()).into_series()
135        },
136
137        #[cfg(all(feature = "dtype-i16", feature = "dtype-f16"))]
138        (DataType::Int16, DataType::Float16) => {
139            reinterpret_chunked_array::<_, Float16Type>(s.i16().unwrap()).into_series()
140        },
141        (DataType::Int32, DataType::Float32) => {
142            reinterpret_chunked_array::<_, Float32Type>(s.i32().unwrap()).into_series()
143        },
144        (DataType::Int64, DataType::Float64) => {
145            reinterpret_chunked_array::<_, Float64Type>(s.i64().unwrap()).into_series()
146        },
147
148        #[cfg(all(feature = "dtype-f16", feature = "dtype-u16"))]
149        (DataType::Float16, DataType::UInt16) => {
150            reinterpret_chunked_array::<_, UInt16Type>(s.f16().unwrap()).into_series()
151        },
152        (DataType::Float32, DataType::UInt32) => {
153            reinterpret_chunked_array::<_, UInt32Type>(s.f32().unwrap()).into_series()
154        },
155        (DataType::Float64, DataType::UInt64) => {
156            reinterpret_chunked_array::<_, UInt64Type>(s.f64().unwrap()).into_series()
157        },
158
159        #[cfg(all(feature = "dtype-f16", feature = "dtype-i16"))]
160        (DataType::Float16, DataType::Int16) => {
161            reinterpret_chunked_array::<_, Int16Type>(s.f16().unwrap()).into_series()
162        },
163        (DataType::Float32, DataType::Int32) => {
164            reinterpret_chunked_array::<_, Int32Type>(s.f32().unwrap()).into_series()
165        },
166        (DataType::Float64, DataType::Int64) => {
167            reinterpret_chunked_array::<_, Int64Type>(s.f64().unwrap()).into_series()
168        },
169
170        (source_dtype, target_dtype) => polars_bail!(
171            ComputeError:
172            "cannot reinterpret from {source_dtype:?} to {target_dtype:?}"
173        ),
174    })
175}
176
177#[cfg(feature = "dtype-f16")]
178impl UInt16Chunked {
179    #[doc(hidden)]
180    pub fn _reinterpret_float(&self) -> Float16Chunked {
181        reinterpret_chunked_array(self)
182    }
183}
184
185impl UInt32Chunked {
186    #[doc(hidden)]
187    pub fn _reinterpret_float(&self) -> Float32Chunked {
188        reinterpret_chunked_array(self)
189    }
190}
191
192impl UInt64Chunked {
193    #[doc(hidden)]
194    pub fn _reinterpret_float(&self) -> Float64Chunked {
195        reinterpret_chunked_array(self)
196    }
197}
198
199/// Used to save compilation paths. Use carefully. Although this is safe,
200/// if misused it can lead to incorrect results.
201#[cfg(feature = "dtype-f16")]
202impl Float16Chunked {
203    pub fn apply_as_ints<F>(&self, f: F) -> Series
204    where
205        F: Fn(&Series) -> Series,
206    {
207        let BitRepr::U16(s) = self.to_bit_repr() else {
208            unreachable!()
209        };
210        let s = s.into_series();
211        let out = f(&s);
212        let out = out.u16().unwrap();
213        out._reinterpret_float().into()
214    }
215}
216
217impl Float32Chunked {
218    pub fn apply_as_ints<F>(&self, f: F) -> Series
219    where
220        F: Fn(&Series) -> Series,
221    {
222        let BitRepr::U32(s) = self.to_bit_repr() else {
223            unreachable!()
224        };
225        let s = s.into_series();
226        let out = f(&s);
227        let out = out.u32().unwrap();
228        out._reinterpret_float().into()
229    }
230}
231impl Float64Chunked {
232    pub fn apply_as_ints<F>(&self, f: F) -> Series
233    where
234        F: Fn(&Series) -> Series,
235    {
236        let BitRepr::U64(s) = self.to_bit_repr() else {
237            unreachable!()
238        };
239        let s = s.into_series();
240        let out = f(&s);
241        let out = out.u64().unwrap();
242        out._reinterpret_float().into()
243    }
244}