polars_core/chunked_array/object/extension/
mod.rs

1pub(crate) mod drop;
2pub(super) mod list;
3pub(crate) mod polars_extension;
4
5use std::mem;
6
7use arrow::array::FixedSizeBinaryArray;
8use arrow::bitmap::BitmapBuilder;
9use arrow::datatypes::ExtensionType;
10use polars_buffer::Buffer;
11use polars_extension::PolarsExtension;
12use polars_utils::format_pl_smallstr;
13use polars_utils::relaxed_cell::RelaxedCell;
14
15use crate::PROCESS_ID;
16use crate::prelude::*;
17
18static POLARS_ALLOW_EXTENSION: RelaxedCell<bool> = RelaxedCell::new_bool(false);
19
20/// Control whether extension types may be created.
21///
22/// If the environment variable POLARS_ALLOW_EXTENSION is set, this function has no effect.
23pub fn set_polars_allow_extension(toggle: bool) {
24    POLARS_ALLOW_EXTENSION.store(toggle)
25}
26
27/// Invariants
28/// `ptr` must point to start a `T` allocation
29/// `n_t_vals` must represent the correct number of `T` values in that allocation
30unsafe fn create_drop<T: Sized>(mut ptr: *const u8, n_t_vals: usize) -> Box<dyn FnMut()> {
31    Box::new(move || {
32        let t_size = size_of::<T>() as isize;
33        for _ in 0..n_t_vals {
34            let _ = std::ptr::read_unaligned(ptr as *const T);
35            ptr = ptr.offset(t_size)
36        }
37    })
38}
39
40#[allow(clippy::type_complexity)]
41struct ExtensionSentinel {
42    drop_fn: Option<Box<dyn FnMut()>>,
43    // A function on the heap that take a `array: FixedSizeBinary` and a `name: PlSmallStr`
44    // and returns a `Series` of `ObjectChunked<T>`
45    pub(crate) to_series_fn: Option<Box<dyn Fn(&FixedSizeBinaryArray, &PlSmallStr) -> Series>>,
46}
47
48impl Drop for ExtensionSentinel {
49    fn drop(&mut self) {
50        let mut drop_fn = self.drop_fn.take().unwrap();
51        drop_fn()
52    }
53}
54
55// https://stackoverflow.com/questions/28127165/how-to-convert-struct-to-u8d
56// not entirely sure if padding bytes in T are initialized or not.
57unsafe fn any_as_u8_slice<T: Sized>(p: &T) -> &[u8] {
58    std::slice::from_raw_parts((p as *const T) as *const u8, size_of::<T>())
59}
60
61/// Create an extension Array that can be sent to arrow and (once wrapped in [`PolarsExtension`] will
62/// also call drop on `T`, when the array is dropped.
63pub(crate) fn create_extension<I: Iterator<Item = Option<T>> + TrustedLen, T: Sized + Default>(
64    iter: I,
65) -> PolarsExtension {
66    let env = "POLARS_ALLOW_EXTENSION";
67    if !(POLARS_ALLOW_EXTENSION.load() || std::env::var(env).is_ok()) {
68        panic!("creating extension types not allowed - try setting the environment variable {env}")
69    }
70    let t_size = size_of::<T>();
71    let t_alignment = align_of::<T>();
72    let n_t_vals = iter.size_hint().1.unwrap();
73
74    let mut buf = Vec::with_capacity(n_t_vals * t_size);
75    let mut validity = BitmapBuilder::with_capacity(n_t_vals);
76
77    // when we transmute from &[u8] to T, T must be aligned correctly,
78    // so we pad with bytes until the alignment matches
79    let n_padding = (buf.as_ptr() as usize) % t_alignment;
80    buf.extend(std::iter::repeat_n(0, n_padding));
81
82    // transmute T as bytes and copy in buffer
83    for opt_t in iter.into_iter() {
84        match opt_t {
85            Some(t) => {
86                unsafe {
87                    buf.extend_from_slice(any_as_u8_slice(&t));
88                    // SAFETY: we allocated upfront
89                    validity.push_unchecked(true)
90                }
91                mem::forget(t);
92            },
93            None => {
94                unsafe {
95                    buf.extend_from_slice(any_as_u8_slice(&T::default()));
96                    // SAFETY: we allocated upfront
97                    validity.push_unchecked(false)
98                }
99            },
100        }
101    }
102
103    // We slice the buffer because we want to ignore the padding bytes from here
104    // they can be forgotten.
105    let buf: Buffer<u8> = Buffer::from_vec(buf).sliced(n_padding..);
106    // ptr to start of T, not to start of padding
107    let ptr = buf.as_slice().as_ptr();
108
109    // SAFETY: ptr and t are correct.
110    let drop_fn = unsafe { create_drop::<T>(ptr, n_t_vals) };
111    let et = Box::new(ExtensionSentinel {
112        drop_fn: Some(drop_fn),
113        to_series_fn: None,
114    });
115    let et_ptr = &*et as *const ExtensionSentinel;
116    std::mem::forget(et);
117
118    let metadata = format_pl_smallstr!("{};{}", *PROCESS_ID, et_ptr as usize);
119
120    let physical_type = ArrowDataType::FixedSizeBinary(t_size);
121    let extension_type = ArrowDataType::Extension(Box::new(ExtensionType {
122        name: PlSmallStr::from_static(POLARS_OBJECT_EXTENSION_NAME),
123        inner: physical_type,
124        metadata: Some(metadata),
125    }));
126
127    let array = FixedSizeBinaryArray::new(extension_type, buf, validity.into_opt_validity());
128
129    // SAFETY: we just heap allocated the ExtensionSentinel, so its alive.
130    unsafe { PolarsExtension::new(array) }
131}
132
133#[cfg(test)]
134mod test {
135    use std::fmt::{Display, Formatter};
136    use std::hash::{Hash, Hasher};
137
138    use polars_utils::total_ord::TotalHash;
139
140    use super::*;
141
142    #[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
143    struct Foo {
144        pub a: i32,
145        pub b: u8,
146        pub other_heap: String,
147    }
148
149    impl TotalEq for Foo {
150        fn tot_eq(&self, other: &Self) -> bool {
151            self == other
152        }
153    }
154
155    impl TotalHash for Foo {
156        fn tot_hash<H>(&self, state: &mut H)
157        where
158            H: Hasher,
159        {
160            self.hash(state);
161        }
162    }
163
164    impl Display for Foo {
165        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
166            write!(f, "{self:?}")
167        }
168    }
169
170    impl PolarsObject for Foo {
171        fn type_name() -> &'static str {
172            "object"
173        }
174    }
175
176    #[test]
177    fn test_create_extension() {
178        set_polars_allow_extension(true);
179        // Run this under MIRI.
180        let foo = Foo {
181            a: 1,
182            b: 1,
183            other_heap: "foo".into(),
184        };
185        let foo2 = Foo {
186            a: 1,
187            b: 1,
188            other_heap: "bar".into(),
189        };
190
191        let vals = vec![Some(foo), Some(foo2)];
192        create_extension(vals.into_iter());
193    }
194}