1use std::fmt::Display;
10use std::ops::Div;
11
12use polars_core::prelude::*;
13use polars_lazy::prelude::*;
14use polars_plan::plans::DynLiteralValue;
15use polars_plan::prelude::typed_lit;
16use polars_time::Duration;
17use rand::Rng;
18use rand::distr::Alphanumeric;
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21use sqlparser::ast::{
22 AccessExpr, BinaryOperator as SQLBinaryOperator, CastFormat, CastKind, DataType as SQLDataType,
23 DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, Interval, Query as Subquery,
24 SelectItem, Subscript, TimezoneInfo, TrimWhereField, TypedString, UnaryOperator,
25 Value as SQLValue, ValueWithSpan,
26};
27use sqlparser::dialect::GenericDialect;
28use sqlparser::parser::{Parser, ParserOptions};
29
30use crate::SQLContext;
31use crate::functions::SQLFunctionVisitor;
32use crate::types::{
33 bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars,
34};
35
36#[inline]
37#[cold]
38#[must_use]
39pub fn to_sql_interface_err(err: impl Display) -> PolarsError {
41 PolarsError::SQLInterface(err.to_string().into())
42}
43
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
46pub enum SubqueryRestriction {
48 SingleColumn,
50 }
54
55pub(crate) struct SQLExprVisitor<'a> {
57 ctx: &'a mut SQLContext,
58 active_schema: Option<&'a Schema>,
59}
60
61impl SQLExprVisitor<'_> {
62 fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
63 let mut array_elements = Vec::with_capacity(elements.len());
64 for e in elements {
65 let val = match e {
66 SQLExpr::Value(ValueWithSpan { value: v, .. }) => self.visit_any_value(v, None),
67 SQLExpr::UnaryOp { op, expr } => match expr.as_ref() {
68 SQLExpr::Value(ValueWithSpan { value: v, .. }) => {
69 self.visit_any_value(v, Some(op))
70 },
71 _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
72 },
73 SQLExpr::Array(values) => {
74 let srs = self.array_expr_to_series(&values.elem)?;
75 Ok(AnyValue::List(srs))
76 },
77 _ => Err(polars_err!(SQLInterface: "array element {:?} is not supported", e)),
78 }?
79 .into_static();
80 array_elements.push(val);
81 }
82 Series::from_any_values(PlSmallStr::EMPTY, &array_elements, true)
83 }
84
85 fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
86 match expr {
87 SQLExpr::AllOp {
88 left,
89 compare_op,
90 right,
91 } => self.visit_all(left, compare_op, right),
92 SQLExpr::AnyOp {
93 left,
94 compare_op,
95 right,
96 is_some: _,
97 } => self.visit_any(left, compare_op, right),
98 SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
99 SQLExpr::Between {
100 expr,
101 negated,
102 low,
103 high,
104 } => self.visit_between(expr, *negated, low, high),
105 SQLExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
106 SQLExpr::Cast {
107 kind,
108 expr,
109 data_type,
110 format,
111 } => self.visit_cast(expr, data_type, format, kind),
112 SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
113 SQLExpr::CompoundFieldAccess { root, access_chain } => {
114 if access_chain.len() == 1 {
116 match &access_chain[0] {
117 AccessExpr::Subscript(subscript) => {
118 return self.visit_subscript(root, subscript);
119 },
120 AccessExpr::Dot(_) => {
121 polars_bail!(SQLSyntax: "dot-notation field access is currently unsupported: {:?}", access_chain[0])
122 },
123 }
124 }
125 polars_bail!(SQLSyntax: "complex field access chains are currently unsupported: {:?}", access_chain[0])
127 },
128 SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
129 SQLExpr::Extract {
130 field,
131 syntax: _,
132 expr,
133 } => parse_extract_date_part(self.visit_expr(expr)?, field),
134 SQLExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
135 SQLExpr::Function(function) => self.visit_function(function),
136 SQLExpr::Identifier(ident) => self.visit_identifier(ident),
137 SQLExpr::InList {
138 expr,
139 list,
140 negated,
141 } => {
142 let expr = self.visit_expr(expr)?;
143 let elems = self.visit_array_expr(list, true, Some(&expr))?;
144 let is_in = expr.is_in(elems, false);
145 Ok(if *negated { is_in.not() } else { is_in })
146 },
147 SQLExpr::InSubquery {
148 expr,
149 subquery,
150 negated,
151 } => self.visit_in_subquery(expr, subquery, *negated),
152 SQLExpr::Interval(interval) => Ok(lit(interval_to_duration(interval, true)?)),
153 SQLExpr::IsDistinctFrom(e1, e2) => {
154 Ok(self.visit_expr(e1)?.neq_missing(self.visit_expr(e2)?))
155 },
156 SQLExpr::IsFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false))),
157 SQLExpr::IsNotDistinctFrom(e1, e2) => {
158 Ok(self.visit_expr(e1)?.eq_missing(self.visit_expr(e2)?))
159 },
160 SQLExpr::IsNotFalse(expr) => Ok(self.visit_expr(expr)?.eq(lit(false)).not()),
161 SQLExpr::IsNotNull(expr) => Ok(self.visit_expr(expr)?.is_not_null()),
162 SQLExpr::IsNotTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true)).not()),
163 SQLExpr::IsNull(expr) => Ok(self.visit_expr(expr)?.is_null()),
164 SQLExpr::IsTrue(expr) => Ok(self.visit_expr(expr)?.eq(lit(true))),
165 SQLExpr::Like {
166 negated,
167 any,
168 expr,
169 pattern,
170 escape_char,
171 } => {
172 if *any {
173 polars_bail!(SQLSyntax: "LIKE ANY is not a supported syntax")
174 }
175 let escape_str = escape_char.as_ref().and_then(|v| match v {
176 SQLValue::SingleQuotedString(s) => Some(s.clone()),
177 _ => None,
178 });
179 self.visit_like(*negated, expr, pattern, &escape_str, false)
180 },
181 SQLExpr::ILike {
182 negated,
183 any,
184 expr,
185 pattern,
186 escape_char,
187 } => {
188 if *any {
189 polars_bail!(SQLSyntax: "ILIKE ANY is not a supported syntax")
190 }
191 let escape_str = escape_char.as_ref().and_then(|v| match v {
192 SQLValue::SingleQuotedString(s) => Some(s.clone()),
193 _ => None,
194 });
195 self.visit_like(*negated, expr, pattern, &escape_str, true)
196 },
197 SQLExpr::Nested(expr) => self.visit_expr(expr),
198 SQLExpr::Position { expr, r#in } => Ok(
199 (self
201 .visit_expr(r#in)?
202 .str()
203 .find(self.visit_expr(expr)?, true)
204 + typed_lit(1u32))
205 .fill_null(typed_lit(0u32)),
206 ),
207 SQLExpr::RLike {
208 negated,
210 expr,
211 pattern,
212 regexp: _,
213 } => {
214 let matches = self
215 .visit_expr(expr)?
216 .str()
217 .contains(self.visit_expr(pattern)?, true);
218 Ok(if *negated { matches.not() } else { matches })
219 },
220 SQLExpr::Subquery(_) => polars_bail!(SQLInterface: "unexpected subquery"),
221 SQLExpr::Substring {
222 expr,
223 substring_from,
224 substring_for,
225 ..
226 } => self.visit_substring(expr, substring_from.as_deref(), substring_for.as_deref()),
227 SQLExpr::Trim {
228 expr,
229 trim_where,
230 trim_what,
231 trim_characters,
232 } => self.visit_trim(expr, trim_where, trim_what, trim_characters),
233 SQLExpr::TypedString(TypedString {
234 data_type,
235 value:
236 ValueWithSpan {
237 value: SQLValue::SingleQuotedString(v),
238 ..
239 },
240 uses_odbc_syntax: _,
241 }) => match data_type {
242 SQLDataType::Date => {
243 if is_iso_date(v) {
244 Ok(lit(v.as_str()).cast(DataType::Date))
245 } else {
246 polars_bail!(SQLSyntax: "invalid DATE literal '{}'", v)
247 }
248 },
249 SQLDataType::Time(None, TimezoneInfo::None) => {
250 if is_iso_time(v) {
251 Ok(lit(v.as_str()).str().to_time(StrptimeOptions {
252 strict: true,
253 ..Default::default()
254 }))
255 } else {
256 polars_bail!(SQLSyntax: "invalid TIME literal '{}'", v)
257 }
258 },
259 SQLDataType::Timestamp(None, TimezoneInfo::None) | SQLDataType::Datetime(None) => {
260 if is_iso_datetime(v) {
261 Ok(lit(v.as_str()).str().to_datetime(
262 None,
263 None,
264 StrptimeOptions {
265 strict: true,
266 ..Default::default()
267 },
268 lit("latest"),
269 ))
270 } else {
271 let fn_name = match data_type {
272 SQLDataType::Timestamp(_, _) => "TIMESTAMP",
273 SQLDataType::Datetime(_) => "DATETIME",
274 _ => unreachable!(),
275 };
276 polars_bail!(SQLSyntax: "invalid {} literal '{}'", fn_name, v)
277 }
278 },
279 _ => {
280 polars_bail!(SQLInterface: "typed literal should be one of DATE, DATETIME, TIME, or TIMESTAMP (found {})", data_type)
281 },
282 },
283 SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
284 SQLExpr::Value(ValueWithSpan { value, .. }) => self.visit_literal(value),
285 SQLExpr::Wildcard(_) => Ok(all().as_expr()),
286 e @ SQLExpr::Case { .. } => self.visit_case_when_then(e),
287 other => {
288 polars_bail!(SQLInterface: "expression {:?} is not currently supported", other)
289 },
290 }
291 }
292
293 fn visit_subquery(
294 &mut self,
295 subquery: &Subquery,
296 restriction: SubqueryRestriction,
297 ) -> PolarsResult<Expr> {
298 if subquery.with.is_some() {
299 polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause");
300 }
301 let (mut lf, schema) = self
304 .ctx
305 .execute_isolated(|ctx| ctx.execute_query_no_ctes(subquery))?;
306
307 if restriction == SubqueryRestriction::SingleColumn {
308 if schema.len() != 1 {
309 polars_bail!(SQLSyntax: "SQL subquery returns more than one column");
310 }
311 let rand_string: String = rand::rng()
312 .sample_iter(&Alphanumeric)
313 .take(16)
314 .map(char::from)
315 .collect();
316
317 let schema_entry = schema.get_at_index(0);
318 if let Some((old_name, _)) = schema_entry {
319 let new_name = String::from(old_name.as_str()) + rand_string.as_str();
320 lf = lf.rename([old_name.to_string()], [new_name.clone()], true);
321 return Ok(Expr::SubPlan(
322 SpecialEq::new(Arc::new(lf.logical_plan)),
323 vec![new_name],
324 ));
325 }
326 };
327 polars_bail!(SQLInterface: "subquery type not supported");
328 }
329
330 fn visit_identifier(&self, ident: &Ident) -> PolarsResult<Expr> {
334 Ok(col(ident.value.as_str()))
335 }
336
337 fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult<Expr> {
341 Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone())
342 }
343
344 fn visit_like(
345 &mut self,
346 negated: bool,
347 expr: &SQLExpr,
348 pattern: &SQLExpr,
349 escape_char: &Option<String>,
350 case_insensitive: bool,
351 ) -> PolarsResult<Expr> {
352 if escape_char.is_some() {
353 polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not currently supported; found '{}'", escape_char.clone().unwrap());
354 }
355 let pat = match self.visit_expr(pattern) {
356 Ok(Expr::Literal(lv)) if lv.extract_str().is_some() => {
357 PlSmallStr::from_str(lv.extract_str().unwrap())
358 },
359 _ => {
360 polars_bail!(SQLSyntax: "LIKE/ILIKE pattern must be a string literal; found {}", pattern)
361 },
362 };
363 if pat.is_empty() || (!case_insensitive && pat.chars().all(|c| !matches!(c, '%' | '_'))) {
364 let op = if negated {
366 SQLBinaryOperator::NotEq
367 } else {
368 SQLBinaryOperator::Eq
369 };
370 self.visit_binary_op(expr, &op, pattern)
371 } else {
372 let mut rx = regex::escape(pat.as_str())
374 .replace('%', ".*")
375 .replace('_', ".");
376
377 rx = format!(
378 "^{}{}$",
379 if case_insensitive { "(?is)" } else { "(?s)" },
380 rx
381 );
382
383 let expr = self.visit_expr(expr)?;
384 let matches = expr.str().contains(lit(rx), true);
385 Ok(if negated { matches.not() } else { matches })
386 }
387 }
388
389 fn visit_subscript(&mut self, expr: &SQLExpr, subscript: &Subscript) -> PolarsResult<Expr> {
390 let expr = self.visit_expr(expr)?;
391 Ok(match subscript {
392 Subscript::Index { index } => {
393 let idx = adjust_one_indexed_param(self.visit_expr(index)?, true);
394 expr.list().get(idx, true)
395 },
396 Subscript::Slice { .. } => {
397 polars_bail!(SQLSyntax: "array slice syntax is not currently supported")
398 },
399 })
400 }
401
402 fn convert_temporal_strings(&mut self, left: &Expr, right: &Expr) -> Expr {
409 if let (Some(name), Some(s), expr_dtype) = match (left, right) {
410 (Expr::Column(name), Expr::Literal(lv)) if lv.extract_str().is_some() => {
412 (Some(name.clone()), Some(lv.extract_str().unwrap()), None)
413 },
414 (Expr::Cast { expr, dtype, .. }, Expr::Literal(lv)) if lv.extract_str().is_some() => {
416 let s = lv.extract_str().unwrap();
417 match &**expr {
418 Expr::Column(name) => (Some(name.clone()), Some(s), Some(dtype)),
419 _ => (None, Some(s), Some(dtype)),
420 }
421 },
422 _ => (None, None, None),
423 } {
424 if expr_dtype.is_none() && self.active_schema.is_none() {
425 right.clone()
426 } else {
427 let left_dtype = expr_dtype.map_or_else(
428 || {
429 self.active_schema
430 .as_ref()
431 .and_then(|schema| schema.get(&name))
432 },
433 |dt| dt.as_literal(),
434 );
435 match left_dtype {
436 Some(DataType::Time) if is_iso_time(s) => {
437 right.clone().str().to_time(StrptimeOptions {
438 strict: true,
439 ..Default::default()
440 })
441 },
442 Some(DataType::Date) if is_iso_date(s) => {
443 right.clone().str().to_date(StrptimeOptions {
444 strict: true,
445 ..Default::default()
446 })
447 },
448 Some(DataType::Datetime(tu, tz)) if is_iso_datetime(s) || is_iso_date(s) => {
449 if s.len() == 10 {
450 lit(format!("{s}T00:00:00"))
452 } else {
453 lit(s.replacen(' ', "T", 1))
454 }
455 .str()
456 .to_datetime(
457 Some(*tu),
458 tz.clone(),
459 StrptimeOptions {
460 strict: true,
461 ..Default::default()
462 },
463 lit("latest"),
464 )
465 },
466 _ => right.clone(),
467 }
468 }
469 } else {
470 right.clone()
471 }
472 }
473
474 fn struct_field_access_expr(
475 &mut self,
476 expr: &Expr,
477 path: &str,
478 infer_index: bool,
479 ) -> PolarsResult<Expr> {
480 let path_elems = if path.starts_with('{') && path.ends_with('}') {
481 path.trim_matches(|c| c == '{' || c == '}')
482 } else {
483 path
484 }
485 .split(',');
486
487 let mut expr = expr.clone();
488 for p in path_elems {
489 let p = p.trim();
490 expr = if infer_index {
491 match p.parse::<i64>() {
492 Ok(idx) => expr.list().get(lit(idx), true),
493 Err(_) => expr.struct_().field_by_name(p),
494 }
495 } else {
496 expr.struct_().field_by_name(p)
497 }
498 }
499 Ok(expr)
500 }
501
502 fn visit_binary_op(
506 &mut self,
507 left: &SQLExpr,
508 op: &SQLBinaryOperator,
509 right: &SQLExpr,
510 ) -> PolarsResult<Expr> {
511 if matches!(left, SQLExpr::Subquery(_)) || matches!(right, SQLExpr::Subquery(_)) {
513 let (suggestion, str_op) = match op {
514 SQLBinaryOperator::NotEq => ("; use 'NOT IN' instead", "!=".to_string()),
515 SQLBinaryOperator::Eq => ("; use 'IN' instead", format!("{op}")),
516 _ => ("", format!("{op}")),
517 };
518 polars_bail!(
519 SQLSyntax: "subquery comparisons with '{str_op}' are not supported{suggestion}"
520 );
521 }
522
523 let (lhs, mut rhs) = match (left, op, right) {
525 (_, SQLBinaryOperator::Minus, SQLExpr::Interval(v)) => {
526 let duration = interval_to_duration(v, false)?;
527 return Ok(self
528 .visit_expr(left)?
529 .dt()
530 .offset_by(lit(format!("-{duration}"))));
531 },
532 (_, SQLBinaryOperator::Plus, SQLExpr::Interval(v)) => {
533 let duration = interval_to_duration(v, false)?;
534 return Ok(self
535 .visit_expr(left)?
536 .dt()
537 .offset_by(lit(format!("{duration}"))));
538 },
539 (SQLExpr::Interval(v1), _, SQLExpr::Interval(v2)) => {
540 let d1 = interval_to_duration(v1, false)?;
542 let d2 = interval_to_duration(v2, false)?;
543 let res = match op {
544 SQLBinaryOperator::Gt => Ok(lit(d1 > d2)),
545 SQLBinaryOperator::Lt => Ok(lit(d1 < d2)),
546 SQLBinaryOperator::GtEq => Ok(lit(d1 >= d2)),
547 SQLBinaryOperator::LtEq => Ok(lit(d1 <= d2)),
548 SQLBinaryOperator::NotEq => Ok(lit(d1 != d2)),
549 SQLBinaryOperator::Eq | SQLBinaryOperator::Spaceship => Ok(lit(d1 == d2)),
550 _ => polars_bail!(SQLInterface: "invalid interval comparison operator"),
551 };
552 if res.is_ok() {
553 return res;
554 }
555 (self.visit_expr(left)?, self.visit_expr(right)?)
556 },
557 _ => (self.visit_expr(left)?, self.visit_expr(right)?),
558 };
559 rhs = self.convert_temporal_strings(&lhs, &rhs);
560
561 Ok(match op {
562 SQLBinaryOperator::BitwiseAnd => lhs.and(rhs), SQLBinaryOperator::BitwiseOr => lhs.or(rhs), SQLBinaryOperator::Xor => lhs.xor(rhs), SQLBinaryOperator::And => lhs.and(rhs), SQLBinaryOperator::Divide => lhs / rhs, SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64), SQLBinaryOperator::Eq => lhs.eq(rhs), SQLBinaryOperator::Gt => lhs.gt(rhs), SQLBinaryOperator::GtEq => lhs.gt_eq(rhs), SQLBinaryOperator::Lt => lhs.lt(rhs), SQLBinaryOperator::LtEq => lhs.lt_eq(rhs), SQLBinaryOperator::Minus => lhs - rhs, SQLBinaryOperator::Modulo => lhs % rhs, SQLBinaryOperator::Multiply => lhs * rhs, SQLBinaryOperator::NotEq => lhs.eq(rhs).not(), SQLBinaryOperator::Or => lhs.or(rhs), SQLBinaryOperator::Plus => lhs + rhs, SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs), SQLBinaryOperator::StringConcat => { lhs.cast(DataType::String) + rhs.cast(DataType::String)
589 },
590 SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs), SQLBinaryOperator::PGRegexMatch => match rhs { Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true),
596 _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
597 },
598 SQLBinaryOperator::PGRegexNotMatch => match rhs { Expr::Literal(ref lv) if lv.extract_str().is_some() => lhs.str().contains(rhs, true).not(),
600 _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
601 },
602 SQLBinaryOperator::PGRegexIMatch => match rhs { Expr::Literal(ref lv) if lv.extract_str().is_some() => {
604 let pat = lv.extract_str().unwrap();
605 lhs.str().contains(lit(format!("(?i){pat}")), true)
606 },
607 _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
608 },
609 SQLBinaryOperator::PGRegexNotIMatch => match rhs { Expr::Literal(ref lv) if lv.extract_str().is_some() => {
611 let pat = lv.extract_str().unwrap();
612 lhs.str().contains(lit(format!("(?i){pat}")), true).not()
613 },
614 _ => {
615 polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs)
616 },
617 },
618 SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch | SQLBinaryOperator::PGILikeMatch | SQLBinaryOperator::PGNotILikeMatch => { let expr = if matches!(
626 op,
627 SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
628 ) {
629 SQLExpr::Like {
630 negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch),
631 any: false,
632 expr: Box::new(left.clone()),
633 pattern: Box::new(right.clone()),
634 escape_char: None,
635 }
636 } else {
637 SQLExpr::ILike {
638 negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch),
639 any: false,
640 expr: Box::new(left.clone()),
641 pattern: Box::new(right.clone()),
642 escape_char: None,
643 }
644 };
645 self.visit_expr(&expr)?
646 },
647 SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs { Expr::Literal(lv) if lv.extract_str().is_some() => {
652 let path = lv.extract_str().unwrap();
653 let mut expr = self.struct_field_access_expr(&lhs, path, false)?;
654 if let SQLBinaryOperator::LongArrow = op {
655 expr = expr.cast(DataType::String);
656 }
657 expr
658 },
659 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(idx))) => {
660 let mut expr = self.struct_field_access_expr(&lhs, &idx.to_string(), true)?;
661 if let SQLBinaryOperator::LongArrow = op {
662 expr = expr.cast(DataType::String);
663 }
664 expr
665 },
666 _ => {
667 polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
668 },
669 },
670 SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => { match rhs {
672 Expr::Literal(lv) if lv.extract_str().is_some() => {
673 let path = lv.extract_str().unwrap();
674 let mut expr = self.struct_field_access_expr(&lhs, path, true)?;
675 if let SQLBinaryOperator::HashLongArrow = op {
676 expr = expr.cast(DataType::String);
677 }
678 expr
679 },
680 _ => {
681 polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", rhs)
682 }
683 }
684 },
685 other => {
686 polars_bail!(SQLInterface: "operator {:?} is not currently supported", other)
687 },
688 })
689 }
690
691 fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
695 let expr = self.visit_expr(expr)?;
696 Ok(match (op, expr.clone()) {
697 (UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
699 lit(n)
700 },
701 (UnaryOperator::Plus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
702 lit(n)
703 },
704 (UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) => {
705 lit(-n)
706 },
707 (UnaryOperator::Minus, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(n)))) => {
708 lit(-n)
709 },
710 (UnaryOperator::Plus, _) => lit(0) + expr,
712 (UnaryOperator::Minus, _) => lit(0) - expr,
713 (UnaryOperator::Not, _) => match &expr {
714 Expr::Column(name)
715 if self
716 .active_schema
717 .and_then(|schema| schema.get(name))
718 .is_some_and(|dtype| matches!(dtype, DataType::Boolean)) =>
719 {
720 expr.not()
722 },
723 _ => expr.strict_cast(DataType::Boolean).not(),
725 },
726 other => polars_bail!(SQLInterface: "unary operator {:?} is not supported", other),
727 })
728 }
729
730 fn visit_function(&mut self, function: &SQLFunction) -> PolarsResult<Expr> {
736 let mut visitor = SQLFunctionVisitor {
737 func: function,
738 ctx: self.ctx,
739 active_schema: self.active_schema,
740 };
741 visitor.visit_function()
742 }
743
744 fn visit_all(
748 &mut self,
749 left: &SQLExpr,
750 compare_op: &SQLBinaryOperator,
751 right: &SQLExpr,
752 ) -> PolarsResult<Expr> {
753 let left = self.visit_expr(left)?;
754 let right = self.visit_expr(right)?;
755
756 match compare_op {
757 SQLBinaryOperator::Gt => Ok(left.gt(right.max())),
758 SQLBinaryOperator::Lt => Ok(left.lt(right.min())),
759 SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.max())),
760 SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.min())),
761 SQLBinaryOperator::Eq => polars_bail!(SQLSyntax: "ALL cannot be used with ="),
762 SQLBinaryOperator::NotEq => polars_bail!(SQLSyntax: "ALL cannot be used with !="),
763 _ => polars_bail!(SQLInterface: "invalid comparison operator"),
764 }
765 }
766
767 fn visit_any(
771 &mut self,
772 left: &SQLExpr,
773 compare_op: &SQLBinaryOperator,
774 right: &SQLExpr,
775 ) -> PolarsResult<Expr> {
776 let left = self.visit_expr(left)?;
777 let right = self.visit_expr(right)?;
778
779 match compare_op {
780 SQLBinaryOperator::Gt => Ok(left.gt(right.min())),
781 SQLBinaryOperator::Lt => Ok(left.lt(right.max())),
782 SQLBinaryOperator::GtEq => Ok(left.gt_eq(right.min())),
783 SQLBinaryOperator::LtEq => Ok(left.lt_eq(right.max())),
784 SQLBinaryOperator::Eq => Ok(left.is_in(right, false)),
785 SQLBinaryOperator::NotEq => Ok(left.is_in(right, false).not()),
786 _ => polars_bail!(SQLInterface: "invalid comparison operator"),
787 }
788 }
789
790 fn visit_array_expr(
792 &mut self,
793 elements: &[SQLExpr],
794 result_as_element: bool,
795 dtype_expr_match: Option<&Expr>,
796 ) -> PolarsResult<Expr> {
797 let mut elems = self.array_expr_to_series(elements)?;
798
799 if let (Some(Expr::Column(name)), Some(schema)) =
802 (dtype_expr_match, self.active_schema.as_ref())
803 {
804 if elems.dtype() == &DataType::String {
805 if let Some(dtype) = schema.get(name) {
806 if matches!(
807 dtype,
808 DataType::Date | DataType::Time | DataType::Datetime(_, _)
809 ) {
810 elems = elems.strict_cast(dtype)?;
811 }
812 }
813 }
814 }
815
816 let res = if result_as_element {
819 elems.implode()?.into_series()
820 } else {
821 elems
822 };
823 Ok(lit(res))
824 }
825
826 fn visit_cast(
830 &mut self,
831 expr: &SQLExpr,
832 dtype: &SQLDataType,
833 format: &Option<CastFormat>,
834 cast_kind: &CastKind,
835 ) -> PolarsResult<Expr> {
836 if format.is_some() {
837 return Err(
838 polars_err!(SQLInterface: "use of FORMAT is not currently supported in CAST"),
839 );
840 }
841 let expr = self.visit_expr(expr)?;
842
843 #[cfg(feature = "json")]
844 if dtype == &SQLDataType::JSON {
845 return Ok(expr.str().json_decode(DataType::Struct(Vec::new())));
847 }
848 let polars_type = map_sql_dtype_to_polars(dtype)?;
849 Ok(match cast_kind {
850 CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type),
851 CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type),
852 })
853 }
854
855 fn visit_literal(&self, value: &SQLValue) -> PolarsResult<Expr> {
861 Ok(match value {
863 SQLValue::Boolean(b) => lit(*b),
864 SQLValue::DollarQuotedString(s) => lit(s.value.clone()),
865 #[cfg(feature = "binary_encoding")]
866 SQLValue::HexStringLiteral(x) => {
867 if x.len() % 2 != 0 {
868 polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
869 };
870 lit(hex::decode(x.clone()).unwrap())
871 },
872 SQLValue::Null => Expr::Literal(LiteralValue::untyped_null()),
873 SQLValue::Number(s, _) => {
874 if s.contains('.') {
876 s.parse::<f64>().map(lit).map_err(|_| ())
877 } else {
878 s.parse::<i64>().map(lit).map_err(|_| ())
879 }
880 .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
881 },
882 SQLValue::SingleQuotedByteStringLiteral(b) => {
883 bitstring_to_bytes_literal(b)?
887 },
888 SQLValue::SingleQuotedString(s) => lit(s.clone()),
889 other => {
890 polars_bail!(SQLInterface: "value {:?} is not a supported literal type", other)
891 },
892 })
893 }
894
895 fn visit_any_value(
897 &self,
898 value: &SQLValue,
899 op: Option<&UnaryOperator>,
900 ) -> PolarsResult<AnyValue<'_>> {
901 Ok(match value {
902 SQLValue::Boolean(b) => AnyValue::Boolean(*b),
903 SQLValue::DollarQuotedString(s) => AnyValue::StringOwned(s.clone().value.into()),
904 #[cfg(feature = "binary_encoding")]
905 SQLValue::HexStringLiteral(x) => {
906 if x.len() % 2 != 0 {
907 polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x)
908 };
909 AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap())
910 },
911 SQLValue::Null => AnyValue::Null,
912 SQLValue::Number(s, _) => {
913 let negate = match op {
914 Some(UnaryOperator::Minus) => true,
915 Some(UnaryOperator::Plus) | None => false,
917 Some(op) => {
918 polars_bail!(SQLInterface: "unary op {:?} not supported for numeric SQL value", op)
919 },
920 };
921 if s.contains('.') {
923 s.parse::<f64>()
924 .map(|n: f64| AnyValue::Float64(if negate { -n } else { n }))
925 .map_err(|_| ())
926 } else {
927 s.parse::<i64>()
928 .map(|n: i64| AnyValue::Int64(if negate { -n } else { n }))
929 .map_err(|_| ())
930 }
931 .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))?
932 },
933 SQLValue::SingleQuotedByteStringLiteral(b) => {
934 let bytes_literal = bitstring_to_bytes_literal(b)?;
936 match bytes_literal {
937 Expr::Literal(lv) if lv.extract_binary().is_some() => {
938 AnyValue::BinaryOwned(lv.extract_binary().unwrap().to_vec())
939 },
940 _ => {
941 polars_bail!(SQLInterface: "failed to parse bitstring literal: {:?}", b)
942 },
943 }
944 },
945 SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.as_str().into()),
946 other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other),
947 })
948 }
949
950 fn visit_between(
953 &mut self,
954 expr: &SQLExpr,
955 negated: bool,
956 low: &SQLExpr,
957 high: &SQLExpr,
958 ) -> PolarsResult<Expr> {
959 let expr = self.visit_expr(expr)?;
960 let low = self.visit_expr(low)?;
961 let high = self.visit_expr(high)?;
962
963 let low = self.convert_temporal_strings(&expr, &low);
964 let high = self.convert_temporal_strings(&expr, &high);
965 Ok(if negated {
966 expr.clone().lt(low).or(expr.gt(high))
967 } else {
968 expr.clone().gt_eq(low).and(expr.lt_eq(high))
969 })
970 }
971
972 fn visit_trim(
975 &mut self,
976 expr: &SQLExpr,
977 trim_where: &Option<TrimWhereField>,
978 trim_what: &Option<Box<SQLExpr>>,
979 trim_characters: &Option<Vec<SQLExpr>>,
980 ) -> PolarsResult<Expr> {
981 if trim_characters.is_some() {
982 return Err(polars_err!(SQLSyntax: "unsupported TRIM syntax (custom chars)"));
984 };
985 let expr = self.visit_expr(expr)?;
986 let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
987 let trim_what = match trim_what {
988 Some(Expr::Literal(lv)) if lv.extract_str().is_some() => {
989 Some(PlSmallStr::from_str(lv.extract_str().unwrap()))
990 },
991 None => None,
992 _ => return self.err(&expr),
993 };
994 Ok(match (trim_where, trim_what) {
995 (None | Some(TrimWhereField::Both), None) => {
996 expr.str().strip_chars(lit(LiteralValue::untyped_null()))
997 },
998 (None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
999 (Some(TrimWhereField::Leading), None) => expr
1000 .str()
1001 .strip_chars_start(lit(LiteralValue::untyped_null())),
1002 (Some(TrimWhereField::Leading), Some(val)) => expr.str().strip_chars_start(lit(val)),
1003 (Some(TrimWhereField::Trailing), None) => expr
1004 .str()
1005 .strip_chars_end(lit(LiteralValue::untyped_null())),
1006 (Some(TrimWhereField::Trailing), Some(val)) => expr.str().strip_chars_end(lit(val)),
1007 })
1008 }
1009
1010 fn visit_substring(
1011 &mut self,
1012 expr: &SQLExpr,
1013 substring_from: Option<&SQLExpr>,
1014 substring_for: Option<&SQLExpr>,
1015 ) -> PolarsResult<Expr> {
1016 let e = self.visit_expr(expr)?;
1017
1018 match (substring_from, substring_for) {
1019 (Some(from_expr), Some(for_expr)) => {
1021 let start = self.visit_expr(from_expr)?;
1022 let length = self.visit_expr(for_expr)?;
1023
1024 Ok(match (start.clone(), length.clone()) {
1026 (Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1027 (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1028 polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", n)
1029 },
1030 (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => {
1031 e.str().slice(lit(n - 1), length)
1032 },
1033 (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => e
1034 .str()
1035 .slice(lit(0), (length + lit(n - 1)).clip_min(lit(0))),
1036 (Expr::Literal(_), _) => {
1037 polars_bail!(SQLSyntax: "invalid 'start' for SUBSTRING")
1038 },
1039 (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1040 polars_bail!(SQLSyntax: "invalid 'length' for SUBSTRING")
1041 },
1042 _ => {
1043 let adjusted_start = start - lit(1);
1044 when(adjusted_start.clone().lt(lit(0)))
1045 .then(e.clone().str().slice(
1046 lit(0),
1047 (length.clone() + adjusted_start.clone()).clip_min(lit(0)),
1048 ))
1049 .otherwise(e.str().slice(adjusted_start, length))
1050 },
1051 })
1052 },
1053 (Some(from_expr), None) => {
1055 let start = self.visit_expr(from_expr)?;
1056
1057 Ok(match start {
1058 Expr::Literal(lv) if lv.is_null() => lit(lv),
1059 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1060 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1061 e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null()))
1062 },
1063 Expr::Literal(_) => {
1064 polars_bail!(SQLSyntax: "invalid 'start' for SUBSTRING")
1065 },
1066 _ => e
1067 .str()
1068 .slice(start - lit(1), lit(LiteralValue::untyped_null())),
1069 })
1070 },
1071 (None, _) => {
1073 polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found 1)")
1074 },
1075 }
1076 }
1077
1078 fn visit_in_subquery(
1080 &mut self,
1081 expr: &SQLExpr,
1082 subquery: &Subquery,
1083 negated: bool,
1084 ) -> PolarsResult<Expr> {
1085 let subquery_result = self
1086 .visit_subquery(subquery, SubqueryRestriction::SingleColumn)?
1087 .implode();
1088 let expr = self.visit_expr(expr)?;
1089 Ok(if negated {
1090 expr.is_in(subquery_result, false).not()
1091 } else {
1092 expr.is_in(subquery_result, false)
1093 })
1094 }
1095
1096 fn visit_case_when_then(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
1098 if let SQLExpr::Case {
1099 case_token: _,
1100 end_token: _,
1101 operand,
1102 conditions,
1103 else_result,
1104 } = expr
1105 {
1106 polars_ensure!(
1107 !conditions.is_empty(),
1108 SQLSyntax: "WHEN and THEN expressions must have at least one element"
1109 );
1110
1111 let mut when_thens = conditions.iter();
1112 let first = when_thens.next();
1113 if first.is_none() {
1114 polars_bail!(SQLSyntax: "WHEN and THEN expressions must have at least one element");
1115 }
1116 let else_res = match else_result {
1117 Some(else_res) => self.visit_expr(else_res)?,
1118 None => lit(LiteralValue::untyped_null()), };
1120 if let Some(operand_expr) = operand {
1121 let first_operand_expr = self.visit_expr(operand_expr)?;
1122
1123 let first = first.unwrap();
1124 let first_cond = first_operand_expr.eq(self.visit_expr(&first.condition)?);
1125 let first_then = self.visit_expr(&first.result)?;
1126 let expr = when(first_cond).then(first_then);
1127 let next = when_thens.next();
1128
1129 let mut when_then = if let Some(case_when) = next {
1130 let second_operand_expr = self.visit_expr(operand_expr)?;
1131 let cond = second_operand_expr.eq(self.visit_expr(&case_when.condition)?);
1132 let res = self.visit_expr(&case_when.result)?;
1133 expr.when(cond).then(res)
1134 } else {
1135 return Ok(expr.otherwise(else_res));
1136 };
1137 for case_when in when_thens {
1138 let new_operand_expr = self.visit_expr(operand_expr)?;
1139 let cond = new_operand_expr.eq(self.visit_expr(&case_when.condition)?);
1140 let res = self.visit_expr(&case_when.result)?;
1141 when_then = when_then.when(cond).then(res);
1142 }
1143 return Ok(when_then.otherwise(else_res));
1144 }
1145
1146 let first = first.unwrap();
1147 let first_cond = self.visit_expr(&first.condition)?;
1148 let first_then = self.visit_expr(&first.result)?;
1149 let expr = when(first_cond).then(first_then);
1150 let next = when_thens.next();
1151
1152 let mut when_then = if let Some(case_when) = next {
1153 let cond = self.visit_expr(&case_when.condition)?;
1154 let res = self.visit_expr(&case_when.result)?;
1155 expr.when(cond).then(res)
1156 } else {
1157 return Ok(expr.otherwise(else_res));
1158 };
1159 for case_when in when_thens {
1160 let cond = self.visit_expr(&case_when.condition)?;
1161 let res = self.visit_expr(&case_when.result)?;
1162 when_then = when_then.when(cond).then(res);
1163 }
1164 Ok(when_then.otherwise(else_res))
1165 } else {
1166 unreachable!()
1167 }
1168 }
1169
1170 fn err(&self, expr: &Expr) -> PolarsResult<Expr> {
1171 polars_bail!(SQLInterface: "expression {:?} is not currently supported", expr);
1172 }
1173}
1174
1175pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
1193 let mut ctx = SQLContext::new();
1194
1195 let mut parser = Parser::new(&GenericDialect);
1196 parser = parser.with_options(ParserOptions {
1197 trailing_commas: true,
1198 ..Default::default()
1199 });
1200
1201 let mut ast = parser
1202 .try_with_sql(s.as_ref())
1203 .map_err(to_sql_interface_err)?;
1204 let expr = ast.parse_select_item().map_err(to_sql_interface_err)?;
1205
1206 Ok(match &expr {
1207 SelectItem::ExprWithAlias { expr, alias } => {
1208 let expr = parse_sql_expr(expr, &mut ctx, None)?;
1209 expr.alias(alias.value.as_str())
1210 },
1211 SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?,
1212 _ => polars_bail!(SQLInterface: "unable to parse '{}' as Expr", s.as_ref()),
1213 })
1214}
1215
1216pub(crate) fn interval_to_duration(interval: &Interval, fixed: bool) -> PolarsResult<Duration> {
1217 if interval.last_field.is_some()
1218 || interval.leading_field.is_some()
1219 || interval.leading_precision.is_some()
1220 || interval.fractional_seconds_precision.is_some()
1221 {
1222 polars_bail!(SQLSyntax: "unsupported interval syntax ('{}')", interval)
1223 }
1224 let s = match &*interval.value {
1225 SQLExpr::UnaryOp { .. } => {
1226 polars_bail!(SQLSyntax: "unary ops are not valid on interval strings; found {}", interval.value)
1227 },
1228 SQLExpr::Value(ValueWithSpan {
1229 value: SQLValue::SingleQuotedString(s),
1230 ..
1231 }) => Some(s),
1232 _ => None,
1233 };
1234 match s {
1235 Some(s) if s.contains('-') => {
1236 polars_bail!(SQLInterface: "minus signs are not yet supported in interval strings; found '{}'", s)
1237 },
1238 Some(s) => {
1239 let duration = Duration::parse_interval(s);
1242 if fixed && duration.months() != 0 {
1243 polars_bail!(SQLSyntax: "fixed-duration interval cannot contain years, quarters, or months; found {}", s)
1244 };
1245 Ok(duration)
1246 },
1247 None => polars_bail!(SQLSyntax: "invalid interval {:?}", interval),
1248 }
1249}
1250
1251pub(crate) fn parse_sql_expr(
1252 expr: &SQLExpr,
1253 ctx: &mut SQLContext,
1254 active_schema: Option<&Schema>,
1255) -> PolarsResult<Expr> {
1256 let mut visitor = SQLExprVisitor { ctx, active_schema };
1257 visitor.visit_expr(expr)
1258}
1259
1260pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
1261 match expr {
1262 SQLExpr::Array(arr) => {
1263 let mut visitor = SQLExprVisitor {
1264 ctx,
1265 active_schema: None,
1266 };
1267 visitor.array_expr_to_series(arr.elem.as_slice())
1268 },
1269 _ => polars_bail!(SQLSyntax: "Expected array expression, found {:?}", expr),
1270 }
1271}
1272
1273pub(crate) fn parse_extract_date_part(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
1274 let field = match field {
1275 DateTimeField::Custom(Ident { value, .. }) => {
1277 let value = value.to_ascii_lowercase();
1278 match value.as_str() {
1279 "millennium" | "millennia" => &DateTimeField::Millennium,
1280 "century" | "centuries" => &DateTimeField::Century,
1281 "decade" | "decades" => &DateTimeField::Decade,
1282 "isoyear" => &DateTimeField::Isoyear,
1283 "year" | "years" | "y" => &DateTimeField::Year,
1284 "quarter" | "quarters" => &DateTimeField::Quarter,
1285 "month" | "months" | "mon" | "mons" => &DateTimeField::Month,
1286 "dayofyear" | "doy" => &DateTimeField::DayOfYear,
1287 "dayofweek" | "dow" => &DateTimeField::DayOfWeek,
1288 "isoweek" | "week" | "weeks" => &DateTimeField::IsoWeek,
1289 "isodow" => &DateTimeField::Isodow,
1290 "day" | "days" | "d" => &DateTimeField::Day,
1291 "hour" | "hours" | "h" => &DateTimeField::Hour,
1292 "minute" | "minutes" | "mins" | "min" | "m" => &DateTimeField::Minute,
1293 "second" | "seconds" | "sec" | "secs" | "s" => &DateTimeField::Second,
1294 "millisecond" | "milliseconds" | "ms" => &DateTimeField::Millisecond,
1295 "microsecond" | "microseconds" | "us" => &DateTimeField::Microsecond,
1296 "nanosecond" | "nanoseconds" | "ns" => &DateTimeField::Nanosecond,
1297 #[cfg(feature = "timezones")]
1298 "timezone" => &DateTimeField::Timezone,
1299 "time" => &DateTimeField::Time,
1300 "epoch" => &DateTimeField::Epoch,
1301 _ => {
1302 polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", value)
1303 },
1304 }
1305 },
1306 _ => field,
1307 };
1308 Ok(match field {
1309 DateTimeField::Millennium => expr.dt().millennium(),
1310 DateTimeField::Century => expr.dt().century(),
1311 DateTimeField::Decade => expr.dt().year() / typed_lit(10i32),
1312 DateTimeField::Isoyear => expr.dt().iso_year(),
1313 DateTimeField::Year | DateTimeField::Years => expr.dt().year(),
1314 DateTimeField::Quarter => expr.dt().quarter(),
1315 DateTimeField::Month | DateTimeField::Months => expr.dt().month(),
1316 DateTimeField::Week(weekday) => {
1317 if weekday.is_some() {
1318 polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1319 }
1320 expr.dt().week()
1321 },
1322 DateTimeField::IsoWeek | DateTimeField::Weeks => expr.dt().week(),
1323 DateTimeField::DayOfYear | DateTimeField::Doy => expr.dt().ordinal_day(),
1324 DateTimeField::DayOfWeek | DateTimeField::Dow => {
1325 let w = expr.dt().weekday();
1326 when(w.clone().eq(typed_lit(7i8)))
1327 .then(typed_lit(0i8))
1328 .otherwise(w)
1329 },
1330 DateTimeField::Isodow => expr.dt().weekday(),
1331 DateTimeField::Day | DateTimeField::Days => expr.dt().day(),
1332 DateTimeField::Hour | DateTimeField::Hours => expr.dt().hour(),
1333 DateTimeField::Minute | DateTimeField::Minutes => expr.dt().minute(),
1334 DateTimeField::Second | DateTimeField::Seconds => expr.dt().second(),
1335 DateTimeField::Millisecond | DateTimeField::Milliseconds => {
1336 (expr.clone().dt().second() * typed_lit(1_000f64))
1337 + expr.dt().nanosecond().div(typed_lit(1_000_000f64))
1338 },
1339 DateTimeField::Microsecond | DateTimeField::Microseconds => {
1340 (expr.clone().dt().second() * typed_lit(1_000_000f64))
1341 + expr.dt().nanosecond().div(typed_lit(1_000f64))
1342 },
1343 DateTimeField::Nanosecond | DateTimeField::Nanoseconds => {
1344 (expr.clone().dt().second() * typed_lit(1_000_000_000f64)) + expr.dt().nanosecond()
1345 },
1346 DateTimeField::Time => expr.dt().time(),
1347 #[cfg(feature = "timezones")]
1348 DateTimeField::Timezone => expr.dt().base_utc_offset().dt().total_seconds(false),
1349 DateTimeField::Epoch => {
1350 expr.clone()
1351 .dt()
1352 .timestamp(TimeUnit::Nanoseconds)
1353 .div(typed_lit(1_000_000_000i64))
1354 + expr.dt().nanosecond().div(typed_lit(1_000_000_000f64))
1355 },
1356 _ => {
1357 polars_bail!(SQLSyntax: "EXTRACT/DATE_PART does not support '{}' part", field)
1358 },
1359 })
1360}
1361
1362pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr {
1365 match idx {
1366 Expr::Literal(sc) if sc.is_null() => lit(LiteralValue::untyped_null()),
1367 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => {
1368 if null_if_zero {
1369 lit(LiteralValue::untyped_null())
1370 } else {
1371 idx
1372 }
1373 },
1374 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n < 0 => idx,
1375 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => lit(n - 1),
1376 _ => when(idx.clone().gt(lit(0)))
1379 .then(idx.clone() - lit(1))
1380 .otherwise(if null_if_zero {
1381 when(idx.clone().eq(lit(0)))
1382 .then(lit(LiteralValue::untyped_null()))
1383 .otherwise(idx.clone())
1384 } else {
1385 idx.clone()
1386 }),
1387 }
1388}
1389
1390fn resolve_column<'a>(
1391 ctx: &'a mut SQLContext,
1392 ident_root: &'a Ident,
1393 name: &'a str,
1394 dtype: &'a DataType,
1395) -> PolarsResult<(Expr, Option<&'a DataType>)> {
1396 let resolved = ctx.resolve_name(&ident_root.value, name);
1397 let resolved = resolved.as_str();
1398 Ok((
1399 if name != resolved {
1400 col(resolved).alias(name)
1401 } else {
1402 col(name)
1403 },
1404 Some(dtype),
1405 ))
1406}
1407
1408pub(crate) fn resolve_compound_identifier(
1409 ctx: &mut SQLContext,
1410 idents: &[Ident],
1411 active_schema: Option<&Schema>,
1412) -> PolarsResult<Vec<Expr>> {
1413 let ident_root = &idents[0];
1415 let mut remaining_idents = idents.iter().skip(1);
1416 let mut lf = ctx.get_table_from_current_scope(&ident_root.value);
1417
1418 let schema = if let Some(ref mut lf) = lf {
1420 lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena)?
1421 } else {
1422 Arc::new(active_schema.cloned().unwrap_or_default())
1423 };
1424
1425 if lf.is_none() && schema.is_empty() {
1427 let (mut column, mut dtype): (Expr, Option<&DataType>) =
1428 (col(ident_root.value.as_str()), None);
1429
1430 for ident in remaining_idents {
1432 let name = ident.value.as_str();
1433 match dtype {
1434 Some(DataType::Struct(fields)) if name == "*" => {
1435 return Ok(fields
1436 .iter()
1437 .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1438 .collect());
1439 },
1440 Some(DataType::Struct(fields)) => {
1441 dtype = fields
1442 .iter()
1443 .find(|fld| fld.name == name)
1444 .map(|fld| &fld.dtype);
1445 },
1446 Some(dtype) if name == "*" => {
1447 polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1448 },
1449 _ => dtype = None,
1450 }
1451 column = column.struct_().field_by_name(name);
1452 }
1453 return Ok(vec![column]);
1454 }
1455
1456 let name = &remaining_idents.next().unwrap().value;
1457
1458 if lf.is_some() && name == "*" {
1460 return schema
1461 .iter_names_and_dtypes()
1462 .map(|(name, dtype)| resolve_column(ctx, ident_root, name, dtype).map(|(expr, _)| expr))
1463 .collect();
1464 }
1465
1466 let col_dtype: PolarsResult<(Expr, Option<&DataType>)> =
1468 match (lf.is_none(), schema.get(&ident_root.value)) {
1469 (true, Some(dtype)) => {
1471 remaining_idents = idents.iter().skip(1);
1472 Ok((col(ident_root.value.as_str()), Some(dtype)))
1473 },
1474 (true, None) => {
1476 polars_bail!(
1477 SQLInterface: "no table or struct column named '{}' found",
1478 ident_root
1479 )
1480 },
1481 (false, _) => {
1483 if let Some((_, col_name, dtype)) = schema.get_full(name) {
1484 resolve_column(ctx, ident_root, col_name, dtype)
1485 } else {
1486 polars_bail!(
1487 SQLInterface: "no column named '{}' found in table '{}'",
1488 name, ident_root
1489 )
1490 }
1491 },
1492 };
1493
1494 let (mut column, mut dtype) = col_dtype?;
1496 for ident in remaining_idents {
1497 let name = ident.value.as_str();
1498 match dtype {
1499 Some(DataType::Struct(fields)) if name == "*" => {
1500 return Ok(fields
1501 .iter()
1502 .map(|fld| column.clone().struct_().field_by_name(&fld.name))
1503 .collect());
1504 },
1505 Some(DataType::Struct(fields)) => {
1506 dtype = fields
1507 .iter()
1508 .find(|fld| fld.name == name)
1509 .map(|fld| &fld.dtype);
1510 },
1511 Some(dtype) if name == "*" => {
1512 polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype)
1513 },
1514 _ => {
1515 dtype = None;
1516 },
1517 }
1518 column = column.struct_().field_by_name(name);
1519 }
1520 Ok(vec![column])
1521}