Skip to main content

polars_sql/
sql_visitors.rs

1//! SQLVisitor helper implementations for traversing SQL AST expressions.
2//!
3//! This module provides visitor implementations used throughout the SQL interface
4//! to analyze and check SQL expressions for various properties.
5
6use std::ops::ControlFlow;
7
8use polars_core::prelude::*;
9use sqlparser::ast::{Expr as SQLExpr, ObjectName, Query, SetExpr, Visit, Visitor as SQLVisitor};
10use sqlparser::keywords::ALL_KEYWORDS;
11
12// ---------------------------------------------------------------------------
13// FindTableIdentifier
14// ---------------------------------------------------------------------------
15
16/// Visitor that checks if an expression tree contains a reference to a specific table.
17pub(crate) struct FindTableIdentifier<'a> {
18    table_name: &'a str,
19    found: bool,
20}
21
22impl<'a> FindTableIdentifier<'a> {
23    fn new(table_name: &'a str) -> Self {
24        Self {
25            table_name,
26            found: false,
27        }
28    }
29}
30
31impl<'a> SQLVisitor for FindTableIdentifier<'a> {
32    type Break = ();
33
34    fn pre_visit_expr(&mut self, expr: &SQLExpr) -> ControlFlow<Self::Break> {
35        if let SQLExpr::CompoundIdentifier(idents) = expr {
36            if idents.len() >= 2 && idents[0].value.as_str() == self.table_name {
37                self.found = true; // return immediately on first match
38                return ControlFlow::Break(());
39            }
40        }
41        ControlFlow::Continue(())
42    }
43}
44
45/// Check if a SQL expression contains a reference to a specific table.
46pub(crate) fn expr_refers_to_table(expr: &SQLExpr, table_name: &str) -> bool {
47    let mut table_finder = FindTableIdentifier::new(table_name);
48    let _ = expr.visit(&mut table_finder);
49    table_finder.found
50}
51
52// ---------------------------------------------------------------------------
53// QualifyExpression
54// ---------------------------------------------------------------------------
55
56/// Visitor used to check a SQL expression used in a QUALIFY clause.
57/// (Confirms window functions are present and collects column refs in one pass).
58pub(crate) struct QualifyExpression {
59    has_window_functions: bool,
60    column_refs: PlHashSet<String>,
61}
62
63impl QualifyExpression {
64    fn new() -> Self {
65        Self {
66            has_window_functions: false,
67            column_refs: PlHashSet::new(),
68        }
69    }
70
71    pub(crate) fn analyze(expr: &SQLExpr) -> (bool, PlHashSet<String>) {
72        let mut analyzer = Self::new();
73        let _ = expr.visit(&mut analyzer);
74        (analyzer.has_window_functions, analyzer.column_refs)
75    }
76}
77
78impl SQLVisitor for QualifyExpression {
79    type Break = ();
80
81    fn pre_visit_expr(&mut self, expr: &SQLExpr) -> ControlFlow<Self::Break> {
82        match expr {
83            SQLExpr::Function(func) if func.over.is_some() => {
84                self.has_window_functions = true;
85            },
86            SQLExpr::Identifier(ident) => {
87                self.column_refs.insert(ident.value.clone());
88            },
89            SQLExpr::CompoundIdentifier(idents) if !idents.is_empty() => {
90                self.column_refs
91                    .insert(idents.last().unwrap().value.clone());
92            },
93            _ => {},
94        }
95        ControlFlow::Continue(())
96    }
97}
98
99// ---------------------------------------------------------------------------
100// AmbiguousColumnVisitor
101// ---------------------------------------------------------------------------
102
103/// Format an identifier, quoting only if necessary (or `force` is true).
104fn maybe_quote(s: &str, force: bool) -> String {
105    let needs_quoting = force
106        || s.is_empty()
107        || s.starts_with(|c: char| c.is_ascii_digit())
108        || !s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
109        || ALL_KEYWORDS.contains(&s.to_ascii_uppercase().as_str());
110    if needs_quoting {
111        format!("\"{s}\"")
112    } else {
113        s.to_string()
114    }
115}
116
117/// Visitor that checks for unqualified references to columns that exist in
118/// multiple tables (columns appearing in a USING clause are excluded from
119/// the check as they are implicitly coalesced).
120struct AmbiguousColumnVisitor<'a> {
121    joined_aliases: &'a PlHashMap<String, PlHashMap<String, String>>,
122    base_table_name: &'a str,
123    using_cols: &'a PlHashSet<String>,
124}
125
126impl SQLVisitor for AmbiguousColumnVisitor<'_> {
127    type Break = PolarsError;
128
129    fn pre_visit_expr(&mut self, expr: &SQLExpr) -> ControlFlow<Self::Break> {
130        if let SQLExpr::Identifier(ident) = expr {
131            let col = &ident.value;
132            if self.using_cols.contains(col) {
133                return ControlFlow::Continue(());
134            }
135            let mut tables: Vec<_> = self
136                .joined_aliases
137                .iter()
138                .filter_map(|(t, cols)| cols.contains_key(col).then_some(t.as_str()))
139                .collect();
140
141            if !tables.is_empty() {
142                tables.push(self.base_table_name);
143                tables.sort();
144                let col_hint = maybe_quote(col, false);
145                let hints = tables
146                    .iter()
147                    .map(|t| format!("{}.{}", maybe_quote(t, false), col_hint));
148                return ControlFlow::Break(polars_err!(
149                    SQLInterface: "ambiguous reference to column {} (use one of: {})",
150                    maybe_quote(col, true), hints.collect::<Vec<_>>().join(", ")
151                ));
152            }
153        }
154        ControlFlow::Continue(())
155    }
156}
157
158/// Check a SQL expression for unqualified references to columns that
159/// exist in multiple tables (columns appearing in a USING clause are
160/// excluded from the check as they are implicitly coalesced).
161pub(crate) fn check_for_ambiguous_column_refs(
162    expr: &SQLExpr,
163    joined_aliases: &PlHashMap<String, PlHashMap<String, String>>,
164    base_table_name: &str,
165    using_cols: &PlHashSet<String>,
166) -> PolarsResult<()> {
167    match expr.visit(&mut AmbiguousColumnVisitor {
168        joined_aliases,
169        base_table_name,
170        using_cols,
171    }) {
172        ControlFlow::Break(err) => Err(err),
173        ControlFlow::Continue(()) => Ok(()),
174    }
175}
176
177// ---------------------------------------------------------------------------
178// TableIdentifierCollector
179// ---------------------------------------------------------------------------
180
181/// Visitor that collects all table identifiers referenced in a SQL query.
182#[derive(Default)]
183pub(crate) struct TableIdentifierCollector {
184    pub(crate) tables: Vec<String>,
185    pub(crate) include_schema: bool,
186}
187
188impl TableIdentifierCollector {
189    pub(crate) fn collect_from_set_expr(&mut self, set_expr: &SetExpr) {
190        // Recursively collect table identifiers from SetExpr nodes
191        match set_expr {
192            SetExpr::Table(tbl) => {
193                self.tables.extend(if self.include_schema {
194                    match (&tbl.schema_name, &tbl.table_name) {
195                        (Some(schema), Some(table)) => Some(format!("{schema}.{table}")),
196                        (None, Some(table)) => Some(table.clone()),
197                        _ => None,
198                    }
199                } else {
200                    tbl.table_name.clone()
201                });
202            },
203            SetExpr::SetOperation { left, right, .. } => {
204                self.collect_from_set_expr(left);
205                self.collect_from_set_expr(right);
206            },
207            SetExpr::Query(query) => self.collect_from_set_expr(&query.body),
208            _ => {},
209        }
210    }
211}
212
213impl SQLVisitor for TableIdentifierCollector {
214    type Break = ();
215
216    fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
217        // Collect from SetExpr nodes in the query body
218        self.collect_from_set_expr(&query.body);
219        ControlFlow::Continue(())
220    }
221
222    fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
223        // Table relation (eg: appearing in FROM clause)
224        self.tables.extend(if self.include_schema {
225            let parts: Vec<_> = relation
226                .0
227                .iter()
228                .filter_map(|p| p.as_ident().map(|i| i.value.as_str()))
229                .collect();
230            (!parts.is_empty()).then(|| parts.join("."))
231        } else {
232            relation
233                .0
234                .last()
235                .and_then(|p| p.as_ident())
236                .map(|i| i.value.clone())
237        });
238        ControlFlow::Continue(())
239    }
240}
241
242// ---------------------------------------------------------------------------
243// WindowFunctionFinder
244// ---------------------------------------------------------------------------
245
246/// Visitor that checks if a SQL expression contains explicit window functions.
247/// Uses early-exit for efficiency when only the boolean result is needed.
248struct WindowFunctionFinder;
249
250impl SQLVisitor for WindowFunctionFinder {
251    type Break = ();
252
253    fn pre_visit_expr(&mut self, expr: &SQLExpr) -> ControlFlow<()> {
254        if matches!(expr, SQLExpr::Function(f) if f.over.is_some()) {
255            ControlFlow::Break(())
256        } else {
257            ControlFlow::Continue(())
258        }
259    }
260}
261
262/// Check if a SQL expression contains explicit window functions.
263pub(crate) fn expr_has_window_functions(expr: &SQLExpr) -> bool {
264    expr.visit(&mut WindowFunctionFinder).is_break()
265}