polars_utils/
python_function.rs

1use polars_error::{PolarsError, polars_bail};
2use pyo3::prelude::*;
3use pyo3::pybacked::PyBackedBytes;
4use pyo3::types::PyBytes;
5#[cfg(feature = "serde")]
6pub use serde_wrap::{
7    PYTHON3_VERSION, PySerializeWrap, SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK,
8    TrySerializeToBytes,
9};
10
11/// Wrapper around PyObject from pyo3 with additional trait impls.
12#[derive(Debug)]
13pub struct PythonObject(pub PyObject);
14// Note: We have this because the struct itself used to be called `PythonFunction`, so it's
15// referred to as such from a lot of places.
16pub type PythonFunction = PythonObject;
17
18impl std::ops::Deref for PythonObject {
19    type Target = PyObject;
20
21    fn deref(&self) -> &Self::Target {
22        &self.0
23    }
24}
25
26impl std::ops::DerefMut for PythonObject {
27    fn deref_mut(&mut self) -> &mut Self::Target {
28        &mut self.0
29    }
30}
31
32impl Clone for PythonObject {
33    fn clone(&self) -> Self {
34        Python::with_gil(|py| Self(self.0.clone_ref(py)))
35    }
36}
37
38impl From<PyObject> for PythonObject {
39    fn from(value: PyObject) -> Self {
40        Self(value)
41    }
42}
43
44impl<'py> pyo3::conversion::IntoPyObject<'py> for PythonObject {
45    type Target = PyAny;
46    type Output = Bound<'py, Self::Target>;
47    type Error = PyErr;
48
49    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
50        Ok(self.0.into_bound(py))
51    }
52}
53
54impl<'py> pyo3::conversion::IntoPyObject<'py> for &PythonObject {
55    type Target = PyAny;
56    type Output = Bound<'py, Self::Target>;
57    type Error = PyErr;
58
59    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
60        Ok(self.0.bind(py).clone())
61    }
62}
63
64impl Eq for PythonObject {}
65
66impl PartialEq for PythonObject {
67    fn eq(&self, other: &Self) -> bool {
68        Python::with_gil(|py| {
69            let eq = self.0.getattr(py, "__eq__").unwrap();
70            eq.call1(py, (other.0.clone_ref(py),))
71                .unwrap()
72                .extract::<bool>(py)
73                // equality can be not implemented, so default to false
74                .unwrap_or(false)
75        })
76    }
77}
78
79#[cfg(feature = "serde")]
80mod _serde_impls {
81    use super::{PySerializeWrap, PythonObject, TrySerializeToBytes, serde_wrap};
82    use crate::pl_serialize::deserialize_map_bytes;
83
84    impl PythonObject {
85        pub fn serialize_with_pyversion<T, S>(
86            value: &T,
87            serializer: S,
88        ) -> std::result::Result<S::Ok, S::Error>
89        where
90            T: AsRef<PythonObject>,
91            S: serde::ser::Serializer,
92        {
93            use serde::Serialize;
94            PySerializeWrap(value.as_ref()).serialize(serializer)
95        }
96
97        pub fn deserialize_with_pyversion<'de, T, D>(d: D) -> Result<T, D::Error>
98        where
99            T: From<PythonObject>,
100            D: serde::de::Deserializer<'de>,
101        {
102            use serde::Deserialize;
103            let v: PySerializeWrap<PythonObject> = PySerializeWrap::deserialize(d)?;
104
105            Ok(v.0.into())
106        }
107    }
108
109    impl TrySerializeToBytes for PythonObject {
110        fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult<Vec<u8>> {
111            serde_wrap::serialize_pyobject_with_cloudpickle_fallback(&self.0)
112        }
113
114        fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult<Self> {
115            serde_wrap::deserialize_pyobject_bytes_maybe_cloudpickle(bytes)
116        }
117    }
118
119    impl serde::Serialize for PythonObject {
120        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
121        where
122            S: serde::Serializer,
123        {
124            use serde::ser::Error;
125            let bytes = self
126                .try_serialize_to_bytes()
127                .map_err(|e| S::Error::custom(e.to_string()))?;
128
129            Vec::<u8>::serialize(&bytes, serializer)
130        }
131    }
132
133    impl<'a> serde::Deserialize<'a> for PythonObject {
134        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
135        where
136            D: serde::Deserializer<'a>,
137        {
138            use serde::de::Error;
139            deserialize_map_bytes(deserializer, |bytes| {
140                Self::try_deserialize_bytes(&bytes).map_err(|e| D::Error::custom(e.to_string()))
141            })?
142        }
143    }
144}
145
146#[cfg(feature = "serde")]
147mod serde_wrap {
148    use std::sync::LazyLock;
149
150    use polars_error::PolarsResult;
151
152    use super::*;
153    use crate::config;
154    use crate::pl_serialize::deserialize_map_bytes;
155
156    pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes();
157    /// [minor, micro]
158    pub static PYTHON3_VERSION: LazyLock<[u8; 2]> = LazyLock::new(super::get_python3_version);
159
160    /// Serializes a Python object without additional system metadata. This is intended to be used
161    /// together with `PySerializeWrap`, which attaches e.g. Python version metadata.
162    pub trait TrySerializeToBytes: Sized {
163        fn try_serialize_to_bytes(&self) -> PolarsResult<Vec<u8>>;
164        fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult<Self>;
165    }
166
167    /// Serialization wrapper for T: TrySerializeToBytes that attaches Python
168    /// version metadata.
169    pub struct PySerializeWrap<T>(pub T);
170
171    impl<T: TrySerializeToBytes> serde::Serialize for PySerializeWrap<&T> {
172        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
173        where
174            S: serde::Serializer,
175        {
176            use serde::ser::Error;
177            let dumped = self
178                .0
179                .try_serialize_to_bytes()
180                .map_err(|e| S::Error::custom(e.to_string()))?;
181
182            serializer.serialize_bytes(
183                &[SERDE_MAGIC_BYTE_MARK, &*PYTHON3_VERSION, dumped.as_slice()].concat(),
184            )
185        }
186    }
187
188    impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap<T> {
189        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
190        where
191            D: serde::Deserializer<'a>,
192        {
193            use serde::de::Error;
194
195            deserialize_map_bytes(deserializer, |bytes| {
196                let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else {
197                    return Err(D::Error::custom(
198                        "unexpected EOF when reading serialized pyobject version",
199                    ));
200                };
201
202                if magic != SERDE_MAGIC_BYTE_MARK {
203                    return Err(D::Error::custom(
204                        "serialized pyobject did not begin with magic byte mark",
205                    ));
206                }
207
208                let bytes = rem;
209
210                let [a, b, rem @ ..] = bytes else {
211                    return Err(D::Error::custom(
212                        "unexpected EOF when reading serialized pyobject metadata",
213                    ));
214                };
215
216                let py3_version = [*a, *b];
217                // The validity of cloudpickle is check later when called `try_deserialize`.
218                let used_cloud_pickle = rem.first();
219
220                // Cloudpickle uses bytecode to serialize, which is unstable between versions
221                // So we only allow strict python versions if cloudpickle is used.
222                if py3_version != *PYTHON3_VERSION && used_cloud_pickle == Some(&1) {
223                    return Err(D::Error::custom(format!(
224                        "python version that pyobject was serialized with {:?} \
225                        differs from system python version {:?}",
226                        (3, py3_version[0], py3_version[1]),
227                        (3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]),
228                    )));
229                }
230
231                let bytes = rem;
232
233                T::try_deserialize_bytes(bytes)
234                    .map(Self)
235                    .map_err(|e| D::Error::custom(e.to_string()))
236            })?
237        }
238    }
239
240    pub fn serialize_pyobject_with_cloudpickle_fallback(
241        py_object: &PyObject,
242    ) -> PolarsResult<Vec<u8>> {
243        Python::with_gil(|py| {
244            let pickle = PyModule::import(py, "pickle")
245                .expect("unable to import 'pickle'")
246                .getattr("dumps")
247                .unwrap();
248
249            let dumped = pickle.call1((py_object.clone_ref(py),));
250
251            let (dumped, used_cloudpickle) = match dumped {
252                Ok(v) => (v, false),
253                Err(e) => {
254                    if config::verbose() {
255                        eprintln!(
256                            "serialize_pyobject_with_cloudpickle_fallback(): \
257                            retrying with cloudpickle due to error: {:?}",
258                            e
259                        );
260                    }
261
262                    let cloudpickle = PyModule::import(py, "cloudpickle")?
263                        .getattr("dumps")
264                        .unwrap();
265                    let dumped = cloudpickle.call1((py_object.clone_ref(py),))?;
266                    (dumped, true)
267                },
268            };
269
270            let py_bytes = dumped.extract::<PyBackedBytes>()?;
271
272            Ok([&[used_cloudpickle as u8, b'C'][..], py_bytes.as_ref()].concat())
273        })
274        .map_err(from_pyerr)
275    }
276
277    pub fn deserialize_pyobject_bytes_maybe_cloudpickle<T: for<'a> From<PyObject>>(
278        bytes: &[u8],
279    ) -> PolarsResult<T> {
280        // TODO: Actually deserialize with cloudpickle if it's set.
281        let [used_cloudpickle @ 0 | used_cloudpickle @ 1, b'C', rem @ ..] = bytes else {
282            polars_bail!(ComputeError: "deserialize_pyobject_bytes_maybe_cloudpickle: invalid start bytes")
283        };
284
285        let bytes = rem;
286
287        Python::with_gil(|py| {
288            let p = if *used_cloudpickle == 1 {
289                "cloudpickle"
290            } else {
291                "pickle"
292            };
293
294            let pickle = PyModule::import(py, p)
295                .expect("unable to import 'pickle'")
296                .getattr("loads")
297                .unwrap();
298            let arg = (PyBytes::new(py, bytes),);
299            let pyany_bound = pickle.call1(arg)?;
300            Ok(PyObject::from(pyany_bound).into())
301        })
302        .map_err(from_pyerr)
303    }
304}
305
306/// Get the [minor, micro] Python3 version from the `sys` module.
307fn get_python3_version() -> [u8; 2] {
308    Python::with_gil(|py| {
309        let version_info = PyModule::import(py, "sys")
310            .unwrap()
311            .getattr("version_info")
312            .unwrap();
313
314        [
315            version_info.getattr("minor").unwrap().extract().unwrap(),
316            version_info.getattr("micro").unwrap().extract().unwrap(),
317        ]
318    })
319}
320
321fn from_pyerr(e: PyErr) -> PolarsError {
322    PolarsError::ComputeError(format!("error raised in python: {e}").into())
323}