Skip to main content

polars_core/schema/
mod.rs

1use std::fmt::Debug;
2
3use arrow::bitmap::Bitmap;
4use polars_utils::pl_str::PlSmallStr;
5
6use crate::prelude::*;
7use crate::utils::try_get_supertype;
8
9pub mod iceberg;
10
11pub type SchemaRef = Arc<Schema>;
12pub type Schema = polars_schema::Schema<DataType, ()>;
13
14pub trait SchemaExt {
15    fn from_arrow_schema(value: &ArrowSchema) -> Self;
16
17    fn get_field(&self, name: &str) -> Option<Field>;
18
19    fn try_get_field(&self, name: &str) -> PolarsResult<Field>;
20
21    fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema;
22
23    fn iter_fields(&self) -> impl ExactSizeIterator<Item = Field> + '_;
24
25    fn to_supertype(&mut self, other: &Schema) -> PolarsResult<bool>;
26
27    /// Select fields using a bitmap.
28    fn project_select(&self, select: &Bitmap) -> Self;
29
30    fn contains_dtype(&self, dtype: &DataType, recursive: bool) -> bool;
31}
32
33impl SchemaExt for Schema {
34    fn from_arrow_schema(value: &ArrowSchema) -> Self {
35        value
36            .iter_values()
37            .map(|x| (x.name.clone(), DataType::from_arrow_field(x)))
38            .collect()
39    }
40
41    /// Look up the name in the schema and return an owned [`Field`] by cloning the data.
42    ///
43    /// Returns `None` if the field does not exist.
44    ///
45    /// This method constructs the `Field` by cloning the name and dtype. For a version that returns references, see
46    /// [`get`][Self::get] or [`get_full`][Self::get_full].
47    fn get_field(&self, name: &str) -> Option<Field> {
48        self.get_full(name)
49            .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone()))
50    }
51
52    /// Look up the name in the schema and return an owned [`Field`] by cloning the data.
53    ///
54    /// Returns `Err(PolarsErr)` if the field does not exist.
55    ///
56    /// This method constructs the `Field` by cloning the name and dtype. For a version that returns references, see
57    /// [`get`][Self::get] or [`get_full`][Self::get_full].
58    fn try_get_field(&self, name: &str) -> PolarsResult<Field> {
59        self.get_full(name)
60            .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{name}"))
61            .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone()))
62    }
63
64    /// Convert self to `ArrowSchema` by cloning the fields.
65    fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema {
66        self.iter()
67            .map(|(name, dtype)| {
68                (
69                    name.clone(),
70                    dtype.to_arrow_field(name.clone(), compat_level),
71                )
72            })
73            .collect()
74    }
75
76    /// Iterates the [`Field`]s in this schema, constructing them anew by cloning each `(&name, &dtype)` pair.
77    ///
78    /// Note that this clones each name and dtype in order to form an owned [`Field`]. For a clone-free version, use
79    /// [`iter`][Self::iter], which returns `(&name, &dtype)`.
80    fn iter_fields(&self) -> impl ExactSizeIterator<Item = Field> + '_ {
81        self.iter()
82            .map(|(name, dtype)| Field::new(name.clone(), dtype.clone()))
83    }
84
85    /// Take another [`Schema`] and try to find the supertypes between them.
86    fn to_supertype(&mut self, other: &Schema) -> PolarsResult<bool> {
87        polars_ensure!(self.len() == other.len(), ComputeError: "schema lengths differ");
88
89        let mut changed = false;
90        for ((k, dt), (other_k, other_dt)) in self.iter_mut().zip(other.iter()) {
91            polars_ensure!(k == other_k, ComputeError: "schema names differ: got {}, expected {}", k, other_k);
92
93            let st = try_get_supertype(dt, other_dt)?;
94            changed |= (&st != dt) || (&st != other_dt);
95            *dt = st
96        }
97        Ok(changed)
98    }
99
100    fn project_select(&self, select: &Bitmap) -> Self {
101        assert_eq!(self.len(), select.len());
102        self.iter()
103            .zip(select.iter())
104            .filter(|(_, select)| *select)
105            .map(|((n, dt), _)| (n.clone(), dt.clone()))
106            .collect()
107    }
108
109    fn contains_dtype(&self, dtype: &DataType, recursive: bool) -> bool {
110        if !recursive {
111            self.iter_values().any(|dt| dt == dtype)
112        } else {
113            self.iter_values()
114                .any(|dt| dt.contains_dtype_recursive(dtype))
115        }
116    }
117}
118
119pub trait SchemaNamesAndDtypes {
120    const IS_ARROW: bool;
121    type DataType: Debug + Clone + PartialEq;
122
123    fn iter_names_and_dtypes(
124        &self,
125    ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)>;
126}
127
128impl SchemaNamesAndDtypes for ArrowSchema {
129    const IS_ARROW: bool = true;
130    type DataType = ArrowDataType;
131
132    fn iter_names_and_dtypes(
133        &self,
134    ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)> {
135        self.iter_values().map(|x| (&x.name, &x.dtype))
136    }
137}
138
139impl SchemaNamesAndDtypes for Schema {
140    const IS_ARROW: bool = false;
141    type DataType = DataType;
142
143    fn iter_names_and_dtypes(
144        &self,
145    ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)> {
146        self.iter()
147    }
148}
149
150pub fn ensure_matching_schema<F, M>(
151    lhs: &polars_schema::Schema<F, M>,
152    rhs: &polars_schema::Schema<F, M>,
153) -> PolarsResult<()>
154where
155    polars_schema::Schema<F, M>: SchemaNamesAndDtypes,
156{
157    let lhs = lhs.iter_names_and_dtypes();
158    let rhs = rhs.iter_names_and_dtypes();
159
160    if lhs.len() != rhs.len() {
161        polars_bail!(
162            SchemaMismatch:
163            "schemas contained differing number of columns: {} != {}",
164            lhs.len(), rhs.len(),
165        );
166    }
167
168    for (i, ((l_name, l_dtype), (r_name, r_dtype))) in lhs.zip(rhs).enumerate() {
169        if l_name != r_name {
170            polars_bail!(
171                SchemaMismatch:
172                "schema names differ at index {}: {} != {}",
173                i, l_name, r_name
174            )
175        }
176        if l_dtype != r_dtype
177            && (!polars_schema::Schema::<F, M>::IS_ARROW
178                || unsafe {
179                    // For timezone normalization. Easier than writing out the entire PartialEq.
180                    DataType::from_arrow_dtype(std::mem::transmute::<
181                        &<polars_schema::Schema<F, M> as SchemaNamesAndDtypes>::DataType,
182                        &ArrowDataType,
183                    >(l_dtype))
184                        != DataType::from_arrow_dtype(std::mem::transmute::<
185                            &<polars_schema::Schema<F, M> as SchemaNamesAndDtypes>::DataType,
186                            &ArrowDataType,
187                        >(r_dtype))
188                })
189        {
190            polars_bail!(
191                SchemaMismatch:
192                "schema dtypes differ at index {} for column {}: {:?} != {:?}",
193                i, l_name, l_dtype, r_dtype
194            )
195        }
196    }
197
198    Ok(())
199}