polars_core/
schema.rs

1use std::fmt::Debug;
2
3use arrow::bitmap::Bitmap;
4use polars_utils::pl_str::PlSmallStr;
5
6use crate::prelude::*;
7use crate::utils::try_get_supertype;
8
9pub type SchemaRef = Arc<Schema>;
10pub type Schema = polars_schema::Schema<DataType>;
11
12pub trait SchemaExt {
13    fn from_arrow_schema(value: &ArrowSchema) -> Self;
14
15    fn get_field(&self, name: &str) -> Option<Field>;
16
17    fn try_get_field(&self, name: &str) -> PolarsResult<Field>;
18
19    fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema;
20
21    fn iter_fields(&self) -> impl ExactSizeIterator<Item = Field> + '_;
22
23    fn to_supertype(&mut self, other: &Schema) -> PolarsResult<bool>;
24
25    /// Select fields using a bitmap.
26    fn project_select(&self, select: &Bitmap) -> Self;
27}
28
29impl SchemaExt for Schema {
30    fn from_arrow_schema(value: &ArrowSchema) -> Self {
31        value
32            .iter_values()
33            .map(|x| (x.name.clone(), DataType::from_arrow_field(x)))
34            .collect()
35    }
36
37    /// Look up the name in the schema and return an owned [`Field`] by cloning the data.
38    ///
39    /// Returns `None` if the field does not exist.
40    ///
41    /// This method constructs the `Field` by cloning the name and dtype. For a version that returns references, see
42    /// [`get`][Self::get] or [`get_full`][Self::get_full].
43    fn get_field(&self, name: &str) -> Option<Field> {
44        self.get_full(name)
45            .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone()))
46    }
47
48    /// Look up the name in the schema and return an owned [`Field`] by cloning the data.
49    ///
50    /// Returns `Err(PolarsErr)` if the field does not exist.
51    ///
52    /// This method constructs the `Field` by cloning the name and dtype. For a version that returns references, see
53    /// [`get`][Self::get] or [`get_full`][Self::get_full].
54    fn try_get_field(&self, name: &str) -> PolarsResult<Field> {
55        self.get_full(name)
56            .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
57            .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone()))
58    }
59
60    /// Convert self to `ArrowSchema` by cloning the fields.
61    fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema {
62        self.iter()
63            .map(|(name, dtype)| {
64                (
65                    name.clone(),
66                    dtype.to_arrow_field(name.clone(), compat_level),
67                )
68            })
69            .collect()
70    }
71
72    /// Iterates the [`Field`]s in this schema, constructing them anew by cloning each `(&name, &dtype)` pair.
73    ///
74    /// Note that this clones each name and dtype in order to form an owned [`Field`]. For a clone-free version, use
75    /// [`iter`][Self::iter], which returns `(&name, &dtype)`.
76    fn iter_fields(&self) -> impl ExactSizeIterator<Item = Field> + '_ {
77        self.iter()
78            .map(|(name, dtype)| Field::new(name.clone(), dtype.clone()))
79    }
80
81    /// Take another [`Schema`] and try to find the supertypes between them.
82    fn to_supertype(&mut self, other: &Schema) -> PolarsResult<bool> {
83        polars_ensure!(self.len() == other.len(), ComputeError: "schema lengths differ");
84
85        let mut changed = false;
86        for ((k, dt), (other_k, other_dt)) in self.iter_mut().zip(other.iter()) {
87            polars_ensure!(k == other_k, ComputeError: "schema names differ: got {}, expected {}", k, other_k);
88
89            let st = try_get_supertype(dt, other_dt)?;
90            changed |= (&st != dt) || (&st != other_dt);
91            *dt = st
92        }
93        Ok(changed)
94    }
95
96    fn project_select(&self, select: &Bitmap) -> Self {
97        assert_eq!(self.len(), select.len());
98        self.iter()
99            .zip(select.iter())
100            .filter(|(_, select)| *select)
101            .map(|((n, dt), _)| (n.clone(), dt.clone()))
102            .collect()
103    }
104}
105
106pub trait SchemaNamesAndDtypes {
107    const IS_ARROW: bool;
108    type DataType: Debug + Clone + Default + PartialEq;
109
110    fn iter_names_and_dtypes(
111        &self,
112    ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)>;
113}
114
115impl SchemaNamesAndDtypes for ArrowSchema {
116    const IS_ARROW: bool = true;
117    type DataType = ArrowDataType;
118
119    fn iter_names_and_dtypes(
120        &self,
121    ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)> {
122        self.iter_values().map(|x| (&x.name, &x.dtype))
123    }
124}
125
126impl SchemaNamesAndDtypes for Schema {
127    const IS_ARROW: bool = false;
128    type DataType = DataType;
129
130    fn iter_names_and_dtypes(
131        &self,
132    ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)> {
133        self.iter()
134    }
135}
136
137pub fn ensure_matching_schema<D>(
138    lhs: &polars_schema::Schema<D>,
139    rhs: &polars_schema::Schema<D>,
140) -> PolarsResult<()>
141where
142    polars_schema::Schema<D>: SchemaNamesAndDtypes,
143{
144    let lhs = lhs.iter_names_and_dtypes();
145    let rhs = rhs.iter_names_and_dtypes();
146
147    if lhs.len() != rhs.len() {
148        polars_bail!(
149            SchemaMismatch:
150            "schemas contained differing number of columns: {} != {}",
151            lhs.len(), rhs.len(),
152        );
153    }
154
155    for (i, ((l_name, l_dtype), (r_name, r_dtype))) in lhs.zip(rhs).enumerate() {
156        if l_name != r_name {
157            polars_bail!(
158                SchemaMismatch:
159                "schema names differ at index {}: {} != {}",
160                i, l_name, r_name
161            )
162        }
163        if l_dtype != r_dtype
164            && (!polars_schema::Schema::<D>::IS_ARROW
165                || unsafe {
166                    // For timezone normalization. Easier than writing out the entire PartialEq.
167                    DataType::from_arrow_dtype(std::mem::transmute::<
168                        &<polars_schema::Schema<D> as SchemaNamesAndDtypes>::DataType,
169                        &ArrowDataType,
170                    >(l_dtype))
171                        != DataType::from_arrow_dtype(std::mem::transmute::<
172                            &<polars_schema::Schema<D> as SchemaNamesAndDtypes>::DataType,
173                            &ArrowDataType,
174                        >(r_dtype))
175                })
176        {
177            polars_bail!(
178                SchemaMismatch:
179                "schema dtypes differ at index {} for column {}: {:?} != {:?}",
180                i, l_name, l_dtype, r_dtype
181            )
182        }
183    }
184
185    Ok(())
186}