polars_utils/
pl_serialize.rs

1//! Centralized Polars serialization entry.
2//!
3//! Currently provides two serialization scheme's.
4//! - Self-describing (and thus more forward compatible) activated with `FC: true`
5//! - Compact activated with `FC: false`
6use polars_error::{PolarsResult, to_compute_err};
7
8fn config() -> bincode::config::Configuration {
9    bincode::config::standard()
10        .with_no_limit()
11        .with_variable_int_encoding()
12}
13
14fn serialize_impl<W, T, const FC: bool>(mut writer: W, value: &T) -> PolarsResult<()>
15where
16    W: std::io::Write,
17    T: serde::ser::Serialize,
18{
19    if FC {
20        let mut s = rmp_serde::Serializer::new(writer).with_struct_map();
21        value.serialize(&mut s).map_err(to_compute_err)
22    } else {
23        bincode::serde::encode_into_std_write(value, &mut writer, config())
24            .map_err(to_compute_err)
25            .map(|_| ())
26    }
27}
28
29pub fn deserialize_impl<T, R, const FC: bool>(mut reader: R) -> PolarsResult<T>
30where
31    T: serde::de::DeserializeOwned,
32    R: std::io::Read,
33{
34    if FC {
35        rmp_serde::from_read(reader).map_err(to_compute_err)
36    } else {
37        bincode::serde::decode_from_std_read(&mut reader, config()).map_err(to_compute_err)
38    }
39}
40
41/// Mainly used to enable compression when serializing the final outer value.
42/// For intermediate serialization steps, the function in the module should
43/// be used instead.
44pub struct SerializeOptions {
45    compression: bool,
46}
47
48impl SerializeOptions {
49    pub fn with_compression(mut self, compression: bool) -> Self {
50        self.compression = compression;
51        self
52    }
53
54    pub fn serialize_into_writer<W, T, const FC: bool>(
55        &self,
56        writer: W,
57        value: &T,
58    ) -> PolarsResult<()>
59    where
60        W: std::io::Write,
61        T: serde::ser::Serialize,
62    {
63        if self.compression {
64            let writer = flate2::write::ZlibEncoder::new(writer, flate2::Compression::fast());
65            serialize_impl::<_, _, FC>(writer, value)
66        } else {
67            serialize_impl::<_, _, FC>(writer, value)
68        }
69    }
70
71    pub fn deserialize_from_reader<T, R, const FC: bool>(&self, reader: R) -> PolarsResult<T>
72    where
73        T: serde::de::DeserializeOwned,
74        R: std::io::Read,
75    {
76        if self.compression {
77            deserialize_impl::<_, _, FC>(flate2::read::ZlibDecoder::new(reader))
78        } else {
79            deserialize_impl::<_, _, FC>(reader)
80        }
81    }
82
83    pub fn serialize_to_bytes<T, const FC: bool>(&self, value: &T) -> PolarsResult<Vec<u8>>
84    where
85        T: serde::ser::Serialize,
86    {
87        let mut v = vec![];
88
89        self.serialize_into_writer::<_, _, FC>(&mut v, value)?;
90
91        Ok(v)
92    }
93}
94
95#[allow(clippy::derivable_impls)]
96impl Default for SerializeOptions {
97    fn default() -> Self {
98        Self { compression: false }
99    }
100}
101
102pub fn serialize_into_writer<W, T, const FC: bool>(writer: W, value: &T) -> PolarsResult<()>
103where
104    W: std::io::Write,
105    T: serde::ser::Serialize,
106{
107    serialize_impl::<_, _, FC>(writer, value)
108}
109
110pub fn deserialize_from_reader<T, R, const FC: bool>(reader: R) -> PolarsResult<T>
111where
112    T: serde::de::DeserializeOwned,
113    R: std::io::Read,
114{
115    deserialize_impl::<_, _, FC>(reader)
116}
117
118pub fn serialize_to_bytes<T, const FC: bool>(value: &T) -> PolarsResult<Vec<u8>>
119where
120    T: serde::ser::Serialize,
121{
122    let mut v = vec![];
123
124    serialize_into_writer::<_, _, FC>(&mut v, value)?;
125
126    Ok(v)
127}
128
129/// Serialize function customized for `DslPlan`, with stack overflow protection.
130pub fn serialize_dsl<W, T>(writer: W, value: &T) -> PolarsResult<()>
131where
132    W: std::io::Write,
133    T: serde::ser::Serialize,
134{
135    let mut s = rmp_serde::Serializer::new(writer).with_struct_map();
136    let s = serde_stacker::Serializer::new(&mut s);
137    value.serialize(s).map_err(to_compute_err)
138}
139
140/// Deserialize function customized for `DslPlan`, with stack overflow protection.
141pub fn deserialize_dsl<T, R>(reader: R) -> PolarsResult<T>
142where
143    T: serde::de::DeserializeOwned,
144    R: std::io::Read,
145{
146    let mut de = rmp_serde::Deserializer::new(reader);
147    de.set_max_depth(usize::MAX);
148    let de = serde_stacker::Deserializer::new(&mut de);
149    T::deserialize(de).map_err(to_compute_err)
150}
151
152/// Potentially avoids copying memory compared to a naive `Vec::<u8>::deserialize`.
153///
154/// This is essentially boilerplate for visiting bytes without copying where possible.
155pub fn deserialize_map_bytes<'de, D, O>(
156    deserializer: D,
157    mut func: impl for<'b> FnMut(std::borrow::Cow<'b, [u8]>) -> O,
158) -> Result<O, D::Error>
159where
160    D: serde::de::Deserializer<'de>,
161{
162    // Lets us avoid monomorphizing the visitor
163    let mut out: Option<O> = None;
164    struct V<'f>(&'f mut dyn for<'b> FnMut(std::borrow::Cow<'b, [u8]>));
165
166    deserializer.deserialize_bytes(V(&mut |v| drop(out.replace(func(v)))))?;
167
168    return Ok(out.unwrap());
169
170    impl<'de> serde::de::Visitor<'de> for V<'_> {
171        type Value = ();
172
173        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
174            formatter.write_str("deserialize_map_bytes")
175        }
176
177        fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
178        where
179            E: serde::de::Error,
180        {
181            self.0(std::borrow::Cow::Borrowed(v));
182            Ok(())
183        }
184
185        fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
186        where
187            E: serde::de::Error,
188        {
189            self.0(std::borrow::Cow::Owned(v));
190            Ok(())
191        }
192
193        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
194        where
195            A: serde::de::SeqAccess<'de>,
196        {
197            // This is not ideal, but we hit here if the serialization format is JSON.
198            let bytes = std::iter::from_fn(|| seq.next_element::<u8>().transpose())
199                .collect::<Result<Vec<_>, A::Error>>()?;
200
201            self.0(std::borrow::Cow::Owned(bytes));
202            Ok(())
203        }
204    }
205}
206
207thread_local! {
208    pub static USE_CLOUDPICKLE: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
209}
210
211#[cfg(feature = "python")]
212pub fn python_object_serialize(
213    pyobj: &pyo3::Py<pyo3::PyAny>,
214    buf: &mut Vec<u8>,
215) -> PolarsResult<()> {
216    use pyo3::Python;
217    use pyo3::pybacked::PyBackedBytes;
218    use pyo3::types::{PyAnyMethods, PyModule};
219
220    use crate::python_function::PYTHON3_VERSION;
221
222    let mut use_cloudpickle = USE_CLOUDPICKLE.get();
223    let dumped = Python::with_gil(|py| {
224        // Pickle with whatever pickling method was selected.
225        if use_cloudpickle {
226            let cloudpickle = PyModule::import(py, "cloudpickle")?.getattr("dumps")?;
227            cloudpickle.call1((pyobj.clone_ref(py),))?
228        } else {
229            let pickle = PyModule::import(py, "pickle")?.getattr("dumps")?;
230            match pickle.call1((pyobj.clone_ref(py),)) {
231                Ok(dumped) => dumped,
232                Err(_) => {
233                    use_cloudpickle = true;
234                    let cloudpickle = PyModule::import(py, "cloudpickle")?.getattr("dumps")?;
235                    cloudpickle.call1((pyobj.clone_ref(py),))?
236                },
237            }
238        }
239        .extract::<PyBackedBytes>()
240    })?;
241
242    // Write pickle metadata
243    buf.push(use_cloudpickle as u8);
244    buf.extend_from_slice(&*PYTHON3_VERSION);
245
246    // Write UDF
247    buf.extend_from_slice(&dumped);
248    Ok(())
249}
250
251#[cfg(feature = "python")]
252pub fn python_object_deserialize(buf: &[u8]) -> PolarsResult<pyo3::Py<pyo3::PyAny>> {
253    use polars_error::polars_ensure;
254    use pyo3::Python;
255    use pyo3::types::{PyAnyMethods, PyBytes, PyModule};
256
257    use crate::python_function::PYTHON3_VERSION;
258
259    // Handle pickle metadata
260    let use_cloudpickle = buf[0] != 0;
261    if use_cloudpickle {
262        let ser_py_version = &buf[1..3];
263        let cur_py_version = *PYTHON3_VERSION;
264        polars_ensure!(
265            ser_py_version == cur_py_version,
266            InvalidOperation:
267            "current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
268            (3, cur_py_version[0], cur_py_version[1]),
269            (3, ser_py_version[0], ser_py_version[1] )
270        );
271    }
272    let buf = &buf[3..];
273
274    Python::with_gil(|py| {
275        let loads = PyModule::import(py, "pickle")?.getattr("loads")?;
276        let arg = (PyBytes::new(py, buf),);
277        let python_function = loads.call1(arg)?;
278        Ok(python_function.into())
279    })
280}
281
282#[cfg(test)]
283mod tests {
284    #[test]
285    fn test_serde_skip_enum() {
286        #[derive(Default, Debug, PartialEq)]
287        struct MyType(Option<usize>);
288
289        // Note: serde(skip) must be at the end of enums
290        #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
291        enum Enum {
292            A,
293            #[serde(skip)]
294            B(MyType),
295        }
296
297        impl Default for Enum {
298            fn default() -> Self {
299                Self::B(MyType(None))
300            }
301        }
302
303        let v = Enum::A;
304        let b = super::serialize_to_bytes::<_, false>(&v).unwrap();
305        let r: Enum = super::deserialize_from_reader::<_, _, false>(b.as_slice()).unwrap();
306
307        assert_eq!(r, v);
308
309        let v = Enum::A;
310        let b = super::SerializeOptions::default()
311            .serialize_to_bytes::<_, false>(&v)
312            .unwrap();
313        let r: Enum = super::SerializeOptions::default()
314            .deserialize_from_reader::<_, _, false>(b.as_slice())
315            .unwrap();
316
317        assert_eq!(r, v);
318    }
319}