polars_utils/
python_function.rs1use polars_error::{PolarsError, polars_bail};
2use pyo3::prelude::*;
3use pyo3::pybacked::PyBackedBytes;
4use pyo3::types::PyBytes;
5#[cfg(feature = "serde")]
6pub use serde_wrap::{
7 PYTHON3_VERSION, PySerializeWrap, SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK,
8 TrySerializeToBytes,
9};
10
11#[derive(Debug)]
13pub struct PythonObject(pub PyObject);
14pub type PythonFunction = PythonObject;
17
18impl std::ops::Deref for PythonObject {
19 type Target = PyObject;
20
21 fn deref(&self) -> &Self::Target {
22 &self.0
23 }
24}
25
26impl std::ops::DerefMut for PythonObject {
27 fn deref_mut(&mut self) -> &mut Self::Target {
28 &mut self.0
29 }
30}
31
32impl Clone for PythonObject {
33 fn clone(&self) -> Self {
34 Python::with_gil(|py| Self(self.0.clone_ref(py)))
35 }
36}
37
38impl From<PyObject> for PythonObject {
39 fn from(value: PyObject) -> Self {
40 Self(value)
41 }
42}
43
44impl Eq for PythonObject {}
45
46impl PartialEq for PythonObject {
47 fn eq(&self, other: &Self) -> bool {
48 Python::with_gil(|py| {
49 let eq = self.0.getattr(py, "__eq__").unwrap();
50 eq.call1(py, (other.0.clone_ref(py),))
51 .unwrap()
52 .extract::<bool>(py)
53 .unwrap_or(false)
55 })
56 }
57}
58
59#[cfg(feature = "serde")]
60mod _serde_impls {
61 use super::{PySerializeWrap, PythonObject, TrySerializeToBytes, serde_wrap};
62 use crate::pl_serialize::deserialize_map_bytes;
63
64 impl PythonObject {
65 pub fn serialize_with_pyversion<T, S>(
66 value: &T,
67 serializer: S,
68 ) -> std::result::Result<S::Ok, S::Error>
69 where
70 T: AsRef<PythonObject>,
71 S: serde::ser::Serializer,
72 {
73 use serde::Serialize;
74 PySerializeWrap(value.as_ref()).serialize(serializer)
75 }
76
77 pub fn deserialize_with_pyversion<'de, T, D>(d: D) -> Result<T, D::Error>
78 where
79 T: From<PythonObject>,
80 D: serde::de::Deserializer<'de>,
81 {
82 use serde::Deserialize;
83 let v: PySerializeWrap<PythonObject> = PySerializeWrap::deserialize(d)?;
84
85 Ok(v.0.into())
86 }
87 }
88
89 impl TrySerializeToBytes for PythonObject {
90 fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult<Vec<u8>> {
91 serde_wrap::serialize_pyobject_with_cloudpickle_fallback(&self.0)
92 }
93
94 fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult<Self> {
95 serde_wrap::deserialize_pyobject_bytes_maybe_cloudpickle(bytes)
96 }
97 }
98
99 impl serde::Serialize for PythonObject {
100 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
101 where
102 S: serde::Serializer,
103 {
104 use serde::ser::Error;
105 let bytes = self
106 .try_serialize_to_bytes()
107 .map_err(|e| S::Error::custom(e.to_string()))?;
108
109 Vec::<u8>::serialize(&bytes, serializer)
110 }
111 }
112
113 impl<'a> serde::Deserialize<'a> for PythonObject {
114 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
115 where
116 D: serde::Deserializer<'a>,
117 {
118 use serde::de::Error;
119 deserialize_map_bytes(deserializer, |bytes| {
120 Self::try_deserialize_bytes(&bytes).map_err(|e| D::Error::custom(e.to_string()))
121 })?
122 }
123 }
124}
125
126#[cfg(feature = "serde")]
127mod serde_wrap {
128 use std::sync::LazyLock;
129
130 use polars_error::PolarsResult;
131
132 use super::*;
133 use crate::config;
134 use crate::pl_serialize::deserialize_map_bytes;
135
136 pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes();
137 pub static PYTHON3_VERSION: LazyLock<[u8; 2]> = LazyLock::new(super::get_python3_version);
139
140 pub trait TrySerializeToBytes: Sized {
143 fn try_serialize_to_bytes(&self) -> PolarsResult<Vec<u8>>;
144 fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult<Self>;
145 }
146
147 pub struct PySerializeWrap<T>(pub T);
150
151 impl<T: TrySerializeToBytes> serde::Serialize for PySerializeWrap<&T> {
152 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
153 where
154 S: serde::Serializer,
155 {
156 use serde::ser::Error;
157 let dumped = self
158 .0
159 .try_serialize_to_bytes()
160 .map_err(|e| S::Error::custom(e.to_string()))?;
161
162 serializer.serialize_bytes(
163 &[SERDE_MAGIC_BYTE_MARK, &*PYTHON3_VERSION, dumped.as_slice()].concat(),
164 )
165 }
166 }
167
168 impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap<T> {
169 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
170 where
171 D: serde::Deserializer<'a>,
172 {
173 use serde::de::Error;
174
175 deserialize_map_bytes(deserializer, |bytes| {
176 let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else {
177 return Err(D::Error::custom(
178 "unexpected EOF when reading serialized pyobject version",
179 ));
180 };
181
182 if magic != SERDE_MAGIC_BYTE_MARK {
183 return Err(D::Error::custom(
184 "serialized pyobject did not begin with magic byte mark",
185 ));
186 }
187
188 let bytes = rem;
189
190 let [a, b, rem @ ..] = bytes else {
191 return Err(D::Error::custom(
192 "unexpected EOF when reading serialized pyobject metadata",
193 ));
194 };
195
196 let py3_version = [*a, *b];
197 let used_cloud_pickle = rem.first();
199
200 if py3_version != *PYTHON3_VERSION && used_cloud_pickle == Some(&1) {
203 return Err(D::Error::custom(format!(
204 "python version that pyobject was serialized with {:?} \
205 differs from system python version {:?}",
206 (3, py3_version[0], py3_version[1]),
207 (3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]),
208 )));
209 }
210
211 let bytes = rem;
212
213 T::try_deserialize_bytes(bytes)
214 .map(Self)
215 .map_err(|e| D::Error::custom(e.to_string()))
216 })?
217 }
218 }
219
220 pub fn serialize_pyobject_with_cloudpickle_fallback(
221 py_object: &PyObject,
222 ) -> PolarsResult<Vec<u8>> {
223 Python::with_gil(|py| {
224 let pickle = PyModule::import(py, "pickle")
225 .expect("unable to import 'pickle'")
226 .getattr("dumps")
227 .unwrap();
228
229 let dumped = pickle.call1((py_object.clone_ref(py),));
230
231 let (dumped, used_cloudpickle) = match dumped {
232 Ok(v) => (v, false),
233 Err(e) => {
234 if config::verbose() {
235 eprintln!(
236 "serialize_pyobject_with_cloudpickle_fallback(): \
237 retrying with cloudpickle due to error: {:?}",
238 e
239 );
240 }
241
242 let cloudpickle = PyModule::import(py, "cloudpickle")?
243 .getattr("dumps")
244 .unwrap();
245 let dumped = cloudpickle.call1((py_object.clone_ref(py),))?;
246 (dumped, true)
247 },
248 };
249
250 let py_bytes = dumped.extract::<PyBackedBytes>()?;
251
252 Ok([&[used_cloudpickle as u8, b'C'][..], py_bytes.as_ref()].concat())
253 })
254 .map_err(from_pyerr)
255 }
256
257 pub fn deserialize_pyobject_bytes_maybe_cloudpickle<T: for<'a> From<PyObject>>(
258 bytes: &[u8],
259 ) -> PolarsResult<T> {
260 let [used_cloudpickle @ 0 | used_cloudpickle @ 1, b'C', rem @ ..] = bytes else {
262 polars_bail!(ComputeError: "deserialize_pyobject_bytes_maybe_cloudpickle: invalid start bytes")
263 };
264
265 let bytes = rem;
266
267 Python::with_gil(|py| {
268 let p = if *used_cloudpickle == 1 {
269 "cloudpickle"
270 } else {
271 "pickle"
272 };
273
274 let pickle = PyModule::import(py, p)
275 .expect("unable to import 'pickle'")
276 .getattr("loads")
277 .unwrap();
278 let arg = (PyBytes::new(py, bytes),);
279 let pyany_bound = pickle.call1(arg)?;
280 Ok(PyObject::from(pyany_bound).into())
281 })
282 .map_err(from_pyerr)
283 }
284}
285
286fn get_python3_version() -> [u8; 2] {
288 Python::with_gil(|py| {
289 let version_info = PyModule::import(py, "sys")
290 .unwrap()
291 .getattr("version_info")
292 .unwrap();
293
294 [
295 version_info.getattr("minor").unwrap().extract().unwrap(),
296 version_info.getattr("micro").unwrap().extract().unwrap(),
297 ]
298 })
299}
300
301fn from_pyerr(e: PyErr) -> PolarsError {
302 PolarsError::ComputeError(format!("error raised in python: {e}").into())
303}