polars_utils/
python_function.rs

1use polars_error::{PolarsError, polars_bail};
2use pyo3::prelude::*;
3use pyo3::pybacked::PyBackedBytes;
4use pyo3::types::PyBytes;
5#[cfg(feature = "serde")]
6pub use serde_wrap::{
7    PYTHON3_VERSION, PySerializeWrap, SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK,
8    TrySerializeToBytes,
9};
10
11/// Wrapper around PyObject from pyo3 with additional trait impls.
12#[derive(Debug)]
13pub struct PythonObject(pub PyObject);
14// Note: We have this because the struct itself used to be called `PythonFunction`, so it's
15// referred to as such from a lot of places.
16pub type PythonFunction = PythonObject;
17
18impl std::ops::Deref for PythonObject {
19    type Target = PyObject;
20
21    fn deref(&self) -> &Self::Target {
22        &self.0
23    }
24}
25
26impl std::ops::DerefMut for PythonObject {
27    fn deref_mut(&mut self) -> &mut Self::Target {
28        &mut self.0
29    }
30}
31
32impl Clone for PythonObject {
33    fn clone(&self) -> Self {
34        Python::with_gil(|py| Self(self.0.clone_ref(py)))
35    }
36}
37
38impl From<PyObject> for PythonObject {
39    fn from(value: PyObject) -> Self {
40        Self(value)
41    }
42}
43
44impl Eq for PythonObject {}
45
46impl PartialEq for PythonObject {
47    fn eq(&self, other: &Self) -> bool {
48        Python::with_gil(|py| {
49            let eq = self.0.getattr(py, "__eq__").unwrap();
50            eq.call1(py, (other.0.clone_ref(py),))
51                .unwrap()
52                .extract::<bool>(py)
53                // equality can be not implemented, so default to false
54                .unwrap_or(false)
55        })
56    }
57}
58
59#[cfg(feature = "serde")]
60mod _serde_impls {
61    use super::{PySerializeWrap, PythonObject, TrySerializeToBytes, serde_wrap};
62    use crate::pl_serialize::deserialize_map_bytes;
63
64    impl PythonObject {
65        pub fn serialize_with_pyversion<T, S>(
66            value: &T,
67            serializer: S,
68        ) -> std::result::Result<S::Ok, S::Error>
69        where
70            T: AsRef<PythonObject>,
71            S: serde::ser::Serializer,
72        {
73            use serde::Serialize;
74            PySerializeWrap(value.as_ref()).serialize(serializer)
75        }
76
77        pub fn deserialize_with_pyversion<'de, T, D>(d: D) -> Result<T, D::Error>
78        where
79            T: From<PythonObject>,
80            D: serde::de::Deserializer<'de>,
81        {
82            use serde::Deserialize;
83            let v: PySerializeWrap<PythonObject> = PySerializeWrap::deserialize(d)?;
84
85            Ok(v.0.into())
86        }
87    }
88
89    impl TrySerializeToBytes for PythonObject {
90        fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult<Vec<u8>> {
91            serde_wrap::serialize_pyobject_with_cloudpickle_fallback(&self.0)
92        }
93
94        fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult<Self> {
95            serde_wrap::deserialize_pyobject_bytes_maybe_cloudpickle(bytes)
96        }
97    }
98
99    impl serde::Serialize for PythonObject {
100        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
101        where
102            S: serde::Serializer,
103        {
104            use serde::ser::Error;
105            let bytes = self
106                .try_serialize_to_bytes()
107                .map_err(|e| S::Error::custom(e.to_string()))?;
108
109            Vec::<u8>::serialize(&bytes, serializer)
110        }
111    }
112
113    impl<'a> serde::Deserialize<'a> for PythonObject {
114        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
115        where
116            D: serde::Deserializer<'a>,
117        {
118            use serde::de::Error;
119            deserialize_map_bytes(deserializer, |bytes| {
120                Self::try_deserialize_bytes(&bytes).map_err(|e| D::Error::custom(e.to_string()))
121            })?
122        }
123    }
124}
125
126#[cfg(feature = "serde")]
127mod serde_wrap {
128    use std::sync::LazyLock;
129
130    use polars_error::PolarsResult;
131
132    use super::*;
133    use crate::config;
134    use crate::pl_serialize::deserialize_map_bytes;
135
136    pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes();
137    /// [minor, micro]
138    pub static PYTHON3_VERSION: LazyLock<[u8; 2]> = LazyLock::new(super::get_python3_version);
139
140    /// Serializes a Python object without additional system metadata. This is intended to be used
141    /// together with `PySerializeWrap`, which attaches e.g. Python version metadata.
142    pub trait TrySerializeToBytes: Sized {
143        fn try_serialize_to_bytes(&self) -> PolarsResult<Vec<u8>>;
144        fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult<Self>;
145    }
146
147    /// Serialization wrapper for T: TrySerializeToBytes that attaches Python
148    /// version metadata.
149    pub struct PySerializeWrap<T>(pub T);
150
151    impl<T: TrySerializeToBytes> serde::Serialize for PySerializeWrap<&T> {
152        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
153        where
154            S: serde::Serializer,
155        {
156            use serde::ser::Error;
157            let dumped = self
158                .0
159                .try_serialize_to_bytes()
160                .map_err(|e| S::Error::custom(e.to_string()))?;
161
162            serializer.serialize_bytes(
163                &[SERDE_MAGIC_BYTE_MARK, &*PYTHON3_VERSION, dumped.as_slice()].concat(),
164            )
165        }
166    }
167
168    impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap<T> {
169        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
170        where
171            D: serde::Deserializer<'a>,
172        {
173            use serde::de::Error;
174
175            deserialize_map_bytes(deserializer, |bytes| {
176                let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else {
177                    return Err(D::Error::custom(
178                        "unexpected EOF when reading serialized pyobject version",
179                    ));
180                };
181
182                if magic != SERDE_MAGIC_BYTE_MARK {
183                    return Err(D::Error::custom(
184                        "serialized pyobject did not begin with magic byte mark",
185                    ));
186                }
187
188                let bytes = rem;
189
190                let [a, b, rem @ ..] = bytes else {
191                    return Err(D::Error::custom(
192                        "unexpected EOF when reading serialized pyobject metadata",
193                    ));
194                };
195
196                let py3_version = [*a, *b];
197                // The validity of cloudpickle is check later when called `try_deserialize`.
198                let used_cloud_pickle = rem.first();
199
200                // Cloudpickle uses bytecode to serialize, which is unstable between versions
201                // So we only allow strict python versions if cloudpickle is used.
202                if py3_version != *PYTHON3_VERSION && used_cloud_pickle == Some(&1) {
203                    return Err(D::Error::custom(format!(
204                        "python version that pyobject was serialized with {:?} \
205                        differs from system python version {:?}",
206                        (3, py3_version[0], py3_version[1]),
207                        (3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]),
208                    )));
209                }
210
211                let bytes = rem;
212
213                T::try_deserialize_bytes(bytes)
214                    .map(Self)
215                    .map_err(|e| D::Error::custom(e.to_string()))
216            })?
217        }
218    }
219
220    pub fn serialize_pyobject_with_cloudpickle_fallback(
221        py_object: &PyObject,
222    ) -> PolarsResult<Vec<u8>> {
223        Python::with_gil(|py| {
224            let pickle = PyModule::import(py, "pickle")
225                .expect("unable to import 'pickle'")
226                .getattr("dumps")
227                .unwrap();
228
229            let dumped = pickle.call1((py_object.clone_ref(py),));
230
231            let (dumped, used_cloudpickle) = match dumped {
232                Ok(v) => (v, false),
233                Err(e) => {
234                    if config::verbose() {
235                        eprintln!(
236                            "serialize_pyobject_with_cloudpickle_fallback(): \
237                            retrying with cloudpickle due to error: {:?}",
238                            e
239                        );
240                    }
241
242                    let cloudpickle = PyModule::import(py, "cloudpickle")?
243                        .getattr("dumps")
244                        .unwrap();
245                    let dumped = cloudpickle.call1((py_object.clone_ref(py),))?;
246                    (dumped, true)
247                },
248            };
249
250            let py_bytes = dumped.extract::<PyBackedBytes>()?;
251
252            Ok([&[used_cloudpickle as u8, b'C'][..], py_bytes.as_ref()].concat())
253        })
254        .map_err(from_pyerr)
255    }
256
257    pub fn deserialize_pyobject_bytes_maybe_cloudpickle<T: for<'a> From<PyObject>>(
258        bytes: &[u8],
259    ) -> PolarsResult<T> {
260        // TODO: Actually deserialize with cloudpickle if it's set.
261        let [used_cloudpickle @ 0 | used_cloudpickle @ 1, b'C', rem @ ..] = bytes else {
262            polars_bail!(ComputeError: "deserialize_pyobject_bytes_maybe_cloudpickle: invalid start bytes")
263        };
264
265        let bytes = rem;
266
267        Python::with_gil(|py| {
268            let p = if *used_cloudpickle == 1 {
269                "cloudpickle"
270            } else {
271                "pickle"
272            };
273
274            let pickle = PyModule::import(py, p)
275                .expect("unable to import 'pickle'")
276                .getattr("loads")
277                .unwrap();
278            let arg = (PyBytes::new(py, bytes),);
279            let pyany_bound = pickle.call1(arg)?;
280            Ok(PyObject::from(pyany_bound).into())
281        })
282        .map_err(from_pyerr)
283    }
284}
285
286/// Get the [minor, micro] Python3 version from the `sys` module.
287fn get_python3_version() -> [u8; 2] {
288    Python::with_gil(|py| {
289        let version_info = PyModule::import(py, "sys")
290            .unwrap()
291            .getattr("version_info")
292            .unwrap();
293
294        [
295            version_info.getattr("minor").unwrap().extract().unwrap(),
296            version_info.getattr("micro").unwrap().extract().unwrap(),
297        ]
298    })
299}
300
301fn from_pyerr(e: PyErr) -> PolarsError {
302    PolarsError::ComputeError(format!("error raised in python: {e}").into())
303}