polars_sql/
sql_visitors.rs1use std::ops::ControlFlow;
7
8use polars_core::prelude::*;
9use sqlparser::ast::{Expr as SQLExpr, ObjectName, Query, SetExpr, Visit, Visitor as SQLVisitor};
10use sqlparser::keywords::ALL_KEYWORDS;
11
12pub(crate) struct FindTableIdentifier<'a> {
18 table_name: &'a str,
19 found: bool,
20}
21
22impl<'a> FindTableIdentifier<'a> {
23 fn new(table_name: &'a str) -> Self {
24 Self {
25 table_name,
26 found: false,
27 }
28 }
29}
30
31impl<'a> SQLVisitor for FindTableIdentifier<'a> {
32 type Break = ();
33
34 fn pre_visit_expr(&mut self, expr: &SQLExpr) -> ControlFlow<Self::Break> {
35 if let SQLExpr::CompoundIdentifier(idents) = expr {
36 if idents.len() >= 2 && idents[0].value.as_str() == self.table_name {
37 self.found = true; return ControlFlow::Break(());
39 }
40 }
41 ControlFlow::Continue(())
42 }
43}
44
45pub(crate) fn expr_refers_to_table(expr: &SQLExpr, table_name: &str) -> bool {
47 let mut table_finder = FindTableIdentifier::new(table_name);
48 let _ = expr.visit(&mut table_finder);
49 table_finder.found
50}
51
52pub(crate) struct QualifyExpression {
59 has_window_functions: bool,
60 column_refs: PlHashSet<String>,
61}
62
63impl QualifyExpression {
64 fn new() -> Self {
65 Self {
66 has_window_functions: false,
67 column_refs: PlHashSet::new(),
68 }
69 }
70
71 pub(crate) fn analyze(expr: &SQLExpr) -> (bool, PlHashSet<String>) {
72 let mut analyzer = Self::new();
73 let _ = expr.visit(&mut analyzer);
74 (analyzer.has_window_functions, analyzer.column_refs)
75 }
76}
77
78impl SQLVisitor for QualifyExpression {
79 type Break = ();
80
81 fn pre_visit_expr(&mut self, expr: &SQLExpr) -> ControlFlow<Self::Break> {
82 match expr {
83 SQLExpr::Function(func) if func.over.is_some() => {
84 self.has_window_functions = true;
85 },
86 SQLExpr::Identifier(ident) => {
87 self.column_refs.insert(ident.value.clone());
88 },
89 SQLExpr::CompoundIdentifier(idents) if !idents.is_empty() => {
90 self.column_refs
91 .insert(idents.last().unwrap().value.clone());
92 },
93 _ => {},
94 }
95 ControlFlow::Continue(())
96 }
97}
98
99fn maybe_quote(s: &str, force: bool) -> String {
105 let needs_quoting = force
106 || s.is_empty()
107 || s.starts_with(|c: char| c.is_ascii_digit())
108 || !s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
109 || ALL_KEYWORDS.contains(&s.to_ascii_uppercase().as_str());
110 if needs_quoting {
111 format!("\"{s}\"")
112 } else {
113 s.to_string()
114 }
115}
116
117struct AmbiguousColumnVisitor<'a> {
121 joined_aliases: &'a PlHashMap<String, PlHashMap<String, String>>,
122 base_table_name: &'a str,
123 using_cols: &'a PlHashSet<String>,
124}
125
126impl SQLVisitor for AmbiguousColumnVisitor<'_> {
127 type Break = PolarsError;
128
129 fn pre_visit_expr(&mut self, expr: &SQLExpr) -> ControlFlow<Self::Break> {
130 if let SQLExpr::Identifier(ident) = expr {
131 let col = &ident.value;
132 if self.using_cols.contains(col) {
133 return ControlFlow::Continue(());
134 }
135 let mut tables: Vec<_> = self
136 .joined_aliases
137 .iter()
138 .filter_map(|(t, cols)| cols.contains_key(col).then_some(t.as_str()))
139 .collect();
140
141 if !tables.is_empty() {
142 tables.push(self.base_table_name);
143 tables.sort();
144 let col_hint = maybe_quote(col, false);
145 let hints = tables
146 .iter()
147 .map(|t| format!("{}.{}", maybe_quote(t, false), col_hint));
148 return ControlFlow::Break(polars_err!(
149 SQLInterface: "ambiguous reference to column {} (use one of: {})",
150 maybe_quote(col, true), hints.collect::<Vec<_>>().join(", ")
151 ));
152 }
153 }
154 ControlFlow::Continue(())
155 }
156}
157
158pub(crate) fn check_for_ambiguous_column_refs(
162 expr: &SQLExpr,
163 joined_aliases: &PlHashMap<String, PlHashMap<String, String>>,
164 base_table_name: &str,
165 using_cols: &PlHashSet<String>,
166) -> PolarsResult<()> {
167 match expr.visit(&mut AmbiguousColumnVisitor {
168 joined_aliases,
169 base_table_name,
170 using_cols,
171 }) {
172 ControlFlow::Break(err) => Err(err),
173 ControlFlow::Continue(()) => Ok(()),
174 }
175}
176
177#[derive(Default)]
183pub(crate) struct TableIdentifierCollector {
184 pub(crate) tables: Vec<String>,
185 pub(crate) include_schema: bool,
186}
187
188impl TableIdentifierCollector {
189 pub(crate) fn collect_from_set_expr(&mut self, set_expr: &SetExpr) {
190 match set_expr {
192 SetExpr::Table(tbl) => {
193 self.tables.extend(if self.include_schema {
194 match (&tbl.schema_name, &tbl.table_name) {
195 (Some(schema), Some(table)) => Some(format!("{schema}.{table}")),
196 (None, Some(table)) => Some(table.clone()),
197 _ => None,
198 }
199 } else {
200 tbl.table_name.clone()
201 });
202 },
203 SetExpr::SetOperation { left, right, .. } => {
204 self.collect_from_set_expr(left);
205 self.collect_from_set_expr(right);
206 },
207 SetExpr::Query(query) => self.collect_from_set_expr(&query.body),
208 _ => {},
209 }
210 }
211}
212
213impl SQLVisitor for TableIdentifierCollector {
214 type Break = ();
215
216 fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
217 self.collect_from_set_expr(&query.body);
219 ControlFlow::Continue(())
220 }
221
222 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
223 self.tables.extend(if self.include_schema {
225 let parts: Vec<_> = relation
226 .0
227 .iter()
228 .filter_map(|p| p.as_ident().map(|i| i.value.as_str()))
229 .collect();
230 (!parts.is_empty()).then(|| parts.join("."))
231 } else {
232 relation
233 .0
234 .last()
235 .and_then(|p| p.as_ident())
236 .map(|i| i.value.clone())
237 });
238 ControlFlow::Continue(())
239 }
240}
241
242struct WindowFunctionFinder;
249
250impl SQLVisitor for WindowFunctionFinder {
251 type Break = ();
252
253 fn pre_visit_expr(&mut self, expr: &SQLExpr) -> ControlFlow<()> {
254 if matches!(expr, SQLExpr::Function(f) if f.over.is_some()) {
255 ControlFlow::Break(())
256 } else {
257 ControlFlow::Continue(())
258 }
259 }
260}
261
262pub(crate) fn expr_has_window_functions(expr: &SQLExpr) -> bool {
264 expr.visit(&mut WindowFunctionFinder).is_break()
265}