polars_io/parquet/write/
key_value_metadata.rs1use std::fmt::Debug;
2use std::hash::Hash;
3use std::sync::Arc;
4
5use polars_error::PolarsResult;
6use polars_parquet::write::KeyValue;
7#[cfg(feature = "python")]
8use polars_utils::python_function::PythonObject;
9#[cfg(feature = "python")]
10use pyo3::PyObject;
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize, de, ser};
13
14pub struct ParquetMetadataContext<'a> {
16 pub arrow_schema: &'a str,
17}
18
19#[derive(Clone, Debug, PartialEq, Eq, Hash)]
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
23pub enum KeyValueMetadata {
24 Static(
26 #[cfg_attr(
27 feature = "serde",
28 serde(
29 serialize_with = "serialize_vec_key_value",
30 deserialize_with = "deserialize_vec_key_value"
31 )
32 )]
33 #[cfg_attr(
34 feature = "dsl-schema",
35 schemars(with = "Vec<(String, Option<String>)>")
36 )]
37 Vec<KeyValue>,
38 ),
39 #[cfg_attr(feature = "dsl-schema", schemars(skip))]
41 DynamicRust(RustKeyValueMetadataFunction),
42 #[cfg(feature = "python")]
44 DynamicPython(python_impl::PythonKeyValueMetadataFunction),
45}
46
47#[cfg(feature = "serde")]
48fn serialize_vec_key_value<S>(kv: &[KeyValue], serializer: S) -> Result<S::Ok, S::Error>
49where
50 S: ser::Serializer,
51{
52 kv.iter()
53 .map(|item| (&item.key, item.value.as_ref()))
54 .collect::<Vec<_>>()
55 .serialize(serializer)
56}
57
58#[cfg(feature = "serde")]
59fn deserialize_vec_key_value<'de, D>(deserializer: D) -> Result<Vec<KeyValue>, D::Error>
60where
61 D: de::Deserializer<'de>,
62{
63 let data = Vec::<(String, Option<String>)>::deserialize(deserializer)?;
64 let result = data
65 .into_iter()
66 .map(|(key, value)| KeyValue { key, value })
67 .collect::<Vec<_>>();
68 Ok(result)
69}
70
71impl KeyValueMetadata {
72 pub fn from_static(kv: Vec<(String, String)>) -> Self {
74 Self::Static(
75 kv.into_iter()
76 .map(|(key, value)| KeyValue {
77 key,
78 value: Some(value),
79 })
80 .collect(),
81 )
82 }
83
84 #[cfg(feature = "python")]
86 pub fn from_py_function(py_object: PyObject) -> Self {
87 Self::DynamicPython(python_impl::PythonKeyValueMetadataFunction(Arc::new(
88 PythonObject(py_object),
89 )))
90 }
91
92 pub fn collect(&self, ctx: ParquetMetadataContext) -> PolarsResult<Vec<KeyValue>> {
95 match self {
96 Self::Static(kv) => Ok(kv.clone()),
97 Self::DynamicRust(func) => Ok(func.0(ctx)),
98 #[cfg(feature = "python")]
99 Self::DynamicPython(py_func) => py_func.call(ctx),
100 }
101 }
102}
103
104#[derive(Clone)]
105pub struct RustKeyValueMetadataFunction(
106 Arc<dyn Fn(ParquetMetadataContext) -> Vec<KeyValue> + Send + Sync>,
107);
108
109impl Debug for RustKeyValueMetadataFunction {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 write!(
112 f,
113 "key value metadata function at 0x{:016x}",
114 self.0.as_ref() as *const _ as *const () as usize
115 )
116 }
117}
118
119impl Eq for RustKeyValueMetadataFunction {}
120
121impl PartialEq for RustKeyValueMetadataFunction {
122 fn eq(&self, other: &Self) -> bool {
123 Arc::ptr_eq(&self.0, &other.0)
124 }
125}
126
127impl Hash for RustKeyValueMetadataFunction {
128 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
129 state.write_usize(Arc::as_ptr(&self.0) as *const () as usize);
130 }
131}
132
133#[cfg(feature = "serde")]
134impl Serialize for RustKeyValueMetadataFunction {
135 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
136 where
137 S: serde::Serializer,
138 {
139 use serde::ser::Error;
140 Err(S::Error::custom(format!("cannot serialize {self:?}")))
141 }
142}
143
144#[cfg(feature = "serde")]
145impl<'de> Deserialize<'de> for RustKeyValueMetadataFunction {
146 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
147 where
148 D: serde::Deserializer<'de>,
149 {
150 use serde::de::Error;
151 Err(D::Error::custom(
152 "cannot deserialize RustKeyValueMetadataFn",
153 ))
154 }
155}
156
157#[cfg(feature = "python")]
158mod python_impl {
159 use std::hash::Hash;
160 use std::sync::Arc;
161
162 use polars_error::{PolarsResult, to_compute_err};
163 use polars_parquet::write::KeyValue;
164 use polars_utils::python_function::PythonObject;
165 use pyo3::types::PyAnyMethods;
166 use pyo3::{PyResult, Python, pyclass};
167 use serde::{Deserialize, Serialize};
168
169 use super::ParquetMetadataContext;
170
171 #[derive(Clone, Debug, PartialEq, Eq)]
172 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
173 #[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
174 pub struct PythonKeyValueMetadataFunction(
175 #[cfg(feature = "python")]
176 #[cfg_attr(
177 feature = "serde",
178 serde(
179 serialize_with = "PythonObject::serialize_with_pyversion",
180 deserialize_with = "PythonObject::deserialize_with_pyversion"
181 )
182 )]
183 #[cfg_attr(feature = "dsl-schema", schemars(with = "Vec<u8>"))]
184 pub Arc<polars_utils::python_function::PythonFunction>,
185 );
186
187 impl PythonKeyValueMetadataFunction {
188 pub fn call(&self, ctx: ParquetMetadataContext) -> PolarsResult<Vec<KeyValue>> {
189 let ctx = PythonParquetMetadataContext::from_key_value_metadata_context(ctx);
190 Python::with_gil(|py| {
191 let args = (ctx,);
192 let out: Vec<(String, String)> =
193 self.0.call1(py, args)?.into_bound(py).extract()?;
194 let result = out
195 .into_iter()
196 .map(|item| KeyValue {
197 key: item.0,
198 value: Some(item.1),
199 })
200 .collect::<Vec<_>>();
201 PyResult::Ok(result)
202 })
203 .map_err(to_compute_err)
204 }
205 }
206
207 impl Hash for PythonKeyValueMetadataFunction {
208 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
209 state.write_usize(Arc::as_ptr(&self.0) as *const () as usize);
210 }
211 }
212
213 #[pyclass]
214 pub struct PythonParquetMetadataContext {
215 #[pyo3(get)]
216 arrow_schema: String,
217 }
218
219 impl PythonParquetMetadataContext {
220 pub fn from_key_value_metadata_context(ctx: ParquetMetadataContext) -> Self {
221 Self {
222 arrow_schema: ctx.arrow_schema.to_string(),
223 }
224 }
225 }
226}