polars_core/chunked_array/object/extension/
mod.rs1pub(crate) mod drop;
2pub(super) mod list;
3pub(crate) mod polars_extension;
4
5use std::mem;
6use std::sync::atomic::{AtomicBool, Ordering};
7
8use arrow::array::FixedSizeBinaryArray;
9use arrow::bitmap::BitmapBuilder;
10use arrow::buffer::Buffer;
11use arrow::datatypes::ExtensionType;
12use polars_extension::PolarsExtension;
13use polars_utils::format_pl_smallstr;
14
15use crate::PROCESS_ID;
16use crate::prelude::*;
17
18static POLARS_ALLOW_EXTENSION: AtomicBool = AtomicBool::new(false);
19
20pub fn set_polars_allow_extension(toggle: bool) {
24 POLARS_ALLOW_EXTENSION.store(toggle, Ordering::Relaxed)
25}
26
27unsafe fn create_drop<T: Sized>(mut ptr: *const u8, n_t_vals: usize) -> Box<dyn FnMut()> {
31 Box::new(move || {
32 let t_size = size_of::<T>() as isize;
33 for _ in 0..n_t_vals {
34 let _ = std::ptr::read_unaligned(ptr as *const T);
35 ptr = ptr.offset(t_size)
36 }
37 })
38}
39
40#[allow(clippy::type_complexity)]
41struct ExtensionSentinel {
42 drop_fn: Option<Box<dyn FnMut()>>,
43 pub(crate) to_series_fn: Option<Box<dyn Fn(&FixedSizeBinaryArray, &PlSmallStr) -> Series>>,
46}
47
48impl Drop for ExtensionSentinel {
49 fn drop(&mut self) {
50 let mut drop_fn = self.drop_fn.take().unwrap();
51 drop_fn()
52 }
53}
54
55unsafe fn any_as_u8_slice<T: Sized>(p: &T) -> &[u8] {
58 std::slice::from_raw_parts((p as *const T) as *const u8, size_of::<T>())
59}
60
61pub(crate) fn create_extension<I: Iterator<Item = Option<T>> + TrustedLen, T: Sized + Default>(
64 iter: I,
65) -> PolarsExtension {
66 let env = "POLARS_ALLOW_EXTENSION";
67 if !(POLARS_ALLOW_EXTENSION.load(Ordering::Relaxed) || std::env::var(env).is_ok()) {
68 panic!("creating extension types not allowed - try setting the environment variable {env}")
69 }
70 let t_size = size_of::<T>();
71 let t_alignment = align_of::<T>();
72 let n_t_vals = iter.size_hint().1.unwrap();
73
74 let mut buf = Vec::with_capacity(n_t_vals * t_size);
75 let mut validity = BitmapBuilder::with_capacity(n_t_vals);
76
77 let n_padding = (buf.as_ptr() as usize) % t_alignment;
80 buf.extend(std::iter::repeat_n(0, n_padding));
81
82 for opt_t in iter.into_iter() {
84 match opt_t {
85 Some(t) => {
86 unsafe {
87 buf.extend_from_slice(any_as_u8_slice(&t));
88 validity.push_unchecked(true)
90 }
91 mem::forget(t);
92 },
93 None => {
94 unsafe {
95 buf.extend_from_slice(any_as_u8_slice(&T::default()));
96 validity.push_unchecked(false)
98 }
99 },
100 }
101 }
102
103 let buf: Buffer<u8> = buf.into();
106 let len = buf.len() - n_padding;
107 let buf = buf.sliced(n_padding, len);
108
109 let ptr = buf.as_slice().as_ptr();
111
112 let drop_fn = unsafe { create_drop::<T>(ptr, n_t_vals) };
115 let et = Box::new(ExtensionSentinel {
116 drop_fn: Some(drop_fn),
117 to_series_fn: None,
118 });
119 let et_ptr = &*et as *const ExtensionSentinel;
120 std::mem::forget(et);
121
122 let metadata = format_pl_smallstr!("{};{}", *PROCESS_ID, et_ptr as usize);
123
124 let physical_type = ArrowDataType::FixedSizeBinary(t_size);
125 let extension_type = ArrowDataType::Extension(Box::new(ExtensionType {
126 name: PlSmallStr::from_static(EXTENSION_NAME),
127 inner: physical_type,
128 metadata: Some(metadata),
129 }));
130
131 let array = FixedSizeBinaryArray::new(extension_type, buf, validity.into_opt_validity());
132
133 unsafe { PolarsExtension::new(array) }
135}
136
137#[cfg(test)]
138mod test {
139 use std::fmt::{Display, Formatter};
140 use std::hash::{Hash, Hasher};
141
142 use polars_utils::total_ord::TotalHash;
143
144 use super::*;
145
146 #[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
147 struct Foo {
148 pub a: i32,
149 pub b: u8,
150 pub other_heap: String,
151 }
152
153 impl TotalEq for Foo {
154 fn tot_eq(&self, other: &Self) -> bool {
155 self == other
156 }
157 }
158
159 impl TotalHash for Foo {
160 fn tot_hash<H>(&self, state: &mut H)
161 where
162 H: Hasher,
163 {
164 self.hash(state);
165 }
166 }
167
168 impl Display for Foo {
169 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
170 write!(f, "{:?}", self)
171 }
172 }
173
174 impl PolarsObject for Foo {
175 fn type_name() -> &'static str {
176 "object"
177 }
178 }
179
180 #[test]
181 fn test_create_extension() {
182 set_polars_allow_extension(true);
183 let foo = Foo {
185 a: 1,
186 b: 1,
187 other_heap: "foo".into(),
188 };
189 let foo2 = Foo {
190 a: 1,
191 b: 1,
192 other_heap: "bar".into(),
193 };
194
195 let vals = vec![Some(foo), Some(foo2)];
196 create_extension(vals.into_iter());
197 }
198}