polars_sql/
sql_expr.rs

1//! Expressions that are supported by the Polars SQL interface.
2//!
3//! This is useful for syntax highlighting
4//!
5//! This module defines:
6//! - all Polars SQL keywords [`all_keywords`]
7//! - all of polars SQL functions [`all_functions`]
8
9use std::fmt::Display;
10use std::ops::Div;
11
12use polars_core::prelude::*;
13use polars_lazy::prelude::*;
14use polars_plan::plans::DynLiteralValue;
15use polars_plan::prelude::typed_lit;
16use polars_time::Duration;
17use rand::Rng;
18use rand::distr::Alphanumeric;
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21use sqlparser::ast::{
22    AccessExpr, BinaryOperator as SQLBinaryOperator, CastFormat, CastKind, DataType as SQLDataType,
23    DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, Interval, Query as Subquery,
24    SelectItem, Subscript, TimezoneInfo, TrimWhereField, TypedString, UnaryOperator,
25    Value as SQLValue, ValueWithSpan,
26};
27use sqlparser::dialect::GenericDialect;
28use sqlparser::parser::{Parser, ParserOptions};
29
30use crate::SQLContext;
31use crate::functions::SQLFunctionVisitor;
32use crate::types::{
33    bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars,
34};
35
36#[inline]
37#[cold]
38#[must_use]
39/// Convert a Display-able error to PolarsError::SQLInterface
40pub fn to_sql_interface_err(err: impl Display) -> PolarsError {
41    PolarsError::SQLInterface(err.to_string().into())
42}
43
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
46/// Categorises the type of (allowed) subquery constraint
47pub enum SubqueryRestriction {
48    /// Subquery must return a single column
49    SingleColumn,
50    // SingleRow,
51    // SingleValue,
52    // Any
53}
54
55/// Recursively walks a SQL Expr to create a polars Expr
56pub(crate) struct SQLExprVisitor<'a> {
57    ctx: &'a mut SQLContext,
58    active_schema: Option<&'a Schema>,
59}
60
61impl SQLExprVisitor<'_> {
62    fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
63        let mut array_elements = Vec::with_capacity(elements.len());
64        for e in elements {
65            let val = match e {
66                SQLExpr::Value(ValueWithSpan { value: v, .. }) => self.visit_any_value(v, None),
67                SQLExpr::UnaryOp { op, expr } => match expr.as_ref() {
68                    SQLExpr::Value(ValueWithSpan { value: v, .. }) => {
69                        self.visit_any_value(v, Some(op))
70                    },
71                    _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
72                },
73                SQLExpr::Array(values) => {
74                    let srs = self.array_expr_to_series(&values.elem)?;
75                    Ok(AnyValue::List(srs))
76                },
77                _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
78            }?
79            .into_static();
80            array_elements.push(val);
81        }
82        Series::from_any_values(PlSmallStr::EMPTY, &array_elements, true)
83    }
84
85    fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
86        match expr {
87            SQLExpr::AllOp {
88                left,
89                compare_op,
90                right,
91            } => self.visit_all(left, compare_op, right),
92            SQLExpr::AnyOp {
93                left,
94                compare_op,
95                right,
96                is_some: _,
97            } => self.visit_any(left, compare_op, right),
98            SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
99            SQLExpr::Between {
100                expr,
101                negated,
102                low,
103                high,
104            } => self.visit_between(expr, *negated, low, high),
105            SQLExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
106            SQLExpr::Cast {
107                kind,
108                expr,
109                data_type,
110                format,
111            } => self.visit_cast(expr, data_type, format, kind),
112            SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
113            SQLExpr::CompoundFieldAccess { root, access_chain } => {
114                // simple subscript access (eg: "array_col[1]")
115                if access_chain.len() == 1 {
116                    match &access_chain[0] {
117                        AccessExpr::Subscript(subscript) => {
118                            return self.visit_subscript(root, subscript);
119                        },
120                        AccessExpr::Dot(_) => {
121                            polars_bail!(SQLSyntax: "dot-notation field access is currently unsupported: {:?}", access_chain[0])
122                        },
123                    }
124                }
125                // chained dot/bracket notation (eg: "struct_col.field[2].foo[0].bar")
126                polars_bail!(SQLSyntax: "complex field access chains are currently unsupported: {:?}", access_chain[0])
127            },
128            SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
129            SQLExpr::Extract {
130                field,
131                syntax: _,
132                expr,
133            } => parse_extract_date_part(self.visit_expr(expr)?, field),
134            SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
135            SQLExpr::Function(function) => self.visit_function(function),
136            SQLExpr::Identifier(ident) => self.visit_identifier(ident),
137            SQLExpr::InList {
138                expr,
139                list,
140                negated,
141            } => {
142                let expr = self.visit_expr(expr)?;
143                let elems = self.visit_array_expr(list, true, Some(&expr))?;
144                let is_in = expr.is_in(elems, false);
145                Ok(if *negated { is_in.not() } else { is_in })
146            },
147            SQLExpr::InSubquery {
148                expr,
149                subquery,
150                negated,
151            } => self.visit_in_subquery(expr, subquery, *negated),
152            SQLExpr::Interval(interval) => Ok(lit(interval_to_duration(interval, true)?)),
153            SQLExpr::IsDistinctFrom(e1, e2) => {
154                Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
155            },
156            SQLExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))),
157            SQLExpr::IsNotDistinctFrom(e1, e2) => {
158                Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?))
159            },
160            SQLExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()),
161            SQLExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()),
162            SQLExpr::IsNotTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true)).not()),
163            SQLExpr::IsNull(expr) => Ok(self.visit_expr(expr)?.is_null()),
164            SQLExpr::IsTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true))),
165            SQLExpr::Like {
166                negated,
167                any,
168                expr,
169                pattern,
170                escape_char,
171            } => {
172                if *any {
173                    polars_bail!(SQLSyntax: "LIKE ANY is not a supported syntax")
174                }
175                let escape_str = escape_char.as_ref().and_then(|v| match v {
176                    SQLValue::SingleQuotedString(s) => Some(s.clone()),
177                    _ => None,
178                });
179                self.visit_like(*negated, expr, pattern, &escape_str, false)
180            },
181            SQLExpr::ILike {
182                negated,
183                any,
184                expr,
185                pattern,
186                escape_char,
187            } => {
188                if *any {
189                    polars_bail!(SQLSyntax: "ILIKE ANY is not a supported syntax")
190                }
191                let escape_str = escape_char.as_ref().and_then(|v| match v {
192                    SQLValue::SingleQuotedString(s) => Some(s.clone()),
193                    _ => None,
194                });
195                self.visit_like(*negated, expr, pattern, &escape_str, true)
196            },
197            SQLExpr::Nested(expr) => self.visit_expr(expr),
198            SQLExpr::Position { expr, r#in } => Ok(
199                // note: SQL is 1-indexed
200                (self
201                    .visit_expr(r#in)?
202                    .str()
203                    .find(self.visit_expr(expr)?, true)
204                    + typed_lit(1u32))
205                .fill_null(typed_lit(0u32)),
206            ),
207            SQLExpr::RLike {
208                // note: parses both RLIKE and REGEXP
209                negated,
210                expr,
211                pattern,
212                regexp: _,
213            } => {
214                let matches = self
215                    .visit_expr(expr)?
216                    .str()
217                    .contains(self.visit_expr(pattern)?, true);
218                Ok(if *negated { matches.not() } else { matches })
219            },
220            SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"),
221            SQLExpr::Substring {
222                expr,
223                substring_from,
224                substring_for,
225                ..
226            } => self.visit_substring(expr, substring_from.as_deref(), substring_for.as_deref()),
227            SQLExpr::Trim {
228                expr,
229                trim_where,
230                trim_what,
231                trim_characters,
232            } => self.visit_trim(expr, trim_where, trim_what, trim_characters),
233            SQLExpr::TypedString(TypedString {
234                data_type,
235                value:
236                    ValueWithSpan {
237                        value: SQLValue::SingleQuotedString(v),
238                        ..
239                    },
240                uses_odbc_syntax: _,
241            }) => match data_type {
242                SQLDataType::Date => {
243                    if is_iso_date(v) {
244                        Ok(lit(v.as_str()).cast(DataType::Date))
245                    } else {
246                        polars_bail!(SQLSyntax: "invalid DATE literal '{}'", v)
247                    }
248                },
249                SQLDataType::Time(None, TimezoneInfo::None) => {
250                    if is_iso_time(v) {
251                        Ok(lit(v.as_str()).str().to_time(StrptimeOptions {
252                            strict: true,
253                            ..Default::default()
254                        }))
255                    } else {
256                        polars_bail!(SQLSyntax: "invalid TIME literal '{}'", v)
257                    }
258                },
259                SQLDataType::Timestamp(None, TimezoneInfo::None) | SQLDataType::Datetime(None) => {
260                    if is_iso_datetime(v) {
261                        Ok(lit(v.as_str()).str().to_datetime(
262                            None,
263                            None,
264                            StrptimeOptions {
265                                strict: true,
266                                ..Default::default()
267                            },
268                            lit("latest"),
269                        ))
270                    } else {
271                        let fn_name = match data_type {
272                            SQLDataType::Timestamp(_, _) => "TIMESTAMP",
273                            SQLDataType::Datetime(_) => "DATETIME",
274                            _ => unreachable!(),
275                        };
276                        polars_bail!(SQLSyntax: "invalid {} literal '{}'", fn_name, v)
277                    }
278                },
279                _ => {
280                    polars_bail!(SQLInterface: "typed literal should be one of DATE, DATETIME, TIME, or TIMESTAMP (found {})", data_type)
281                },
282            },
283            SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
284            SQLExpr::Value(ValueWithSpan { value, .. }) => self.visit_literal(value),
285            SQLExpr::Wildcard(_) => Ok(all().as_expr()),
286            e @ SQLExpr::Case { .. } => self.visit_case_when_then(e),
287            other => {
288                polars_bail!(SQLInterface: "expression {:?} is not currently supported", other)
289            },
290        }
291    }
292
293    fn visit_subquery(
294        &mut self,
295        subquery: &Subquery,
296        restriction: SubqueryRestriction,
297    ) -> PolarsResult<Expr> {
298        if subquery.with.is_some() {
299            polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause");
300        }
301        // note: we have to execute subqueries in an isolated scope to prevent
302        // propagating any context/arena mutation into the rest of the query
303        let (mut lf, schema) = self
304            .ctx
305            .execute_isolated(|ctx| ctx.execute_query_no_ctes(subquery))?;
306
307        if restriction == SubqueryRestriction::SingleColumn {
308            if schema.len() != 1 {
309                polars_bail!(SQLSyntax: "SQL subquery returns more than one column");
310            }
311            let rand_string: String = rand::rng()
312                .sample_iter(&Alphanumeric)
313                .take(16)
314                .map(char::from)
315                .collect();
316
317            let schema_entry = schema.get_at_index(0);
318            if let Some((old_name, _)) = schema_entry {
319                let new_name = String::from(old_name.as_str()) + rand_string.as_str();
320                lf = lf.rename([old_name.to_string()], [new_name.clone()], true);
321                return Ok(Expr::SubPlan(
322                    SpecialEq::new(Arc::new(lf.logical_plan)),
323                    vec![new_name],
324                ));
325            }
326        };
327        polars_bail!(SQLInterface: "subquery type not supported");
328    }
329
330    /// Visit a single SQL identifier.
331    ///
332    /// e.g. column
333    fn visit_identifier(&self, ident: &Ident) -> PolarsResult<Expr> {
334        Ok(col(ident.value.as_str()))
335    }
336
337    /// Visit a compound SQL identifier
338    ///
339    /// e.g. tbl.column, struct.field, tbl.struct.field (inc. nested struct fields)
340    fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult<Expr> {
341        Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone())
342    }
343
344    fn visit_like(
345        &mut self,
346        negated: bool,
347        expr: &SQLExpr,
348        pattern: &SQLExpr,
349        escape_char: &Option<String>,
350        case_insensitive: bool,
351    ) -> PolarsResult<Expr> {
352        if escape_char.is_some() {
353            polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not currently supported; found '{}'", escape_char.clone().unwrap());
354        }
355        let pat = match self.visit_expr(pattern) {
356            Ok(Expr::Literal(lv)) if lv.extract_str().is_some() => {
357                PlSmallStr::from_str(lv.extract_str().unwrap())
358            },
359            _ => {
360                polars_bail!(SQLSyntax: "LIKE/ILIKE pattern must be a string literal; found {}", pattern)
361            },
362        };
363        if pat.is_empty() || (!case_insensitive && pat.chars().all(|c| !matches!(c, '%' | '_'))) {
364            // empty string or other exact literal match (eg: no wildcard chars)
365            let op = if negated {
366                SQLBinaryOperator::NotEq
367            } else {
368                SQLBinaryOperator::Eq
369            };
370            self.visit_binary_op(expr, &op, pattern)
371        } else {
372            // create regex from pattern containing SQL wildcard chars ('%' => '.*', '_' => '.')
373            let mut rx = regex::escape(pat.as_str())
374                .replace('%', ".*")
375                .replace('_', ".");
376
377            rx = format!(
378                "^{}{}$",
379                if case_insensitive { "(?is)" } else { "(?s)" },
380                rx
381            );
382
383            let expr = self.visit_expr(expr)?;
384            let matches = expr.str().contains(lit(rx), true);
385            Ok(if negated { matches.not() } else { matches })
386        }
387    }
388
389    fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult<Expr> {
390        let expr = self.visit_expr(expr)?;
391        Ok(match subscript {
392            Subscript::Index { index } => {
393                let idx = adjust_one_indexed_param(self.visit_expr(index)?, true);
394                expr.list().get(idx, true)
395            },
396            Subscript::Slice { .. } => {
397                polars_bail!(SQLSyntax: "array slice syntax is not currently supported")
398            },
399        })
400    }
401
402    /// Handle implicit temporal string comparisons.
403    ///
404    /// eg: clauses such as -
405    ///   "dt >= '2024-04-30'"
406    ///   "dt = '2077-10-10'::date"
407    ///   "dtm::date = '2077-10-10'
408    fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
409        if let (Some(name), Some(s), expr_dtype) = match (left, right) {
410            // identify "col <op> string" expressions
411            (Expr::Column(name), Expr::Literal(lv)) if lv.extract_str().is_some() => {
412                (Some(name.clone()), Some(lv.extract_str().unwrap()), None)
413            },
414            // identify "CAST(expr AS type) <op> string" and/or "expr::type <op> string" expressions
415            (Expr::Cast { expr, dtype, .. }, Expr::Literal(lv)) if lv.extract_str().is_some() => {
416                let s = lv.extract_str().unwrap();
417                match &**expr {
418                    Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)),
419                    _ => (None, Some(s), Some(dtype)),
420                }
421            },
422            _ => (None, None, None),
423        } {
424            if expr_dtype.is_none() && self.active_schema.is_none() {
425                right.clone()
426            } else {
427                let left_dtype = expr_dtype.map_or_else(
428                    || {
429                        self.active_schema
430                            .as_ref()
431                            .and_then(|schema| schema.get(&name))
432                    },
433                    |dt| dt.as_literal(),
434                );
435                match left_dtype {
436                    Some(DataType::Time) if is_iso_time(s) => {
437                        right.clone().str().to_time(StrptimeOptions {
438                            strict: true,
439                            ..Default::default()
440                        })
441                    },
442                    Some(DataType::Date) if is_iso_date(s) => {
443                        right.clone().str().to_date(StrptimeOptions {
444                            strict: true,
445                            ..Default::default()
446                        })
447                    },
448                    Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => {
449                        if s.len() == 10 {
450                            // handle upcast from ISO date string (10 chars) to datetime
451                            lit(format!("{s}T00:00:00"))
452                        } else {
453                            lit(s.replacen(' ', "T", 1))
454                        }
455                        .str()
456                        .to_datetime(
457                            Some(*tu),
458                            tz.clone(),
459                            StrptimeOptions {
460                                strict: true,
461                                ..Default::default()
462                            },
463                            lit("latest"),
464                        )
465                    },
466                    _ => right.clone(),
467                }
468            }
469        } else {
470            right.clone()
471        }
472    }
473
474    fn struct_field_access_expr(
475        &mut self,
476        expr: &Expr,
477        path: &str,
478        infer_index: bool,
479    ) -> PolarsResult<Expr> {
480        let path_elems = if path.starts_with('{') && path.ends_with('}') {
481            path.trim_matches(|c| c == '{' || c == '}')
482        } else {
483            path
484        }
485        .split(',');
486
487        let mut expr = expr.clone();
488        for p in path_elems {
489            let p = p.trim();
490            expr = if infer_index {
491                match p.parse::<i64>() {
492                    Ok(idx) => expr.list().get(lit(idx), true),
493                    Err(_) => expr.struct_().field_by_name(p),
494                }
495            } else {
496                expr.struct_().field_by_name(p)
497            }
498        }
499        Ok(expr)
500    }
501
502    /// Visit a SQL binary operator.
503    ///
504    /// e.g. "column + 1", "column1 <= column2"
505    fn visit_binary_op(
506        &mut self,
507        left: &SQLExpr,
508        op: &SQLBinaryOperator,
509        right: &SQLExpr,
510    ) -> PolarsResult<Expr> {
511        // check for (unsupported) scalar subquery comparisons
512        if matches!(left, SQLExpr::Subquery(_)) || matches!(right, SQLExpr::Subquery(_)) {
513            let (suggestion, str_op) = match op {
514                SQLBinaryOperator::NotEq => ("; use 'NOT IN' instead", "!=".to_string()),
515                SQLBinaryOperator::Eq => ("; use 'IN' instead", format!("{op}")),
516                _ => ("", format!("{op}")),
517            };
518            polars_bail!(
519                SQLSyntax: "subquery comparisons with '{str_op}' are not supported{suggestion}"
520            );
521        }
522
523        // need special handling for interval offsets and comparisons
524        let (lhs, mut rhs) = match (left, op, right) {
525            (_, SQLBinaryOperator::Minus, SQLExpr::Interval(v)) => {
526                let duration = interval_to_duration(v, false)?;
527                return Ok(self
528                    .visit_expr(left)?
529                    .dt()
530                    .offset_by(lit(format!("-{duration}"))));
531            },
532            (_, SQLBinaryOperator::Plus, SQLExpr::Interval(v)) => {
533                let duration = interval_to_duration(v, false)?;
534                return Ok(self
535                    .visit_expr(left)?
536                    .dt()
537                    .offset_by(lit(format!("{duration}"))));
538            },
539            (SQLExpr::Interval(v1), _, SQLExpr::Interval(v2)) => {
540                // shortcut interval comparison evaluation (-> bool)
541                let d1 = interval_to_duration(v1, false)?;
542                let d2 = interval_to_duration(v2, false)?;
543                let res = match op {
544                    SQLBinaryOperator::Gt => Ok(lit(d1 > d2)),
545                    SQLBinaryOperator::Lt => Ok(lit(d1 < d2)),
546                    SQLBinaryOperator::GtEq => Ok(lit(d1 >= d2)),
547                    SQLBinaryOperator::LtEq => Ok(lit(d1 <= d2)),
548                    SQLBinaryOperator::NotEq => Ok(lit(d1 != d2)),
549                    SQLBinaryOperator::Eq | SQLBinaryOperator::Spaceship => Ok(lit(d1 == d2)),
550                    _ => polars_bail!(SQLInterface: "invalid interval comparison operator"),
551                };
552                if res.is_ok() {
553                    return res;
554                }
555                (self.visit_expr(left)?, self.visit_expr(right)?)
556            },
557            _ => (self.visit_expr(left)?, self.visit_expr(right)?),
558        };
559        rhs = self.convert_temporal_strings(&lhs, &rhs);
560
561        Ok(match op {
562            // ----
563            // Bitwise operators
564            // ----
565            SQLBinaryOperator::BitwiseAnd => lhs.and(rhs),  // "x & y"
566            SQLBinaryOperator::BitwiseOr => lhs.or(rhs),  // "x | y"
567            SQLBinaryOperator::Xor => lhs.xor(rhs),  // "x XOR y"
568
569            // ----
570            // General operators
571            // ----
572            SQLBinaryOperator::And => lhs.and(rhs),  // "x AND y"
573            SQLBinaryOperator::Divide => lhs / rhs,  // "x / y"
574            SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64),  // "x // y"
575            SQLBinaryOperator::Eq => lhs.eq(rhs),  // "x = y"
576            SQLBinaryOperator::Gt => lhs.gt(rhs),  // "x > y"
577            SQLBinaryOperator::GtEq => lhs.gt_eq(rhs),  // "x >= y"
578            SQLBinaryOperator::Lt => lhs.lt(rhs),  // "x < y"
579            SQLBinaryOperator::LtEq => lhs.lt_eq(rhs),  // "x <= y"
580            SQLBinaryOperator::Minus => lhs - rhs,  // "x - y"
581            SQLBinaryOperator::Modulo => lhs % rhs,  // "x % y"
582            SQLBinaryOperator::Multiply => lhs * rhs,  // "x * y"
583            SQLBinaryOperator::NotEq => lhs.eq(rhs).not(),  // "x != y"
584            SQLBinaryOperator::Or => lhs.or(rhs),  // "x OR y"
585            SQLBinaryOperator::Plus => lhs + rhs,  // "x + y"
586            SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs),  // "x <=> y"
587            SQLBinaryOperator::StringConcat => {  // "x || y"
588                lhs.cast(DataType::String) + rhs.cast(DataType::String)
589            },
590            SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs),  // "x ^@ y"
591            // ----
592            // Regular expression operators
593            // ----
594            SQLBinaryOperator::PGRegexMatch => match rhs {  // "x ~ y"
595                Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true),
596                _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
597            },
598            SQLBinaryOperator::PGRegexNotMatch => match rhs {  // "x !~ y"
599                Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true).not(),
600                _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
601            },
602            SQLBinaryOperator::PGRegexIMatch => match rhs {  // "x ~* y"
603                Expr::Literal(ref lv) if lv.extract_str().is_some() => {
604                    let pat = lv.extract_str().unwrap();
605                    lhs.str().contains(lit(format!("(?i){pat}")), true)
606                },
607                _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
608            },
609            SQLBinaryOperator::PGRegexNotIMatch => match rhs {  // "x !~* y"
610                Expr::Literal(ref lv) if lv.extract_str().is_some() => {
611                    let pat = lv.extract_str().unwrap();
612                    lhs.str().contains(lit(format!("(?i){pat}")), true).not()
613                },
614                _ => {
615                    polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs)
616                },
617            },
618            // ----
619            // LIKE/ILIKE operators
620            // ----
621            SQLBinaryOperator::PGLikeMatch  // "x ~~ y"
622            | SQLBinaryOperator::PGNotLikeMatch  // "x !~~ y"
623            | SQLBinaryOperator::PGILikeMatch  // "x ~~* y"
624            | SQLBinaryOperator::PGNotILikeMatch => {  // "x !~~* y"
625                let expr = if matches!(
626                    op,
627                    SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
628                ) {
629                    SQLExpr::Like {
630                        negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch),
631                        any: false,
632                        expr: Box::new(left.clone()),
633                        pattern: Box::new(right.clone()),
634                        escape_char: None,
635                    }
636                } else {
637                    SQLExpr::ILike {
638                        negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch),
639                        any: false,
640                        expr: Box::new(left.clone()),
641                        pattern: Box::new(right.clone()),
642                        escape_char: None,
643                    }
644                };
645                self.visit_expr(&expr)?
646            },
647            // ----
648            // JSON/Struct field access operators
649            // ----
650            SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs {  // "x -> y", "x ->> y"
651                Expr::Literal(lv) if lv.extract_str().is_some() => {
652                    let path = lv.extract_str().unwrap();
653                    let mut expr = self.struct_field_access_expr(&lhs, path, false)?;
654                    if let SQLBinaryOperator::LongArrow = op {
655                        expr = expr.cast(DataType::String);
656                    }
657                    expr
658                },
659                Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(idx))) => {
660                    let mut expr = self.struct_field_access_expr(&lhs, &idx.to_string(), true)?;
661                    if let SQLBinaryOperator::LongArrow = op {
662                        expr = expr.cast(DataType::String);
663                    }
664                    expr
665                },
666                _ => {
667                    polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
668                },
669            },
670            SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => {  // "x #> y", "x #>> y"
671                match rhs {
672                    Expr::Literal(lv) if lv.extract_str().is_some() => {
673                        let path = lv.extract_str().unwrap();
674                        let mut expr = self.struct_field_access_expr(&lhs, path, true)?;
675                        if let SQLBinaryOperator::HashLongArrow = op {
676                            expr = expr.cast(DataType::String);
677                        }
678                        expr
679                    },
680                    _ => {
681                        polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", rhs)
682                    }
683                }
684            },
685            other => {
686                polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
687            },
688        })
689    }
690
691    /// Visit a SQL unary operator.
692    ///
693    /// e.g. +column or -column
694    fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
695        let expr = self.visit_expr(expr)?;
696        Ok(match (op, expr.clone()) {
697            // simplify the parse tree by special-casing common unary +/- ops
698            (UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
699                lit(n)
700            },
701            (UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
702                lit(n)
703            },
704            (UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
705                lit(-n)
706            },
707            (UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
708                lit(-n)
709            },
710            // general case
711            (UnaryOperator::Plus, _) => lit(0) + expr,
712            (UnaryOperator::Minus, _) => lit(0) - expr,
713            (UnaryOperator::Not, _) => match &expr {
714                Expr::Column(name)
715                    if self
716                        .active_schema
717                        .and_then(|schema| schema.get(name))
718                        .is_some_and(|dtype| matches!(dtype, DataType::Boolean)) =>
719                {
720                    // if already boolean, can operate bitwise
721                    expr.not()
722                },
723                // otherwise SQL "NOT" expects logical, not bitwise, behaviour (eg: on integers)
724                _ => expr.strict_cast(DataType::Boolean).not(),
725            },
726            other => polars_bail!(SQLInterface: "unary operator {:?} is not supported", other),
727        })
728    }
729
730    /// Visit a SQL function.
731    ///
732    /// e.g. SUM(column) or COUNT(*)
733    ///
734    /// See [SQLFunctionVisitor] for more details
735    fn visit_function(&mut self, function: &SQLFunction) -> PolarsResult<Expr> {
736        let mut visitor = SQLFunctionVisitor {
737            func: function,
738            ctx: self.ctx,
739            active_schema: self.active_schema,
740        };
741        visitor.visit_function()
742    }
743
744    /// Visit a SQL `ALL` expression.
745    ///
746    /// e.g. `a > ALL(y)`
747    fn visit_all(
748        &mut self,
749        left: &SQLExpr,
750        compare_op: &SQLBinaryOperator,
751        right: &SQLExpr,
752    ) -> PolarsResult<Expr> {
753        let left = self.visit_expr(left)?;
754        let right = self.visit_expr(right)?;
755
756        match compare_op {
757            SQLBinaryOperator::Gt => Ok(left.gt(right.max())),
758            SQLBinaryOperator::Lt => Ok(left.lt(right.min())),
759            SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.max())),
760            SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.min())),
761            SQLBinaryOperator::Eq => polars_bail!(SQLSyntax: "ALL cannot be used with ="),
762            SQLBinaryOperator::NotEq => polars_bail!(SQLSyntax: "ALL cannot be used with !="),
763            _ => polars_bail!(SQLInterface: "invalid comparison operator"),
764        }
765    }
766
767    /// Visit a SQL `ANY` expression.
768    ///
769    /// e.g. `a != ANY(y)`
770    fn visit_any(
771        &mut self,
772        left: &SQLExpr,
773        compare_op: &SQLBinaryOperator,
774        right: &SQLExpr,
775    ) -> PolarsResult<Expr> {
776        let left = self.visit_expr(left)?;
777        let right = self.visit_expr(right)?;
778
779        match compare_op {
780            SQLBinaryOperator::Gt => Ok(left.gt(right.min())),
781            SQLBinaryOperator::Lt => Ok(left.lt(right.max())),
782            SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.min())),
783            SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.max())),
784            SQLBinaryOperator::Eq => Ok(left.is_in(right, false)),
785            SQLBinaryOperator::NotEq => Ok(left.is_in(right, false).not()),
786            _ => polars_bail!(SQLInterface: "invalid comparison operator"),
787        }
788    }
789
790    /// Visit a SQL `ARRAY` list (including `IN` values).
791    fn visit_array_expr(
792        &mut self,
793        elements: &[SQLExpr],
794        result_as_element: bool,
795        dtype_expr_match: Option<&Expr>,
796    ) -> PolarsResult<Expr> {
797        let mut elems = self.array_expr_to_series(elements)?;
798
799        // handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
800        // (not yet as versatile as the temporal string conversions in visit_binary_op)
801        if let (Some(Expr::Column(name)), Some(schema)) =
802            (dtype_expr_match, self.active_schema.as_ref())
803        {
804            if elems.dtype() == &DataType::String {
805                if let Some(dtype) = schema.get(name) {
806                    if matches!(
807                        dtype,
808                        DataType::Date | DataType::Time | DataType::Datetime(_, _)
809                    ) {
810                        elems = elems.strict_cast(dtype)?;
811                    }
812                }
813            }
814        }
815
816        // if we are parsing the list as an element in a series, implode.
817        // otherwise, return the series as-is.
818        let res = if result_as_element {
819            elems.implode()?.into_series()
820        } else {
821            elems
822        };
823        Ok(lit(res))
824    }
825
826    /// Visit a SQL `CAST` or `TRY_CAST` expression.
827    ///
828    /// e.g. `CAST(col AS INT)`, `col::int4`, or `TRY_CAST(col AS VARCHAR)`,
829    fn visit_cast(
830        &mut self,
831        expr: &SQLExpr,
832        dtype: &SQLDataType,
833        format: &Option<CastFormat>,
834        cast_kind: &CastKind,
835    ) -> PolarsResult<Expr> {
836        if format.is_some() {
837            return Err(
838                polars_err!(SQLInterface: "use of FORMAT is not currently supported in CAST"),
839            );
840        }
841        let expr = self.visit_expr(expr)?;
842
843        #[cfg(feature = "json")]
844        if dtype == &SQLDataType::JSON {
845            // @BROKEN: we cannot handle this.
846            return Ok(expr.str().json_decode(DataType::Struct(Vec::new())));
847        }
848        let polars_type = map_sql_dtype_to_polars(dtype)?;
849        Ok(match cast_kind {
850            CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type),
851            CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type),
852        })
853    }
854
855    /// Visit a SQL literal.
856    ///
857    /// e.g. 1, 'foo', 1.0, NULL
858    ///
859    /// See [SQLValue] and [LiteralValue] for more details
860    fn visit_literal(&self, value: &SQLValue) -> PolarsResult<Expr> {
861        // note: double-quoted strings will be parsed as identifiers, not literals
862        Ok(match value {
863            SQLValue::Boolean(b) => lit(*b),
864            SQLValue::DollarQuotedString(s) => lit(s.value.clone()),
865            #[cfg(feature = "binary_encoding")]
866            SQLValue::HexStringLiteral(x) => {
867                if x.len() % 2 != 0 {
868                    polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
869                };
870                lit(hex::decode(x.clone()).unwrap())
871            },
872            SQLValue::Null => Expr::Literal(LiteralValue::untyped_null()),
873            SQLValue::Number(s, _) => {
874                // Check for existence of decimal separator dot
875                if s.contains('.') {
876                    s.parse::<f64>().map(lit).map_err(|_| ())
877                } else {
878                    s.parse::<i64>().map(lit).map_err(|_| ())
879                }
880                .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
881            },
882            SQLValue::SingleQuotedByteStringLiteral(b) => {
883                // note: for PostgreSQL this represents a BIT string literal (eg: b'10101') not a BYTE string
884                // literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser-rs
885                // patterned the token name after BigQuery (where b'str' really IS a byte string)
886                bitstring_to_bytes_literal(b)?
887            },
888            SQLValue::SingleQuotedString(s) => lit(s.clone()),
889            other => {
890                polars_bail!(SQLInterface: "value {:?} is not a supported literal type", other)
891            },
892        })
893    }
894
895    /// Visit a SQL literal (like [visit_literal]), but return AnyValue instead of Expr.
896    fn visit_any_value(
897        &self,
898        value: &SQLValue,
899        op: Option<&UnaryOperator>,
900    ) -> PolarsResult<AnyValue<'_>> {
901        Ok(match value {
902            SQLValue::Boolean(b) => AnyValue::Boolean(*b),
903            SQLValue::DollarQuotedString(s) => AnyValue::StringOwned(s.clone().value.into()),
904            #[cfg(feature = "binary_encoding")]
905            SQLValue::HexStringLiteral(x) => {
906                if x.len() % 2 != 0 {
907                    polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
908                };
909                AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap())
910            },
911            SQLValue::Null => AnyValue::Null,
912            SQLValue::Number(s, _) => {
913                let negate = match op {
914                    Some(UnaryOperator::Minus) => true,
915                    // no op should be taken as plus.
916                    Some(UnaryOperator::Plus) | None => false,
917                    Some(op) => {
918                        polars_bail!(SQLInterface: "unary op {:?} not supported for numeric SQL value", op)
919                    },
920                };
921                // Check for existence of decimal separator dot
922                if s.contains('.') {
923                    s.parse::<f64>()
924                        .map(|n: f64| AnyValue::Float64(if negate { -n } else { n }))
925                        .map_err(|_| ())
926                } else {
927                    s.parse::<i64>()
928                        .map(|n: i64| AnyValue::Int64(if negate { -n } else { n }))
929                        .map_err(|_| ())
930                }
931                .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
932            },
933            SQLValue::SingleQuotedByteStringLiteral(b) => {
934                // note: for PostgreSQL this represents a BIT literal (eg: b'10101') not BYTE
935                let bytes_literal = bitstring_to_bytes_literal(b)?;
936                match bytes_literal {
937                    Expr::Literal(lv) if lv.extract_binary().is_some() => {
938                        AnyValue::BinaryOwned(lv.extract_binary().unwrap().to_vec())
939                    },
940                    _ => {
941                        polars_bail!(SQLInterface: "failed to parse bitstring literal: {:?}", b)
942                    },
943                }
944            },
945            SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.as_str().into()),
946            other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other),
947        })
948    }
949
950    /// Visit a SQL `BETWEEN` expression.
951    /// See [sqlparser::ast::Expr::Between] for more details
952    fn visit_between(
953        &mut self,
954        expr: &SQLExpr,
955        negated: bool,
956        low: &SQLExpr,
957        high: &SQLExpr,
958    ) -> PolarsResult<Expr> {
959        let expr = self.visit_expr(expr)?;
960        let low = self.visit_expr(low)?;
961        let high = self.visit_expr(high)?;
962
963        let low = self.convert_temporal_strings(&expr, &low);
964        let high = self.convert_temporal_strings(&expr, &high);
965        Ok(if negated {
966            expr.clone().lt(low).or(expr.gt(high))
967        } else {
968            expr.clone().gt_eq(low).and(expr.lt_eq(high))
969        })
970    }
971
972    /// Visit a SQL `TRIM` function.
973    /// See [sqlparser::ast::Expr::Trim] for more details
974    fn visit_trim(
975        &mut self,
976        expr: &SQLExpr,
977        trim_where: &Option<TrimWhereField>,
978        trim_what: &Option<Box<SQLExpr>>,
979        trim_characters: &Option<Vec<SQLExpr>>,
980    ) -> PolarsResult<Expr> {
981        if trim_characters.is_some() {
982            // TODO: allow compact snowflake/bigquery syntax?
983            return Err(polars_err!(SQLSyntax: "unsupported TRIM syntax (custom chars)"));
984        };
985        let expr = self.visit_expr(expr)?;
986        let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
987        let trim_what = match trim_what {
988            Some(Expr::Literal(lv)) if lv.extract_str().is_some() => {
989                Some(PlSmallStr::from_str(lv.extract_str().unwrap()))
990            },
991            None => None,
992            _ => return self.err(&expr),
993        };
994        Ok(match (trim_where, trim_what) {
995            (None | Some(TrimWhereField::Both), None) => {
996                expr.str().strip_chars(lit(LiteralValue::untyped_null()))
997            },
998            (None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
999            (Some(TrimWhereField::Leading), None) => expr
1000                .str()
1001                .strip_chars_start(lit(LiteralValue::untyped_null())),
1002            (Some(TrimWhereField::Leading), Some(val)) => expr.str().strip_chars_start(lit(val)),
1003            (Some(TrimWhereField::Trailing), None) => expr
1004                .str()
1005                .strip_chars_end(lit(LiteralValue::untyped_null())),
1006            (Some(TrimWhereField::Trailing), Some(val)) => expr.str().strip_chars_end(lit(val)),
1007        })
1008    }
1009
1010    fn visit_substring(
1011        &mut self,
1012        expr: &SQLExpr,
1013        substring_from: Option<&SQLExpr>,
1014        substring_for: Option<&SQLExpr>,
1015    ) -> PolarsResult<Expr> {
1016        let e = self.visit_expr(expr)?;
1017
1018        match (substring_from, substring_for) {
1019            // SUBSTRING(expr FROM start FOR length)
1020            (Some(from_expr), Some(for_expr)) => {
1021                let start = self.visit_expr(from_expr)?;
1022                let length = self.visit_expr(for_expr)?;
1023
1024                // note: SQL is 1-indexed, so we need to adjust the offsets accordingly
1025                Ok(match (start.clone(), length.clone()) {
1026                    (Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1027                    (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1028                        polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", n)
1029                    },
1030                    (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => {
1031                        e.str().slice(lit(n - 1), length)
1032                    },
1033                    (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => e
1034                        .str()
1035                        .slice(lit(0), (length + lit(n - 1)).clip_min(lit(0))),
1036                    (Expr::Literal(_), _) => {
1037                        polars_bail!(SQLSyntax: "invalid 'start' for SUBSTRING")
1038                    },
1039                    (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1040                        polars_bail!(SQLSyntax: "invalid 'length' for SUBSTRING")
1041                    },
1042                    _ => {
1043                        let adjusted_start = start - lit(1);
1044                        when(adjusted_start.clone().lt(lit(0)))
1045                            .then(e.clone().str().slice(
1046                                lit(0),
1047                                (length.clone() + adjusted_start.clone()).clip_min(lit(0)),
1048                            ))
1049                            .otherwise(e.str().slice(adjusted_start, length))
1050                    },
1051                })
1052            },
1053            // SUBSTRING(expr FROM start)
1054            (Some(from_expr), None) => {
1055                let start = self.visit_expr(from_expr)?;
1056
1057                Ok(match start {
1058                    Expr::Literal(lv) if lv.is_null() => lit(lv),
1059                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1060                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1061                        e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null()))
1062                    },
1063                    Expr::Literal(_) => {
1064                        polars_bail!(SQLSyntax: "invalid 'start' for SUBSTRING")
1065                    },
1066                    _ => e
1067                        .str()
1068                        .slice(start - lit(1), lit(LiteralValue::untyped_null())),
1069                })
1070            },
1071            // SUBSTRING(expr) - not valid, but handle gracefully
1072            (None, _) => {
1073                polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found 1)")
1074            },
1075        }
1076    }
1077
1078    /// Visit a SQL subquery inside an `IN` expression.
1079    fn visit_in_subquery(
1080        &mut self,
1081        expr: &SQLExpr,
1082        subquery: &Subquery,
1083        negated: bool,
1084    ) -> PolarsResult<Expr> {
1085        let subquery_result = self
1086            .visit_subquery(subquery, SubqueryRestriction::SingleColumn)?
1087            .implode();
1088        let expr = self.visit_expr(expr)?;
1089        Ok(if negated {
1090            expr.is_in(subquery_result, false).not()
1091        } else {
1092            expr.is_in(subquery_result, false)
1093        })
1094    }
1095
1096    /// Visit `CASE` control flow expression.
1097    fn visit_case_when_then(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
1098        if let SQLExpr::Case {
1099            case_token: _,
1100            end_token: _,
1101            operand,
1102            conditions,
1103            else_result,
1104        } = expr
1105        {
1106            polars_ensure!(
1107                !conditions.is_empty(),
1108                SQLSyntax: "WHEN and THEN expressions must have at least one element"
1109            );
1110
1111            let mut when_thens = conditions.iter();
1112            let first = when_thens.next();
1113            if first.is_none() {
1114                polars_bail!(SQLSyntax: "WHEN and THEN expressions must have at least one element");
1115            }
1116            let else_res = match else_result {
1117                Some(else_res) => self.visit_expr(else_res)?,
1118                None => lit(LiteralValue::untyped_null()), // ELSE clause is optional; when omitted, it is implicitly NULL
1119            };
1120            if let Some(operand_expr) = operand {
1121                let first_operand_expr = self.visit_expr(operand_expr)?;
1122
1123                let first = first.unwrap();
1124                let first_cond = first_operand_expr.eq(self.visit_expr(&first.condition)?);
1125                let first_then = self.visit_expr(&first.result)?;
1126                let expr = when(first_cond).then(first_then);
1127                let next = when_thens.next();
1128
1129                let mut when_then = if let Some(case_when) = next {
1130                    let second_operand_expr = self.visit_expr(operand_expr)?;
1131                    let cond = second_operand_expr.eq(self.visit_expr(&case_when.condition)?);
1132                    let res = self.visit_expr(&case_when.result)?;
1133                    expr.when(cond).then(res)
1134                } else {
1135                    return Ok(expr.otherwise(else_res));
1136                };
1137                for case_when in when_thens {
1138                    let new_operand_expr = self.visit_expr(operand_expr)?;
1139                    let cond = new_operand_expr.eq(self.visit_expr(&case_when.condition)?);
1140                    let res = self.visit_expr(&case_when.result)?;
1141                    when_then = when_then.when(cond).then(res);
1142                }
1143                return Ok(when_then.otherwise(else_res));
1144            }
1145
1146            let first = first.unwrap();
1147            let first_cond = self.visit_expr(&first.condition)?;
1148            let first_then = self.visit_expr(&first.result)?;
1149            let expr = when(first_cond).then(first_then);
1150            let next = when_thens.next();
1151
1152            let mut when_then = if let Some(case_when) = next {
1153                let cond = self.visit_expr(&case_when.condition)?;
1154                let res = self.visit_expr(&case_when.result)?;
1155                expr.when(cond).then(res)
1156            } else {
1157                return Ok(expr.otherwise(else_res));
1158            };
1159            for case_when in when_thens {
1160                let cond = self.visit_expr(&case_when.condition)?;
1161                let res = self.visit_expr(&case_when.result)?;
1162                when_then = when_then.when(cond).then(res);
1163            }
1164            Ok(when_then.otherwise(else_res))
1165        } else {
1166            unreachable!()
1167        }
1168    }
1169
1170    fn err(&self, expr: &Expr) -> PolarsResult<Expr> {
1171        polars_bail!(SQLInterface: "expression {:?} is not currently supported", expr);
1172    }
1173}
1174
1175/// parse a SQL expression to a polars expression
1176/// # Example
1177/// ```rust
1178/// # use polars_sql::{SQLContext, sql_expr};
1179/// # use polars_core::prelude::*;
1180/// # use polars_lazy::prelude::*;
1181/// # fn main() {
1182///
1183/// let mut ctx = SQLContext::new();
1184/// let df = df! {
1185///    "a" =>  [1, 2, 3],
1186/// }
1187/// .unwrap();
1188/// let expr = sql_expr("MAX(a)").unwrap();
1189/// df.lazy().select(vec![expr]).collect().unwrap();
1190/// # }
1191/// ```
1192pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
1193    let mut ctx = SQLContext::new();
1194
1195    let mut parser = Parser::new(&GenericDialect);
1196    parser = parser.with_options(ParserOptions {
1197        trailing_commas: true,
1198        ..Default::default()
1199    });
1200
1201    let mut ast = parser
1202        .try_with_sql(s.as_ref())
1203        .map_err(to_sql_interface_err)?;
1204    let expr = ast.parse_select_item().map_err(to_sql_interface_err)?;
1205
1206    Ok(match &expr {
1207        SelectItem::ExprWithAlias { expr, alias } => {
1208            let expr = parse_sql_expr(expr, &mut ctx, None)?;
1209            expr.alias(alias.value.as_str())
1210        },
1211        SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
1212        _ => polars_bail!(SQLInterface: "unable to parse '{}' as Expr", s.as_ref()),
1213    })
1214}
1215
1216pub(crate) fn interval_to_duration(interval: &Interval, fixed: bool) -> PolarsResult<Duration> {
1217    if interval.last_field.is_some()
1218        || interval.leading_field.is_some()
1219        || interval.leading_precision.is_some()
1220        || interval.fractional_seconds_precision.is_some()
1221    {
1222        polars_bail!(SQLSyntax: "unsupported interval syntax ('{}')", interval)
1223    }
1224    let s = match &*interval.value {
1225        SQLExpr::UnaryOp { .. } => {
1226            polars_bail!(SQLSyntax: "unary ops are not valid on interval strings; found {}", interval.value)
1227        },
1228        SQLExpr::Value(ValueWithSpan {
1229            value: SQLValue::SingleQuotedString(s),
1230            ..
1231        }) => Some(s),
1232        _ => None,
1233    };
1234    match s {
1235        Some(s) if s.contains('-') => {
1236            polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s)
1237        },
1238        Some(s) => {
1239            // years, quarters, and months do not have a fixed duration; these
1240            // interval parts can only be used with respect to a reference point
1241            let duration = Duration::parse_interval(s);
1242            if fixed && duration.months() != 0 {
1243                polars_bail!(SQLSyntax: "fixed-duration interval cannot contain years, quarters, or months; found {}", s)
1244            };
1245            Ok(duration)
1246        },
1247        None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval),
1248    }
1249}
1250
1251pub(crate) fn parse_sql_expr(
1252    expr: &SQLExpr,
1253    ctx: &mut SQLContext,
1254    active_schema: Option<&Schema>,
1255) -> PolarsResult<Expr> {
1256    let mut visitor = SQLExprVisitor { ctx, active_schema };
1257    visitor.visit_expr(expr)
1258}
1259
1260pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
1261    match expr {
1262        SQLExpr::Array(arr) => {
1263            let mut visitor = SQLExprVisitor {
1264                ctx,
1265                active_schema: None,
1266            };
1267            visitor.array_expr_to_series(arr.elem.as_slice())
1268        },
1269        _ => polars_bail!(SQLSyntax: "Expected array expression, found {:?}", expr),
1270    }
1271}
1272
1273pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
1274    let field = match field {
1275        // handle 'DATE_PART' and all valid abbreviations/alternates
1276        DateTimeField::Custom(Ident { value, .. }) => {
1277            let value = value.to_ascii_lowercase();
1278            match value.as_str() {
1279                "millennium" | "millennia" => &DateTimeField::Millennium,
1280                "century" | "centuries" => &DateTimeField::Century,
1281                "decade" | "decades" => &DateTimeField::Decade,
1282                "isoyear" => &DateTimeField::Isoyear,
1283                "year" | "years" | "y" => &DateTimeField::Year,
1284                "quarter" | "quarters" => &DateTimeField::Quarter,
1285                "month" | "months" | "mon" | "mons" => &DateTimeField::Month,
1286                "dayofyear" | "doy" => &DateTimeField::DayOfYear,
1287                "dayofweek" | "dow" => &DateTimeField::DayOfWeek,
1288                "isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek,
1289                "isodow" => &DateTimeField::Isodow,
1290                "day" | "days" | "d" => &DateTimeField::Day,
1291                "hour" | "hours" | "h" => &DateTimeField::Hour,
1292                "minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute,
1293                "second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second,
1294                "millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond,
1295                "microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond,
1296                "nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond,
1297                #[cfg(feature = "timezones")]
1298                "timezone" => &DateTimeField::Timezone,
1299                "time" => &DateTimeField::Time,
1300                "epoch" => &DateTimeField::Epoch,
1301                _ => {
1302                    polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value)
1303                },
1304            }
1305        },
1306        _ => field,
1307    };
1308    Ok(match field {
1309        DateTimeField::Millennium => expr.dt().millennium(),
1310        DateTimeField::Century => expr.dt().century(),
1311        DateTimeField::Decade => expr.dt().year() / typed_lit(10i32),
1312        DateTimeField::Isoyear => expr.dt().iso_year(),
1313        DateTimeField::Year | DateTimeField::Years => expr.dt().year(),
1314        DateTimeField::Quarter => expr.dt().quarter(),
1315        DateTimeField::Month | DateTimeField::Months => expr.dt().month(),
1316        DateTimeField::Week(weekday) => {
1317            if weekday.is_some() {
1318                polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1319            }
1320            expr.dt().week()
1321        },
1322        DateTimeField::IsoWeek | DateTimeField::Weeks => expr.dt().week(),
1323        DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
1324        DateTimeField::DayOfWeek | DateTimeField::Dow => {
1325            let w = expr.dt().weekday();
1326            when(w.clone().eq(typed_lit(7i8)))
1327                .then(typed_lit(0i8))
1328                .otherwise(w)
1329        },
1330        DateTimeField::Isodow => expr.dt().weekday(),
1331        DateTimeField::Day | DateTimeField::Days => expr.dt().day(),
1332        DateTimeField::Hour | DateTimeField::Hours => expr.dt().hour(),
1333        DateTimeField::Minute | DateTimeField::Minutes => expr.dt().minute(),
1334        DateTimeField::Second | DateTimeField::Seconds => expr.dt().second(),
1335        DateTimeField::Millisecond | DateTimeField::Milliseconds => {
1336            (expr.clone().dt().second() * typed_lit(1_000f64))
1337                + expr.dt().nanosecond().div(typed_lit(1_000_000f64))
1338        },
1339        DateTimeField::Microsecond | DateTimeField::Microseconds => {
1340            (expr.clone().dt().second() * typed_lit(1_000_000f64))
1341                + expr.dt().nanosecond().div(typed_lit(1_000f64))
1342        },
1343        DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
1344            (expr.clone().dt().second() * typed_lit(1_000_000_000f64)) + expr.dt().nanosecond()
1345        },
1346        DateTimeField::Time => expr.dt().time(),
1347        #[cfg(feature = "timezones")]
1348        DateTimeField::Timezone => expr.dt().base_utc_offset().dt().total_seconds(false),
1349        DateTimeField::Epoch => {
1350            expr.clone()
1351                .dt()
1352                .timestamp(TimeUnit::Nanoseconds)
1353                .div(typed_lit(1_000_000_000i64))
1354                + expr.dt().nanosecond().div(typed_lit(1_000_000_000f64))
1355        },
1356        _ => {
1357            polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1358        },
1359    })
1360}
1361
1362/// Allow an expression that represents a 1-indexed parameter to
1363/// be adjusted from 1-indexed (SQL) to 0-indexed (Rust/Polars)
1364pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
1365    match idx {
1366        Expr::Literal(sc) if sc.is_null() => lit(LiteralValue::untyped_null()),
1367        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => {
1368            if null_if_zero {
1369                lit(LiteralValue::untyped_null())
1370            } else {
1371                idx
1372            }
1373        },
1374        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n < 0 => idx,
1375        Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => lit(n - 1),
1376        // TODO: when 'saturating_sub' is available, should be able
1377        //  to streamline the when/then/otherwise block below -
1378        _ => when(idx.clone().gt(lit(0)))
1379            .then(idx.clone() - lit(1))
1380            .otherwise(if null_if_zero {
1381                when(idx.clone().eq(lit(0)))
1382                    .then(lit(LiteralValue::untyped_null()))
1383                    .otherwise(idx.clone())
1384            } else {
1385                idx.clone()
1386            }),
1387    }
1388}
1389
1390fn resolve_column<'a>(
1391    ctx: &'a mut SQLContext,
1392    ident_root: &'a Ident,
1393    name: &'a str,
1394    dtype: &'a DataType,
1395) -> PolarsResult<(Expr, Option<&'a DataType>)> {
1396    let resolved = ctx.resolve_name(&ident_root.value, name);
1397    let resolved = resolved.as_str();
1398    Ok((
1399        if name != resolved {
1400            col(resolved).alias(name)
1401        } else {
1402            col(name)
1403        },
1404        Some(dtype),
1405    ))
1406}
1407
1408pub(crate) fn resolve_compound_identifier(
1409    ctx: &mut SQLContext,
1410    idents: &[Ident],
1411    active_schema: Option<&Schema>,
1412) -> PolarsResult<Vec<Expr>> {
1413    // inference priority: table > struct > column
1414    let ident_root = &idents[0];
1415    let mut remaining_idents = idents.iter().skip(1);
1416    let mut lf = ctx.get_table_from_current_scope(&ident_root.value);
1417
1418    // get schema from table (or the active/default schema)
1419    let schema = if let Some(ref mut lf) = lf {
1420        lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena)?
1421    } else {
1422        Arc::new(active_schema.cloned().unwrap_or_default())
1423    };
1424
1425    // handle simple/unqualified column reference with no schema
1426    if lf.is_none() && schema.is_empty() {
1427        let (mut column, mut dtype): (Expr, Option<&DataType>) =
1428            (col(ident_root.value.as_str()), None);
1429
1430        // traverse the remaining struct field path (if any)
1431        for ident in remaining_idents {
1432            let name = ident.value.as_str();
1433            match dtype {
1434                Some(DataType::Struct(fields)) if name == "*" => {
1435                    return Ok(fields
1436                        .iter()
1437                        .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1438                        .collect());
1439                },
1440                Some(DataType::Struct(fields)) => {
1441                    dtype = fields
1442                        .iter()
1443                        .find(|fld| fld.name == name)
1444                        .map(|fld| &fld.dtype);
1445                },
1446                Some(dtype) if name == "*" => {
1447                    polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1448                },
1449                _ => dtype = None,
1450            }
1451            column = column.struct_().field_by_name(name);
1452        }
1453        return Ok(vec![column]);
1454    }
1455
1456    let name = &remaining_idents.next().unwrap().value;
1457
1458    // handle "table.*" wildcard expansion
1459    if lf.is_some() && name == "*" {
1460        return schema
1461            .iter_names_and_dtypes()
1462            .map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).map(|(expr, _)| expr))
1463            .collect();
1464    }
1465
1466    // resolve column/struct reference
1467    let col_dtype: PolarsResult<(Expr, Option<&DataType>)> =
1468        match (lf.is_none(), schema.get(&ident_root.value)) {
1469            // root is a column/struct in schema (no table)
1470            (true, Some(dtype)) => {
1471                remaining_idents = idents.iter().skip(1);
1472                Ok((col(ident_root.value.as_str()), Some(dtype)))
1473            },
1474            // root is not in schema and no table found
1475            (true, None) => {
1476                polars_bail!(
1477                    SQLInterface: "no table or struct column named '{}' found",
1478                    ident_root
1479                )
1480            },
1481            // root is a table, resolve column from table schema
1482            (false, _) => {
1483                if let Some((_, col_name, dtype)) = schema.get_full(name) {
1484                    resolve_column(ctx, ident_root, col_name, dtype)
1485                } else {
1486                    polars_bail!(
1487                        SQLInterface: "no column named '{}' found in table '{}'",
1488                        name, ident_root
1489                    )
1490                }
1491            },
1492        };
1493
1494    // additional ident levels index into struct fields (eg: "df.col.field.nested_field")
1495    let (mut column, mut dtype) = col_dtype?;
1496    for ident in remaining_idents {
1497        let name = ident.value.as_str();
1498        match dtype {
1499            Some(DataType::Struct(fields)) if name == "*" => {
1500                return Ok(fields
1501                    .iter()
1502                    .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1503                    .collect());
1504            },
1505            Some(DataType::Struct(fields)) => {
1506                dtype = fields
1507                    .iter()
1508                    .find(|fld| fld.name == name)
1509                    .map(|fld| &fld.dtype);
1510            },
1511            Some(dtype) if name == "*" => {
1512                polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1513            },
1514            _ => {
1515                dtype = None;
1516            },
1517        }
1518        column = column.struct_().field_by_name(name);
1519    }
1520    Ok(vec![column])
1521}