Skip to main content

polars_utils/
python_function.rs

1use pyo3::BoundObject;
2use pyo3::prelude::*;
3#[cfg(feature = "serde")]
4pub use serde_wrap::{
5    PYTHON3_VERSION, PySerializeWrap, SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK,
6    TrySerializeToBytes,
7};
8
9/// Wrapper around PyObject from pyo3 with additional trait impls.
10#[derive(Debug)]
11pub struct PythonObject(pub Py<PyAny>);
12// Note: We have this because the struct itself used to be called `PythonFunction`, so it's
13// referred to as such from a lot of places.
14pub type PythonFunction = PythonObject;
15
16impl std::ops::Deref for PythonObject {
17    type Target = Py<PyAny>;
18
19    fn deref(&self) -> &Self::Target {
20        &self.0
21    }
22}
23
24impl std::ops::DerefMut for PythonObject {
25    fn deref_mut(&mut self) -> &mut Self::Target {
26        &mut self.0
27    }
28}
29
30impl std::hash::Hash for PythonObject {
31    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
32        usize::hash(&(self.0.as_ptr() as _), state)
33    }
34}
35
36impl Clone for PythonObject {
37    fn clone(&self) -> Self {
38        Python::attach(|py| Self(self.0.clone_ref(py)))
39    }
40}
41
42impl From<Py<PyAny>> for PythonObject {
43    fn from(value: Py<PyAny>) -> Self {
44        Self(value)
45    }
46}
47
48impl<'a, 'py> FromPyObject<'a, 'py> for PythonObject {
49    type Error = PyErr;
50
51    fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
52        Ok(PythonObject(ob.into_bound().unbind()))
53    }
54}
55
56impl<'py> pyo3::conversion::IntoPyObject<'py> for PythonObject {
57    type Target = PyAny;
58    type Output = Bound<'py, Self::Target>;
59    type Error = PyErr;
60
61    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
62        Ok(self.0.into_bound(py))
63    }
64}
65
66impl<'py> pyo3::conversion::IntoPyObject<'py> for &PythonObject {
67    type Target = PyAny;
68    type Output = Bound<'py, Self::Target>;
69    type Error = PyErr;
70
71    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
72        Ok(self.0.bind(py).clone())
73    }
74}
75
76impl Eq for PythonObject {}
77
78impl PartialEq for PythonObject {
79    fn eq(&self, other: &Self) -> bool {
80        Python::attach(|py| {
81            let eq = self.0.getattr(py, "__eq__").unwrap();
82            eq.call1(py, (other.0.clone_ref(py),))
83                .unwrap()
84                .extract::<bool>(py)
85                // equality can be not implemented, so default to false
86                .unwrap_or(false)
87        })
88    }
89}
90
91#[cfg(feature = "dsl-schema")]
92impl schemars::JsonSchema for PythonObject {
93    fn schema_name() -> std::borrow::Cow<'static, str> {
94        "PythonObject".into()
95    }
96
97    fn schema_id() -> std::borrow::Cow<'static, str> {
98        std::borrow::Cow::Borrowed(concat!(module_path!(), "::", "PythonObject"))
99    }
100
101    fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
102        Vec::<u8>::json_schema(generator)
103    }
104}
105
106#[cfg(feature = "serde")]
107mod _serde_impls {
108    use super::{PySerializeWrap, PythonObject, TrySerializeToBytes};
109    use crate::pl_serialize::deserialize_map_bytes;
110
111    impl PythonObject {
112        pub fn serialize_with_pyversion<T, S>(
113            value: &T,
114            serializer: S,
115        ) -> std::result::Result<S::Ok, S::Error>
116        where
117            T: AsRef<PythonObject>,
118            S: serde::ser::Serializer,
119        {
120            use serde::Serialize;
121            PySerializeWrap(value.as_ref()).serialize(serializer)
122        }
123
124        pub fn deserialize_with_pyversion<'de, T, D>(d: D) -> Result<T, D::Error>
125        where
126            T: From<PythonObject>,
127            D: serde::de::Deserializer<'de>,
128        {
129            use serde::Deserialize;
130            let v: PySerializeWrap<PythonObject> = PySerializeWrap::deserialize(d)?;
131
132            Ok(v.0.into())
133        }
134    }
135
136    impl TrySerializeToBytes for PythonObject {
137        fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult<Vec<u8>> {
138            let mut buf = Vec::new();
139            crate::pl_serialize::python_object_serialize(&self.0, &mut buf)?;
140            Ok(buf)
141        }
142
143        fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult<Self> {
144            crate::pl_serialize::python_object_deserialize(bytes).map(PythonObject)
145        }
146    }
147
148    impl serde::Serialize for PythonObject {
149        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
150        where
151            S: serde::Serializer,
152        {
153            use serde::ser::Error;
154            let bytes = self
155                .try_serialize_to_bytes()
156                .map_err(|e| S::Error::custom(e.to_string()))?;
157
158            Vec::<u8>::serialize(&bytes, serializer)
159        }
160    }
161
162    impl<'a> serde::Deserialize<'a> for PythonObject {
163        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
164        where
165            D: serde::Deserializer<'a>,
166        {
167            use serde::de::Error;
168            deserialize_map_bytes(deserializer, |bytes| {
169                Self::try_deserialize_bytes(&bytes).map_err(|e| D::Error::custom(e.to_string()))
170            })?
171        }
172    }
173}
174
175#[cfg(feature = "serde")]
176mod serde_wrap {
177    use std::sync::LazyLock;
178
179    use polars_error::PolarsResult;
180
181    use crate::pl_serialize::deserialize_map_bytes;
182
183    pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes();
184    /// [minor, micro]
185    pub static PYTHON3_VERSION: LazyLock<[u8; 2]> = LazyLock::new(super::get_python3_version);
186
187    /// Serializes a Python object without additional system metadata. This is intended to be used
188    /// together with `PySerializeWrap`, which attaches e.g. Python version metadata.
189    pub trait TrySerializeToBytes: Sized {
190        fn try_serialize_to_bytes(&self) -> PolarsResult<Vec<u8>>;
191        fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult<Self>;
192    }
193
194    /// Serialization wrapper for T: TrySerializeToBytes that attaches Python
195    /// version metadata.
196    pub struct PySerializeWrap<T>(pub T);
197
198    impl<T: TrySerializeToBytes> serde::Serialize for PySerializeWrap<&T> {
199        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
200        where
201            S: serde::Serializer,
202        {
203            use serde::ser::Error;
204            let dumped = self
205                .0
206                .try_serialize_to_bytes()
207                .map_err(|e| S::Error::custom(e.to_string()))?;
208
209            serializer.serialize_bytes(dumped.as_slice())
210        }
211    }
212
213    impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap<T> {
214        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
215        where
216            D: serde::Deserializer<'a>,
217        {
218            use serde::de::Error;
219
220            deserialize_map_bytes(deserializer, |bytes| {
221                T::try_deserialize_bytes(bytes.as_ref())
222                    .map(Self)
223                    .map_err(|e| D::Error::custom(e.to_string()))
224            })?
225        }
226    }
227}
228
229/// Get the [minor, micro] Python3 version from the `sys` module.
230fn get_python3_version() -> [u8; 2] {
231    Python::attach(|py| {
232        let version_info = PyModule::import(py, "sys")
233            .unwrap()
234            .getattr("version_info")
235            .unwrap();
236
237        [
238            version_info.getattr("minor").unwrap().extract().unwrap(),
239            version_info.getattr("micro").unwrap().extract().unwrap(),
240        ]
241    })
242}