1use std::fmt::Debug;
2
3use arrow::bitmap::Bitmap;
4use polars_utils::pl_str::PlSmallStr;
5
6use crate::prelude::*;
7use crate::utils::try_get_supertype;
8
9pub type SchemaRef = Arc<Schema>;
10pub type Schema = polars_schema::Schema<DataType>;
11
12pub trait SchemaExt {
13 fn from_arrow_schema(value: &ArrowSchema) -> Self;
14
15 fn get_field(&self, name: &str) -> Option<Field>;
16
17 fn try_get_field(&self, name: &str) -> PolarsResult<Field>;
18
19 fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema;
20
21 fn iter_fields(&self) -> impl ExactSizeIterator<Item = Field> + '_;
22
23 fn to_supertype(&mut self, other: &Schema) -> PolarsResult<bool>;
24
25 fn project_select(&self, select: &Bitmap) -> Self;
27}
28
29impl SchemaExt for Schema {
30 fn from_arrow_schema(value: &ArrowSchema) -> Self {
31 value
32 .iter_values()
33 .map(|x| (x.name.clone(), DataType::from_arrow_field(x)))
34 .collect()
35 }
36
37 fn get_field(&self, name: &str) -> Option<Field> {
44 self.get_full(name)
45 .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone()))
46 }
47
48 fn try_get_field(&self, name: &str) -> PolarsResult<Field> {
55 self.get_full(name)
56 .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name))
57 .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone()))
58 }
59
60 fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema {
62 self.iter()
63 .map(|(name, dtype)| {
64 (
65 name.clone(),
66 dtype.to_arrow_field(name.clone(), compat_level),
67 )
68 })
69 .collect()
70 }
71
72 fn iter_fields(&self) -> impl ExactSizeIterator<Item = Field> + '_ {
77 self.iter()
78 .map(|(name, dtype)| Field::new(name.clone(), dtype.clone()))
79 }
80
81 fn to_supertype(&mut self, other: &Schema) -> PolarsResult<bool> {
83 polars_ensure!(self.len() == other.len(), ComputeError: "schema lengths differ");
84
85 let mut changed = false;
86 for ((k, dt), (other_k, other_dt)) in self.iter_mut().zip(other.iter()) {
87 polars_ensure!(k == other_k, ComputeError: "schema names differ: got {}, expected {}", k, other_k);
88
89 let st = try_get_supertype(dt, other_dt)?;
90 changed |= (&st != dt) || (&st != other_dt);
91 *dt = st
92 }
93 Ok(changed)
94 }
95
96 fn project_select(&self, select: &Bitmap) -> Self {
97 assert_eq!(self.len(), select.len());
98 self.iter()
99 .zip(select.iter())
100 .filter(|(_, select)| *select)
101 .map(|((n, dt), _)| (n.clone(), dt.clone()))
102 .collect()
103 }
104}
105
106pub trait SchemaNamesAndDtypes {
107 const IS_ARROW: bool;
108 type DataType: Debug + Clone + Default + PartialEq;
109
110 fn iter_names_and_dtypes(
111 &self,
112 ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)>;
113}
114
115impl SchemaNamesAndDtypes for ArrowSchema {
116 const IS_ARROW: bool = true;
117 type DataType = ArrowDataType;
118
119 fn iter_names_and_dtypes(
120 &self,
121 ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)> {
122 self.iter_values().map(|x| (&x.name, &x.dtype))
123 }
124}
125
126impl SchemaNamesAndDtypes for Schema {
127 const IS_ARROW: bool = false;
128 type DataType = DataType;
129
130 fn iter_names_and_dtypes(
131 &self,
132 ) -> impl ExactSizeIterator<Item = (&PlSmallStr, &Self::DataType)> {
133 self.iter()
134 }
135}
136
137pub fn ensure_matching_schema<D>(
138 lhs: &polars_schema::Schema<D>,
139 rhs: &polars_schema::Schema<D>,
140) -> PolarsResult<()>
141where
142 polars_schema::Schema<D>: SchemaNamesAndDtypes,
143{
144 let lhs = lhs.iter_names_and_dtypes();
145 let rhs = rhs.iter_names_and_dtypes();
146
147 if lhs.len() != rhs.len() {
148 polars_bail!(
149 SchemaMismatch:
150 "schemas contained differing number of columns: {} != {}",
151 lhs.len(), rhs.len(),
152 );
153 }
154
155 for (i, ((l_name, l_dtype), (r_name, r_dtype))) in lhs.zip(rhs).enumerate() {
156 if l_name != r_name {
157 polars_bail!(
158 SchemaMismatch:
159 "schema names differ at index {}: {} != {}",
160 i, l_name, r_name
161 )
162 }
163 if l_dtype != r_dtype
164 && (!polars_schema::Schema::<D>::IS_ARROW
165 || unsafe {
166 DataType::from_arrow_dtype(std::mem::transmute::<
168 &<polars_schema::Schema<D> as SchemaNamesAndDtypes>::DataType,
169 &ArrowDataType,
170 >(l_dtype))
171 != DataType::from_arrow_dtype(std::mem::transmute::<
172 &<polars_schema::Schema<D> as SchemaNamesAndDtypes>::DataType,
173 &ArrowDataType,
174 >(r_dtype))
175 })
176 {
177 polars_bail!(
178 SchemaMismatch:
179 "schema dtypes differ at index {} for column {}: {:?} != {:?}",
180 i, l_name, l_dtype, r_dtype
181 )
182 }
183 }
184
185 Ok(())
186}