polars_io/parquet/write/
key_value_metadata.rs

1use std::fmt::Debug;
2use std::hash::Hash;
3use std::sync::Arc;
4
5use polars_error::PolarsResult;
6use polars_parquet::write::KeyValue;
7#[cfg(feature = "python")]
8use polars_utils::python_function::PythonObject;
9#[cfg(feature = "python")]
10use pyo3::PyObject;
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize, de, ser};
13
14/// Context that can be used to construct custom file-level key value metadata for a Parquet file.
15pub struct ParquetMetadataContext<'a> {
16    pub arrow_schema: &'a str,
17}
18
19/// Key/value pairs that can be attached to a Parquet file as file-level metadtaa.
20#[derive(Clone, Debug, PartialEq, Eq, Hash)]
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
23pub enum KeyValueMetadata {
24    /// Static key value metadata.
25    Static(
26        #[cfg_attr(
27            feature = "serde",
28            serde(
29                serialize_with = "serialize_vec_key_value",
30                deserialize_with = "deserialize_vec_key_value"
31            )
32        )]
33        #[cfg_attr(
34            feature = "dsl-schema",
35            schemars(with = "Vec<(String, Option<String>)>")
36        )]
37        Vec<KeyValue>,
38    ),
39    /// Rust function to dynamically compute key value metadata.
40    #[cfg_attr(feature = "dsl-schema", schemars(skip))]
41    DynamicRust(RustKeyValueMetadataFunction),
42    /// Python function to dynamically compute key value metadata.
43    #[cfg(feature = "python")]
44    DynamicPython(python_impl::PythonKeyValueMetadataFunction),
45}
46
47#[cfg(feature = "serde")]
48fn serialize_vec_key_value<S>(kv: &[KeyValue], serializer: S) -> Result<S::Ok, S::Error>
49where
50    S: ser::Serializer,
51{
52    kv.iter()
53        .map(|item| (&item.key, item.value.as_ref()))
54        .collect::<Vec<_>>()
55        .serialize(serializer)
56}
57
58#[cfg(feature = "serde")]
59fn deserialize_vec_key_value<'de, D>(deserializer: D) -> Result<Vec<KeyValue>, D::Error>
60where
61    D: de::Deserializer<'de>,
62{
63    let data = Vec::<(String, Option<String>)>::deserialize(deserializer)?;
64    let result = data
65        .into_iter()
66        .map(|(key, value)| KeyValue { key, value })
67        .collect::<Vec<_>>();
68    Ok(result)
69}
70
71impl KeyValueMetadata {
72    /// Create a key value metadata object from a static key value mapping.
73    pub fn from_static(kv: Vec<(String, String)>) -> Self {
74        Self::Static(
75            kv.into_iter()
76                .map(|(key, value)| KeyValue {
77                    key,
78                    value: Some(value),
79                })
80                .collect(),
81        )
82    }
83
84    /// Create a key value metadata object from a Python function.
85    #[cfg(feature = "python")]
86    pub fn from_py_function(py_object: PyObject) -> Self {
87        Self::DynamicPython(python_impl::PythonKeyValueMetadataFunction(Arc::new(
88            PythonObject(py_object),
89        )))
90    }
91
92    /// Turn the metadata into the key/value pairs to write to the Parquet file.
93    /// The context is used to dynamically construct key/value pairs.
94    pub fn collect(&self, ctx: ParquetMetadataContext) -> PolarsResult<Vec<KeyValue>> {
95        match self {
96            Self::Static(kv) => Ok(kv.clone()),
97            Self::DynamicRust(func) => Ok(func.0(ctx)),
98            #[cfg(feature = "python")]
99            Self::DynamicPython(py_func) => py_func.call(ctx),
100        }
101    }
102}
103
104#[derive(Clone)]
105pub struct RustKeyValueMetadataFunction(
106    Arc<dyn Fn(ParquetMetadataContext) -> Vec<KeyValue> + Send + Sync>,
107);
108
109impl Debug for RustKeyValueMetadataFunction {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        write!(
112            f,
113            "key value metadata function at 0x{:016x}",
114            self.0.as_ref() as *const _ as *const () as usize
115        )
116    }
117}
118
119impl Eq for RustKeyValueMetadataFunction {}
120
121impl PartialEq for RustKeyValueMetadataFunction {
122    fn eq(&self, other: &Self) -> bool {
123        Arc::ptr_eq(&self.0, &other.0)
124    }
125}
126
127impl Hash for RustKeyValueMetadataFunction {
128    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
129        state.write_usize(Arc::as_ptr(&self.0) as *const () as usize);
130    }
131}
132
133#[cfg(feature = "serde")]
134impl Serialize for RustKeyValueMetadataFunction {
135    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
136    where
137        S: serde::Serializer,
138    {
139        use serde::ser::Error;
140        Err(S::Error::custom(format!("cannot serialize {self:?}")))
141    }
142}
143
144#[cfg(feature = "serde")]
145impl<'de> Deserialize<'de> for RustKeyValueMetadataFunction {
146    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
147    where
148        D: serde::Deserializer<'de>,
149    {
150        use serde::de::Error;
151        Err(D::Error::custom(
152            "cannot deserialize RustKeyValueMetadataFn",
153        ))
154    }
155}
156
157#[cfg(feature = "python")]
158mod python_impl {
159    use std::hash::Hash;
160    use std::sync::Arc;
161
162    use polars_error::{PolarsResult, to_compute_err};
163    use polars_parquet::write::KeyValue;
164    use polars_utils::python_function::PythonObject;
165    use pyo3::types::PyAnyMethods;
166    use pyo3::{PyResult, Python, pyclass};
167    use serde::{Deserialize, Serialize};
168
169    use super::ParquetMetadataContext;
170
171    #[derive(Clone, Debug, PartialEq, Eq)]
172    #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
173    #[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
174    pub struct PythonKeyValueMetadataFunction(
175        #[cfg(feature = "python")]
176        #[cfg_attr(
177            feature = "serde",
178            serde(
179                serialize_with = "PythonObject::serialize_with_pyversion",
180                deserialize_with = "PythonObject::deserialize_with_pyversion"
181            )
182        )]
183        #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<u8>"))]
184        pub Arc<polars_utils::python_function::PythonFunction>,
185    );
186
187    impl PythonKeyValueMetadataFunction {
188        pub fn call(&self, ctx: ParquetMetadataContext) -> PolarsResult<Vec<KeyValue>> {
189            let ctx = PythonParquetMetadataContext::from_key_value_metadata_context(ctx);
190            Python::with_gil(|py| {
191                let args = (ctx,);
192                let out: Vec<(String, String)> =
193                    self.0.call1(py, args)?.into_bound(py).extract()?;
194                let result = out
195                    .into_iter()
196                    .map(|item| KeyValue {
197                        key: item.0,
198                        value: Some(item.1),
199                    })
200                    .collect::<Vec<_>>();
201                PyResult::Ok(result)
202            })
203            .map_err(to_compute_err)
204        }
205    }
206
207    impl Hash for PythonKeyValueMetadataFunction {
208        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
209            state.write_usize(Arc::as_ptr(&self.0) as *const () as usize);
210        }
211    }
212
213    #[pyclass]
214    pub struct PythonParquetMetadataContext {
215        #[pyo3(get)]
216        arrow_schema: String,
217    }
218
219    impl PythonParquetMetadataContext {
220        pub fn from_key_value_metadata_context(ctx: ParquetMetadataContext) -> Self {
221            Self {
222                arrow_schema: ctx.arrow_schema.to_string(),
223            }
224        }
225    }
226}