1use polars_error::{PolarsResult, to_compute_err};
7
8fn config() -> bincode::config::Configuration {
9 bincode::config::standard()
10 .with_no_limit()
11 .with_variable_int_encoding()
12}
13
14fn serialize_impl<W, T, const FC: bool>(mut writer: W, value: &T) -> PolarsResult<()>
15where
16 W: std::io::Write,
17 T: serde::ser::Serialize,
18{
19 if FC {
20 let mut s = rmp_serde::Serializer::new(writer).with_struct_map();
21 value.serialize(&mut s).map_err(to_compute_err)
22 } else {
23 bincode::serde::encode_into_std_write(value, &mut writer, config())
24 .map_err(to_compute_err)
25 .map(|_| ())
26 }
27}
28
29pub fn deserialize_impl<T, R, const FC: bool>(mut reader: R) -> PolarsResult<T>
30where
31 T: serde::de::DeserializeOwned,
32 R: std::io::Read,
33{
34 if FC {
35 rmp_serde::from_read(reader).map_err(to_compute_err)
36 } else {
37 bincode::serde::decode_from_std_read(&mut reader, config()).map_err(to_compute_err)
38 }
39}
40
41pub struct SerializeOptions {
45 compression: bool,
46}
47
48impl SerializeOptions {
49 pub fn with_compression(mut self, compression: bool) -> Self {
50 self.compression = compression;
51 self
52 }
53
54 pub fn serialize_into_writer<W, T, const FC: bool>(
55 &self,
56 writer: W,
57 value: &T,
58 ) -> PolarsResult<()>
59 where
60 W: std::io::Write,
61 T: serde::ser::Serialize,
62 {
63 if self.compression {
64 let writer = flate2::write::ZlibEncoder::new(writer, flate2::Compression::fast());
65 serialize_impl::<_, _, FC>(writer, value)
66 } else {
67 serialize_impl::<_, _, FC>(writer, value)
68 }
69 }
70
71 pub fn deserialize_from_reader<T, R, const FC: bool>(&self, reader: R) -> PolarsResult<T>
72 where
73 T: serde::de::DeserializeOwned,
74 R: std::io::Read,
75 {
76 if self.compression {
77 deserialize_impl::<_, _, FC>(flate2::read::ZlibDecoder::new(reader))
78 } else {
79 deserialize_impl::<_, _, FC>(reader)
80 }
81 }
82
83 pub fn serialize_to_bytes<T, const FC: bool>(&self, value: &T) -> PolarsResult<Vec<u8>>
84 where
85 T: serde::ser::Serialize,
86 {
87 let mut v = vec![];
88
89 self.serialize_into_writer::<_, _, FC>(&mut v, value)?;
90
91 Ok(v)
92 }
93}
94
95#[allow(clippy::derivable_impls)]
96impl Default for SerializeOptions {
97 fn default() -> Self {
98 Self { compression: false }
99 }
100}
101
102pub fn serialize_into_writer<W, T, const FC: bool>(writer: W, value: &T) -> PolarsResult<()>
103where
104 W: std::io::Write,
105 T: serde::ser::Serialize,
106{
107 serialize_impl::<_, _, FC>(writer, value)
108}
109
110pub fn deserialize_from_reader<T, R, const FC: bool>(reader: R) -> PolarsResult<T>
111where
112 T: serde::de::DeserializeOwned,
113 R: std::io::Read,
114{
115 deserialize_impl::<_, _, FC>(reader)
116}
117
118pub fn serialize_to_bytes<T, const FC: bool>(value: &T) -> PolarsResult<Vec<u8>>
119where
120 T: serde::ser::Serialize,
121{
122 let mut v = vec![];
123
124 serialize_into_writer::<_, _, FC>(&mut v, value)?;
125
126 Ok(v)
127}
128
129pub fn serialize_dsl<W, T>(writer: W, value: &T) -> PolarsResult<()>
131where
132 W: std::io::Write,
133 T: serde::ser::Serialize,
134{
135 let mut s = rmp_serde::Serializer::new(writer).with_struct_map();
136 let s = serde_stacker::Serializer::new(&mut s);
137 value.serialize(s).map_err(to_compute_err)
138}
139
140pub fn deserialize_dsl<T, R>(reader: R) -> PolarsResult<T>
142where
143 T: serde::de::DeserializeOwned,
144 R: std::io::Read,
145{
146 let mut de = rmp_serde::Deserializer::new(reader);
147 de.set_max_depth(usize::MAX);
148 let de = serde_stacker::Deserializer::new(&mut de);
149 T::deserialize(de).map_err(to_compute_err)
150}
151
152pub fn deserialize_map_bytes<'de, D, O>(
156 deserializer: D,
157 mut func: impl for<'b> FnMut(std::borrow::Cow<'b, [u8]>) -> O,
158) -> Result<O, D::Error>
159where
160 D: serde::de::Deserializer<'de>,
161{
162 let mut out: Option<O> = None;
164 struct V<'f>(&'f mut dyn for<'b> FnMut(std::borrow::Cow<'b, [u8]>));
165
166 deserializer.deserialize_bytes(V(&mut |v| drop(out.replace(func(v)))))?;
167
168 return Ok(out.unwrap());
169
170 impl<'de> serde::de::Visitor<'de> for V<'_> {
171 type Value = ();
172
173 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
174 formatter.write_str("deserialize_map_bytes")
175 }
176
177 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
178 where
179 E: serde::de::Error,
180 {
181 self.0(std::borrow::Cow::Borrowed(v));
182 Ok(())
183 }
184
185 fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
186 where
187 E: serde::de::Error,
188 {
189 self.0(std::borrow::Cow::Owned(v));
190 Ok(())
191 }
192
193 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
194 where
195 A: serde::de::SeqAccess<'de>,
196 {
197 let bytes = std::iter::from_fn(|| seq.next_element::<u8>().transpose())
199 .collect::<Result<Vec<_>, A::Error>>()?;
200
201 self.0(std::borrow::Cow::Owned(bytes));
202 Ok(())
203 }
204 }
205}
206
207thread_local! {
208 pub static USE_CLOUDPICKLE: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
209}
210
211#[cfg(feature = "python")]
212pub fn python_object_serialize(
213 pyobj: &pyo3::Py<pyo3::PyAny>,
214 buf: &mut Vec<u8>,
215) -> PolarsResult<()> {
216 use pyo3::Python;
217 use pyo3::pybacked::PyBackedBytes;
218 use pyo3::types::{PyAnyMethods, PyModule};
219
220 use crate::python_function::PYTHON3_VERSION;
221
222 let mut use_cloudpickle = USE_CLOUDPICKLE.get();
223 let dumped = Python::with_gil(|py| {
224 if use_cloudpickle {
226 let cloudpickle = PyModule::import(py, "cloudpickle")?.getattr("dumps")?;
227 cloudpickle.call1((pyobj.clone_ref(py),))?
228 } else {
229 let pickle = PyModule::import(py, "pickle")?.getattr("dumps")?;
230 match pickle.call1((pyobj.clone_ref(py),)) {
231 Ok(dumped) => dumped,
232 Err(_) => {
233 use_cloudpickle = true;
234 let cloudpickle = PyModule::import(py, "cloudpickle")?.getattr("dumps")?;
235 cloudpickle.call1((pyobj.clone_ref(py),))?
236 },
237 }
238 }
239 .extract::<PyBackedBytes>()
240 })?;
241
242 buf.push(use_cloudpickle as u8);
244 buf.extend_from_slice(&*PYTHON3_VERSION);
245
246 buf.extend_from_slice(&dumped);
248 Ok(())
249}
250
251#[cfg(feature = "python")]
252pub fn python_object_deserialize(buf: &[u8]) -> PolarsResult<pyo3::Py<pyo3::PyAny>> {
253 use polars_error::polars_ensure;
254 use pyo3::Python;
255 use pyo3::types::{PyAnyMethods, PyBytes, PyModule};
256
257 use crate::python_function::PYTHON3_VERSION;
258
259 let use_cloudpickle = buf[0] != 0;
261 if use_cloudpickle {
262 let ser_py_version = &buf[1..3];
263 let cur_py_version = *PYTHON3_VERSION;
264 polars_ensure!(
265 ser_py_version == cur_py_version,
266 InvalidOperation:
267 "current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
268 (3, cur_py_version[0], cur_py_version[1]),
269 (3, ser_py_version[0], ser_py_version[1] )
270 );
271 }
272 let buf = &buf[3..];
273
274 Python::with_gil(|py| {
275 let loads = PyModule::import(py, "pickle")?.getattr("loads")?;
276 let arg = (PyBytes::new(py, buf),);
277 let python_function = loads.call1(arg)?;
278 Ok(python_function.into())
279 })
280}
281
282#[cfg(test)]
283mod tests {
284 #[test]
285 fn test_serde_skip_enum() {
286 #[derive(Default, Debug, PartialEq)]
287 struct MyType(Option<usize>);
288
289 #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
291 enum Enum {
292 A,
293 #[serde(skip)]
294 B(MyType),
295 }
296
297 impl Default for Enum {
298 fn default() -> Self {
299 Self::B(MyType(None))
300 }
301 }
302
303 let v = Enum::A;
304 let b = super::serialize_to_bytes::<_, false>(&v).unwrap();
305 let r: Enum = super::deserialize_from_reader::<_, _, false>(b.as_slice()).unwrap();
306
307 assert_eq!(r, v);
308
309 let v = Enum::A;
310 let b = super::SerializeOptions::default()
311 .serialize_to_bytes::<_, false>(&v)
312 .unwrap();
313 let r: Enum = super::SerializeOptions::default()
314 .deserialize_from_reader::<_, _, false>(b.as_slice())
315 .unwrap();
316
317 assert_eq!(r, v);
318 }
319}