polars_utils/
python_function.rs1use pyo3::BoundObject;
2use pyo3::prelude::*;
3#[cfg(feature = "serde")]
4pub use serde_wrap::{
5 PYTHON3_VERSION, PySerializeWrap, SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK,
6 TrySerializeToBytes,
7};
8
9#[derive(Debug)]
11pub struct PythonObject(pub Py<PyAny>);
12pub type PythonFunction = PythonObject;
15
16impl std::ops::Deref for PythonObject {
17 type Target = Py<PyAny>;
18
19 fn deref(&self) -> &Self::Target {
20 &self.0
21 }
22}
23
24impl std::ops::DerefMut for PythonObject {
25 fn deref_mut(&mut self) -> &mut Self::Target {
26 &mut self.0
27 }
28}
29
30impl std::hash::Hash for PythonObject {
31 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
32 usize::hash(&(self.0.as_ptr() as _), state)
33 }
34}
35
36impl Clone for PythonObject {
37 fn clone(&self) -> Self {
38 Python::attach(|py| Self(self.0.clone_ref(py)))
39 }
40}
41
42impl From<Py<PyAny>> for PythonObject {
43 fn from(value: Py<PyAny>) -> Self {
44 Self(value)
45 }
46}
47
48impl<'a, 'py> FromPyObject<'a, 'py> for PythonObject {
49 type Error = PyErr;
50
51 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
52 Ok(PythonObject(ob.into_bound().unbind()))
53 }
54}
55
56impl<'py> pyo3::conversion::IntoPyObject<'py> for PythonObject {
57 type Target = PyAny;
58 type Output = Bound<'py, Self::Target>;
59 type Error = PyErr;
60
61 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
62 Ok(self.0.into_bound(py))
63 }
64}
65
66impl<'py> pyo3::conversion::IntoPyObject<'py> for &PythonObject {
67 type Target = PyAny;
68 type Output = Bound<'py, Self::Target>;
69 type Error = PyErr;
70
71 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
72 Ok(self.0.bind(py).clone())
73 }
74}
75
76impl Eq for PythonObject {}
77
78impl PartialEq for PythonObject {
79 fn eq(&self, other: &Self) -> bool {
80 Python::attach(|py| {
81 let eq = self.0.getattr(py, "__eq__").unwrap();
82 eq.call1(py, (other.0.clone_ref(py),))
83 .unwrap()
84 .extract::<bool>(py)
85 .unwrap_or(false)
87 })
88 }
89}
90
91#[cfg(feature = "dsl-schema")]
92impl schemars::JsonSchema for PythonObject {
93 fn schema_name() -> std::borrow::Cow<'static, str> {
94 "PythonObject".into()
95 }
96
97 fn schema_id() -> std::borrow::Cow<'static, str> {
98 std::borrow::Cow::Borrowed(concat!(module_path!(), "::", "PythonObject"))
99 }
100
101 fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
102 Vec::<u8>::json_schema(generator)
103 }
104}
105
106#[cfg(feature = "serde")]
107mod _serde_impls {
108 use super::{PySerializeWrap, PythonObject, TrySerializeToBytes};
109 use crate::pl_serialize::deserialize_map_bytes;
110
111 impl PythonObject {
112 pub fn serialize_with_pyversion<T, S>(
113 value: &T,
114 serializer: S,
115 ) -> std::result::Result<S::Ok, S::Error>
116 where
117 T: AsRef<PythonObject>,
118 S: serde::ser::Serializer,
119 {
120 use serde::Serialize;
121 PySerializeWrap(value.as_ref()).serialize(serializer)
122 }
123
124 pub fn deserialize_with_pyversion<'de, T, D>(d: D) -> Result<T, D::Error>
125 where
126 T: From<PythonObject>,
127 D: serde::de::Deserializer<'de>,
128 {
129 use serde::Deserialize;
130 let v: PySerializeWrap<PythonObject> = PySerializeWrap::deserialize(d)?;
131
132 Ok(v.0.into())
133 }
134 }
135
136 impl TrySerializeToBytes for PythonObject {
137 fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult<Vec<u8>> {
138 let mut buf = Vec::new();
139 crate::pl_serialize::python_object_serialize(&self.0, &mut buf)?;
140 Ok(buf)
141 }
142
143 fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult<Self> {
144 crate::pl_serialize::python_object_deserialize(bytes).map(PythonObject)
145 }
146 }
147
148 impl serde::Serialize for PythonObject {
149 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
150 where
151 S: serde::Serializer,
152 {
153 use serde::ser::Error;
154 let bytes = self
155 .try_serialize_to_bytes()
156 .map_err(|e| S::Error::custom(e.to_string()))?;
157
158 Vec::<u8>::serialize(&bytes, serializer)
159 }
160 }
161
162 impl<'a> serde::Deserialize<'a> for PythonObject {
163 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
164 where
165 D: serde::Deserializer<'a>,
166 {
167 use serde::de::Error;
168 deserialize_map_bytes(deserializer, |bytes| {
169 Self::try_deserialize_bytes(&bytes).map_err(|e| D::Error::custom(e.to_string()))
170 })?
171 }
172 }
173}
174
175#[cfg(feature = "serde")]
176mod serde_wrap {
177 use std::sync::LazyLock;
178
179 use polars_error::PolarsResult;
180
181 use crate::pl_serialize::deserialize_map_bytes;
182
183 pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes();
184 pub static PYTHON3_VERSION: LazyLock<[u8; 2]> = LazyLock::new(super::get_python3_version);
186
187 pub trait TrySerializeToBytes: Sized {
190 fn try_serialize_to_bytes(&self) -> PolarsResult<Vec<u8>>;
191 fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult<Self>;
192 }
193
194 pub struct PySerializeWrap<T>(pub T);
197
198 impl<T: TrySerializeToBytes> serde::Serialize for PySerializeWrap<&T> {
199 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
200 where
201 S: serde::Serializer,
202 {
203 use serde::ser::Error;
204 let dumped = self
205 .0
206 .try_serialize_to_bytes()
207 .map_err(|e| S::Error::custom(e.to_string()))?;
208
209 serializer.serialize_bytes(dumped.as_slice())
210 }
211 }
212
213 impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap<T> {
214 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
215 where
216 D: serde::Deserializer<'a>,
217 {
218 use serde::de::Error;
219
220 deserialize_map_bytes(deserializer, |bytes| {
221 T::try_deserialize_bytes(bytes.as_ref())
222 .map(Self)
223 .map_err(|e| D::Error::custom(e.to_string()))
224 })?
225 }
226 }
227}
228
229fn get_python3_version() -> [u8; 2] {
231 Python::attach(|py| {
232 let version_info = PyModule::import(py, "sys")
233 .unwrap()
234 .getattr("version_info")
235 .unwrap();
236
237 [
238 version_info.getattr("minor").unwrap().extract().unwrap(),
239 version_info.getattr("micro").unwrap().extract().unwrap(),
240 ]
241 })
242}