polars_utils/
python_function.rsuse polars_error::{polars_bail, PolarsError, PolarsResult};
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedBytes;
use pyo3::types::PyBytes;
#[cfg(feature = "serde")]
pub use serde_wrap::{
PySerializeWrap, TrySerializeToBytes, PYTHON3_VERSION,
SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK,
};
#[derive(Debug)]
pub struct PythonFunction(pub PyObject);
impl Clone for PythonFunction {
fn clone(&self) -> Self {
Python::with_gil(|py| Self(self.0.clone_ref(py)))
}
}
impl From<PyObject> for PythonFunction {
fn from(value: PyObject) -> Self {
Self(value)
}
}
impl Eq for PythonFunction {}
impl PartialEq for PythonFunction {
fn eq(&self, other: &Self) -> bool {
Python::with_gil(|py| {
let eq = self.0.getattr(py, "__eq__").unwrap();
eq.call1(py, (other.0.clone_ref(py),))
.unwrap()
.extract::<bool>(py)
.unwrap_or(false)
})
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for PythonFunction {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::Error;
serializer.serialize_bytes(
self.try_serialize_to_bytes()
.map_err(|e| S::Error::custom(e.to_string()))?
.as_slice(),
)
}
}
#[cfg(feature = "serde")]
impl<'a> serde::Deserialize<'a> for PythonFunction {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'a>,
{
use serde::de::Error;
let bytes = Vec::<u8>::deserialize(deserializer)?;
Self::try_deserialize_bytes(bytes.as_slice()).map_err(|e| D::Error::custom(e.to_string()))
}
}
#[cfg(feature = "serde")]
impl TrySerializeToBytes for PythonFunction {
fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult<Vec<u8>> {
serialize_pyobject_with_cloudpickle_fallback(&self.0)
}
fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult<Self> {
deserialize_pyobject_bytes_maybe_cloudpickle(bytes)
}
}
pub fn serialize_pyobject_with_cloudpickle_fallback(py_object: &PyObject) -> PolarsResult<Vec<u8>> {
Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("dumps")
.unwrap();
let dumped = pickle.call1((py_object.clone_ref(py),));
let (dumped, used_cloudpickle) = if let Ok(v) = dumped {
(v, false)
} else {
let cloudpickle = PyModule::import_bound(py, "cloudpickle")
.map_err(from_pyerr)?
.getattr("dumps")
.unwrap();
let dumped = cloudpickle
.call1((py_object.clone_ref(py),))
.map_err(from_pyerr)?;
(dumped, true)
};
let py_bytes = dumped.extract::<PyBackedBytes>().map_err(from_pyerr)?;
Ok([&[used_cloudpickle as u8, b'C'][..], py_bytes.as_ref()].concat())
})
}
pub fn deserialize_pyobject_bytes_maybe_cloudpickle<T: for<'a> From<PyObject>>(
bytes: &[u8],
) -> PolarsResult<T> {
let [_used_cloudpickle @ 0 | _used_cloudpickle @ 1, b'C', rem @ ..] = bytes else {
polars_bail!(ComputeError: "deserialize_pyobject_bytes_maybe_cloudpickle: invalid start bytes")
};
let bytes = rem;
Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("loads")
.unwrap();
let arg = (PyBytes::new_bound(py, bytes),);
let pyany_bound = pickle.call1(arg).map_err(from_pyerr)?;
Ok(PyObject::from(pyany_bound).into())
})
}
#[cfg(feature = "serde")]
mod serde_wrap {
use once_cell::sync::Lazy;
use polars_error::PolarsResult;
pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes();
pub static PYTHON3_VERSION: Lazy<[u8; 2]> = Lazy::new(super::get_python3_version);
pub trait TrySerializeToBytes: Sized {
fn try_serialize_to_bytes(&self) -> PolarsResult<Vec<u8>>;
fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult<Self>;
}
pub struct PySerializeWrap<T>(pub T);
impl<T: TrySerializeToBytes> serde::Serialize for PySerializeWrap<&T> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::Error;
let dumped = self
.0
.try_serialize_to_bytes()
.map_err(|e| S::Error::custom(e.to_string()))?;
serializer.serialize_bytes(
&[SERDE_MAGIC_BYTE_MARK, &*PYTHON3_VERSION, dumped.as_slice()].concat(),
)
}
}
impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap<T> {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'a>,
{
use serde::de::Error;
let bytes = Vec::<u8>::deserialize(deserializer)?;
let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else {
return Err(D::Error::custom(
"unexpected EOF when reading serialized pyobject version",
));
};
if magic != SERDE_MAGIC_BYTE_MARK {
return Err(D::Error::custom(
"serialized pyobject did not begin with magic byte mark",
));
}
let bytes = rem;
let [a, b, rem @ ..] = bytes else {
return Err(D::Error::custom(
"unexpected EOF when reading serialized pyobject metadata",
));
};
let py3_version = [*a, *b];
if py3_version != *PYTHON3_VERSION {
return Err(D::Error::custom(format!(
"python version that pyobject was serialized with {:?} \
differs from system python version {:?}",
(3, py3_version[0], py3_version[1]),
(3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]),
)));
}
let bytes = rem;
T::try_deserialize_bytes(bytes)
.map(Self)
.map_err(|e| D::Error::custom(e.to_string()))
}
}
}
fn get_python3_version() -> [u8; 2] {
Python::with_gil(|py| {
let version_info = PyModule::import_bound(py, "sys")
.unwrap()
.getattr("version_info")
.unwrap();
[
version_info.getattr("minor").unwrap().extract().unwrap(),
version_info.getattr("micro").unwrap().extract().unwrap(),
]
})
}
fn from_pyerr(e: PyErr) -> PolarsError {
PolarsError::ComputeError(format!("error raised in python: {e}").into())
}