polars_core/schema/
mod.rs1use std::fmt::Debug;
2
3use arrow::bitmap::Bitmap;
4use polars_utils::pl_str::PlSmallStr;
5
6use crate::prelude::*;
7use crate::utils::try_get_supertype;
8
9pub mod iceberg;
10
11pub type SchemaRef = Arc<Schema>;
12pub type Schema = polars_schema::Schema<DataType, ()>;
13
14pub trait SchemaExt {
15 fn from_arrow_schema(value: &ArrowSchema) -> Self;
16
17 fn get_field(&self, name: &str) -> Option<Field>;
18
19 fn try_get_field(&self, name: &str) -> PolarsResult<Field>;
20
21 fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema;
22
23 fn iter_fields(&self) -> impl ExactSizeIterator<Item = Field> + '_;
24
25 fn to_supertype(&mut self, other: &Schema) -> PolarsResult<bool>;
26
27 fn project_select(&self, select: &Bitmap) -> Self;
29
30 fn contains_dtype(&self, dtype: &DataType, recursive: bool) -> bool;
31}
32
33impl SchemaExt for Schema {
34 fn from_arrow_schema(value: &ArrowSchema) -> Self {
35 value
36 .iter_values()
37 .map(|x| (x.name.clone(), DataType::from_arrow_field(x)))
38 .collect()
39 }
40
41 fn get_field(&self, name: &str) -> Option<Field> {
48 self.get_full(name)
49 .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone()))
50 }
51
52 fn try_get_field(&self, name: &str) -> PolarsResult<Field> {
59 self.get_full(name)
60 .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{name}"))
61 .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone()))
62 }
63
64 fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema {
66 self.iter()
67 .map(|(name, dtype)| {
68 (
69 name.clone(),
70 dtype.to_arrow_field(name.clone(), compat_level),
71 )
72 })
73 .collect()
74 }
75
76 fn iter_fields(&self) -> impl ExactSizeIterator<Item = Field> + '_ {
81 self.iter()
82 .map(|(name, dtype)| Field::new(name.clone(), dtype.clone()))
83 }
84
85 fn to_supertype(&mut self, other: &Schema) -> PolarsResult<bool> {
87 polars_ensure!(self.len() == other.len(), ComputeError: "schema lengths differ");
88
89 let mut changed = false;
90 for ((k, dt), (other_k, other_dt)) in self.iter_mut().zip(other.iter()) {
91 polars_ensure!(k == other_k, ComputeError: "schema names differ: got {}, expected {}", k, other_k);
92
93 let st = try_get_supertype(dt, other_dt)?;
94 changed |= (&st != dt) || (&st != other_dt);
95 *dt = st
96 }
97 Ok(changed)
98 }
99
100 fn project_select(&self, select: &Bitmap) -> Self {
101 assert_eq!(self.len(), select.len());
102 self.iter()
103 .zip(select.iter())
104 .filter(|(_, select)| *select)
105 .map(|((n, dt), _)| (n.clone(), dt.clone()))
106 .collect()
107 }
108
109 fn contains_dtype(&self, dtype: &DataType, recursive: bool) -> bool {
110 if !recursive {
111 self.iter_values().any(|dt| dt == dtype)
112 } else {
113 self.iter_values()
114 .any(|dt| dt.contains_dtype_recursive(dtype))
115 }
116 }
117}
118
119pub trait SchemaNamesAndDtypes {
120 const IS_ARROW: bool;
121 type DataType: Debug + Clone + PartialEq;
122
123 fn iter_names_and_dtypes(
124 &self,
125 ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)>;
126}
127
128impl SchemaNamesAndDtypes for ArrowSchema {
129 const IS_ARROW: bool = true;
130 type DataType = ArrowDataType;
131
132 fn iter_names_and_dtypes(
133 &self,
134 ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)> {
135 self.iter_values().map(|x| (&x.name, &x.dtype))
136 }
137}
138
139impl SchemaNamesAndDtypes for Schema {
140 const IS_ARROW: bool = false;
141 type DataType = DataType;
142
143 fn iter_names_and_dtypes(
144 &self,
145 ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)> {
146 self.iter()
147 }
148}
149
150pub fn ensure_matching_schema<F, M>(
151 lhs: &polars_schema::Schema<F, M>,
152 rhs: &polars_schema::Schema<F, M>,
153) -> PolarsResult<()>
154where
155 polars_schema::Schema<F, M>: SchemaNamesAndDtypes,
156{
157 let lhs = lhs.iter_names_and_dtypes();
158 let rhs = rhs.iter_names_and_dtypes();
159
160 if lhs.len() != rhs.len() {
161 polars_bail!(
162 SchemaMismatch:
163 "schemas contained differing number of columns: {} != {}",
164 lhs.len(), rhs.len(),
165 );
166 }
167
168 for (i, ((l_name, l_dtype), (r_name, r_dtype))) in lhs.zip(rhs).enumerate() {
169 if l_name != r_name {
170 polars_bail!(
171 SchemaMismatch:
172 "schema names differ at index {}: {} != {}",
173 i, l_name, r_name
174 )
175 }
176 if l_dtype != r_dtype
177 && (!polars_schema::Schema::<F, M>::IS_ARROW
178 || unsafe {
179 DataType::from_arrow_dtype(std::mem::transmute::<
181 &<polars_schema::Schema<F, M> as SchemaNamesAndDtypes>::DataType,
182 &ArrowDataType,
183 >(l_dtype))
184 != DataType::from_arrow_dtype(std::mem::transmute::<
185 &<polars_schema::Schema<F, M> as SchemaNamesAndDtypes>::DataType,
186 &ArrowDataType,
187 >(r_dtype))
188 })
189 {
190 polars_bail!(
191 SchemaMismatch:
192 "schema dtypes differ at index {} for column {}: {:?} != {:?}",
193 i, l_name, l_dtype, r_dtype
194 )
195 }
196 }
197
198 Ok(())
199}