polars_utils/
python_function.rs
1use polars_error::{PolarsError, polars_bail};
2use pyo3::prelude::*;
3use pyo3::pybacked::PyBackedBytes;
4use pyo3::types::PyBytes;
5#[cfg(feature = "serde")]
6pub use serde_wrap::{
7 PYTHON3_VERSION, PySerializeWrap, SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK,
8 TrySerializeToBytes,
9};
10
11#[derive(Debug)]
13pub struct PythonObject(pub PyObject);
14pub type PythonFunction = PythonObject;
17
18impl std::ops::Deref for PythonObject {
19 type Target = PyObject;
20
21 fn deref(&self) -> &Self::Target {
22 &self.0
23 }
24}
25
26impl std::ops::DerefMut for PythonObject {
27 fn deref_mut(&mut self) -> &mut Self::Target {
28 &mut self.0
29 }
30}
31
32impl Clone for PythonObject {
33 fn clone(&self) -> Self {
34 Python::with_gil(|py| Self(self.0.clone_ref(py)))
35 }
36}
37
38impl From<PyObject> for PythonObject {
39 fn from(value: PyObject) -> Self {
40 Self(value)
41 }
42}
43
44impl<'py> pyo3::conversion::IntoPyObject<'py> for PythonObject {
45 type Target = PyAny;
46 type Output = Bound<'py, Self::Target>;
47 type Error = PyErr;
48
49 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
50 Ok(self.0.into_bound(py))
51 }
52}
53
54impl<'py> pyo3::conversion::IntoPyObject<'py> for &PythonObject {
55 type Target = PyAny;
56 type Output = Bound<'py, Self::Target>;
57 type Error = PyErr;
58
59 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
60 Ok(self.0.bind(py).clone())
61 }
62}
63
64impl Eq for PythonObject {}
65
66impl PartialEq for PythonObject {
67 fn eq(&self, other: &Self) -> bool {
68 Python::with_gil(|py| {
69 let eq = self.0.getattr(py, "__eq__").unwrap();
70 eq.call1(py, (other.0.clone_ref(py),))
71 .unwrap()
72 .extract::<bool>(py)
73 .unwrap_or(false)
75 })
76 }
77}
78
79#[cfg(feature = "serde")]
80mod _serde_impls {
81 use super::{PySerializeWrap, PythonObject, TrySerializeToBytes, serde_wrap};
82 use crate::pl_serialize::deserialize_map_bytes;
83
84 impl PythonObject {
85 pub fn serialize_with_pyversion<T, S>(
86 value: &T,
87 serializer: S,
88 ) -> std::result::Result<S::Ok, S::Error>
89 where
90 T: AsRef<PythonObject>,
91 S: serde::ser::Serializer,
92 {
93 use serde::Serialize;
94 PySerializeWrap(value.as_ref()).serialize(serializer)
95 }
96
97 pub fn deserialize_with_pyversion<'de, T, D>(d: D) -> Result<T, D::Error>
98 where
99 T: From<PythonObject>,
100 D: serde::de::Deserializer<'de>,
101 {
102 use serde::Deserialize;
103 let v: PySerializeWrap<PythonObject> = PySerializeWrap::deserialize(d)?;
104
105 Ok(v.0.into())
106 }
107 }
108
109 impl TrySerializeToBytes for PythonObject {
110 fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult<Vec<u8>> {
111 serde_wrap::serialize_pyobject_with_cloudpickle_fallback(&self.0)
112 }
113
114 fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult<Self> {
115 serde_wrap::deserialize_pyobject_bytes_maybe_cloudpickle(bytes)
116 }
117 }
118
119 impl serde::Serialize for PythonObject {
120 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
121 where
122 S: serde::Serializer,
123 {
124 use serde::ser::Error;
125 let bytes = self
126 .try_serialize_to_bytes()
127 .map_err(|e| S::Error::custom(e.to_string()))?;
128
129 Vec::<u8>::serialize(&bytes, serializer)
130 }
131 }
132
133 impl<'a> serde::Deserialize<'a> for PythonObject {
134 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
135 where
136 D: serde::Deserializer<'a>,
137 {
138 use serde::de::Error;
139 deserialize_map_bytes(deserializer, |bytes| {
140 Self::try_deserialize_bytes(&bytes).map_err(|e| D::Error::custom(e.to_string()))
141 })?
142 }
143 }
144}
145
146#[cfg(feature = "serde")]
147mod serde_wrap {
148 use std::sync::LazyLock;
149
150 use polars_error::PolarsResult;
151
152 use super::*;
153 use crate::config;
154 use crate::pl_serialize::deserialize_map_bytes;
155
156 pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes();
157 pub static PYTHON3_VERSION: LazyLock<[u8; 2]> = LazyLock::new(super::get_python3_version);
159
160 pub trait TrySerializeToBytes: Sized {
163 fn try_serialize_to_bytes(&self) -> PolarsResult<Vec<u8>>;
164 fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult<Self>;
165 }
166
167 pub struct PySerializeWrap<T>(pub T);
170
171 impl<T: TrySerializeToBytes> serde::Serialize for PySerializeWrap<&T> {
172 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
173 where
174 S: serde::Serializer,
175 {
176 use serde::ser::Error;
177 let dumped = self
178 .0
179 .try_serialize_to_bytes()
180 .map_err(|e| S::Error::custom(e.to_string()))?;
181
182 serializer.serialize_bytes(
183 &[SERDE_MAGIC_BYTE_MARK, &*PYTHON3_VERSION, dumped.as_slice()].concat(),
184 )
185 }
186 }
187
188 impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap<T> {
189 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
190 where
191 D: serde::Deserializer<'a>,
192 {
193 use serde::de::Error;
194
195 deserialize_map_bytes(deserializer, |bytes| {
196 let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else {
197 return Err(D::Error::custom(
198 "unexpected EOF when reading serialized pyobject version",
199 ));
200 };
201
202 if magic != SERDE_MAGIC_BYTE_MARK {
203 return Err(D::Error::custom(
204 "serialized pyobject did not begin with magic byte mark",
205 ));
206 }
207
208 let bytes = rem;
209
210 let [a, b, rem @ ..] = bytes else {
211 return Err(D::Error::custom(
212 "unexpected EOF when reading serialized pyobject metadata",
213 ));
214 };
215
216 let py3_version = [*a, *b];
217 let used_cloud_pickle = rem.first();
219
220 if py3_version != *PYTHON3_VERSION && used_cloud_pickle == Some(&1) {
223 return Err(D::Error::custom(format!(
224 "python version that pyobject was serialized with {:?} \
225 differs from system python version {:?}",
226 (3, py3_version[0], py3_version[1]),
227 (3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]),
228 )));
229 }
230
231 let bytes = rem;
232
233 T::try_deserialize_bytes(bytes)
234 .map(Self)
235 .map_err(|e| D::Error::custom(e.to_string()))
236 })?
237 }
238 }
239
240 pub fn serialize_pyobject_with_cloudpickle_fallback(
241 py_object: &PyObject,
242 ) -> PolarsResult<Vec<u8>> {
243 Python::with_gil(|py| {
244 let pickle = PyModule::import(py, "pickle")
245 .expect("unable to import 'pickle'")
246 .getattr("dumps")
247 .unwrap();
248
249 let dumped = pickle.call1((py_object.clone_ref(py),));
250
251 let (dumped, used_cloudpickle) = match dumped {
252 Ok(v) => (v, false),
253 Err(e) => {
254 if config::verbose() {
255 eprintln!(
256 "serialize_pyobject_with_cloudpickle_fallback(): \
257 retrying with cloudpickle due to error: {:?}",
258 e
259 );
260 }
261
262 let cloudpickle = PyModule::import(py, "cloudpickle")?
263 .getattr("dumps")
264 .unwrap();
265 let dumped = cloudpickle.call1((py_object.clone_ref(py),))?;
266 (dumped, true)
267 },
268 };
269
270 let py_bytes = dumped.extract::<PyBackedBytes>()?;
271
272 Ok([&[used_cloudpickle as u8, b'C'][..], py_bytes.as_ref()].concat())
273 })
274 .map_err(from_pyerr)
275 }
276
277 pub fn deserialize_pyobject_bytes_maybe_cloudpickle<T: for<'a> From<PyObject>>(
278 bytes: &[u8],
279 ) -> PolarsResult<T> {
280 let [used_cloudpickle @ 0 | used_cloudpickle @ 1, b'C', rem @ ..] = bytes else {
282 polars_bail!(ComputeError: "deserialize_pyobject_bytes_maybe_cloudpickle: invalid start bytes")
283 };
284
285 let bytes = rem;
286
287 Python::with_gil(|py| {
288 let p = if *used_cloudpickle == 1 {
289 "cloudpickle"
290 } else {
291 "pickle"
292 };
293
294 let pickle = PyModule::import(py, p)
295 .expect("unable to import 'pickle'")
296 .getattr("loads")
297 .unwrap();
298 let arg = (PyBytes::new(py, bytes),);
299 let pyany_bound = pickle.call1(arg)?;
300 Ok(PyObject::from(pyany_bound).into())
301 })
302 .map_err(from_pyerr)
303 }
304}
305
306fn get_python3_version() -> [u8; 2] {
308 Python::with_gil(|py| {
309 let version_info = PyModule::import(py, "sys")
310 .unwrap()
311 .getattr("version_info")
312 .unwrap();
313
314 [
315 version_info.getattr("minor").unwrap().extract().unwrap(),
316 version_info.getattr("micro").unwrap().extract().unwrap(),
317 ]
318 })
319}
320
321fn from_pyerr(e: PyErr) -> PolarsError {
322 PolarsError::ComputeError(format!("error raised in python: {e}").into())
323}