1use std::ops::{Add, Sub};
2
3use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions};
4use polars_core::prelude::{
5 DataType, ExplodeOptions, PolarsResult, QuantileMethod, Schema, TimeUnit, polars_bail,
6 polars_err,
7};
8use polars_lazy::dsl::Expr;
9#[cfg(feature = "rank")]
10use polars_lazy::prelude::{RankMethod, RankOptions};
11use polars_ops::chunked_array::UnicodeForm;
12use polars_ops::series::RoundMode;
13use polars_plan::dsl::{
14 as_struct, coalesce, concat_str, element, int_range, len, max_horizontal, min_horizontal, when,
15};
16use polars_plan::plans::{DynLiteralValue, LiteralValue, typed_lit};
17use polars_plan::prelude::{StrptimeOptions, col, cols, lit};
18use polars_utils::pl_str::PlSmallStr;
19use sqlparser::ast::helpers::attached_token::AttachedToken;
20use sqlparser::ast::{
21 DateTimeField, DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg,
22 FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, Ident,
23 OrderByExpr, Value as SQLValue, ValueWithSpan, WindowFrame, WindowFrameBound, WindowFrameUnits,
24 WindowSpec, WindowType,
25};
26use sqlparser::tokenizer::Span;
27
28use crate::SQLContext;
29use crate::sql_expr::{adjust_one_indexed_param, parse_extract_date_part, parse_sql_expr};
30
31pub(crate) struct SQLFunctionVisitor<'a> {
32 pub(crate) func: &'a SQLFunction,
33 pub(crate) ctx: &'a mut SQLContext,
34 pub(crate) active_schema: Option<&'a Schema>,
35}
36
37pub(crate) enum PolarsSQLFunctions {
39 BitAnd,
48 #[cfg(feature = "bitwise")]
54 BitCount,
55 BitNot,
61 BitOr,
67 BitXor,
73
74 Abs,
83 Ceil,
89 Div,
95 Exp,
101 Floor,
108 Pi,
114 Ln,
120 Log2,
126 Log10,
132 Log,
138 Log1p,
144 Pow,
150 Mod,
156 Sqrt,
162 Cbrt,
168 Round,
175 Sign,
181
182 Cos,
191 Cot,
197 Sin,
203 Tan,
209 CosD,
215 CotD,
221 SinD,
227 TanD,
233 Acos,
239 Asin,
245 Atan,
251 Atan2,
257 AcosD,
263 AsinD,
269 AtanD,
275 Atan2D,
281 Degrees,
289 Radians,
295
296 DatePart,
305 Strftime,
311
312 BitLength,
320 Concat,
326 ConcatWS,
333 Date,
342 EndsWith,
349 #[cfg(feature = "nightly")]
355 InitCap,
356 Left,
362 Length,
368 Lower,
374 LTrim,
380 Normalize,
387 OctetLength,
393 RegexpLike,
399 Replace,
405 Reverse,
411 Right,
417 RTrim,
423 SplitPart,
430 StartsWith,
437 StrPos,
443 Substring,
450 StringToArray,
456 Strptime,
462 Time,
471 Timestamp,
480 Upper,
486
487 Coalesce,
496 Greatest,
502 If,
509 IfNull,
515 Least,
521 NullIf,
527
528 Avg,
537 Corr,
543 Count,
552 CovarPop,
558 CovarSamp,
564 First,
570 Last,
576 Max,
582 Median,
588 QuantileCont,
595 QuantileDisc,
602 Min,
608 StdDev,
614 Sum,
620 Variance,
626
627 ArrayLength,
636 ArrayMin,
642 ArrayMax,
648 ArraySum,
654 ArrayMean,
660 ArrayReverse,
666 ArrayUnique,
672 Explode,
678 ArrayAgg,
684 ArrayToString,
691 ArrayGet,
697 ArrayContains,
703
704 FirstValue,
713 LastValue,
720 Lag,
726 Lead,
732 RowNumber,
739 #[cfg(feature = "rank")]
747 Rank,
748 #[cfg(feature = "rank")]
756 DenseRank,
757
758 Columns,
762
763 Udf(String),
767}
768
769impl PolarsSQLFunctions {
770 pub(crate) fn keywords() -> &'static [&'static str] {
771 &[
772 "abs",
773 "acos",
774 "acosd",
775 "array_contains",
776 "array_get",
777 "array_length",
778 "array_lower",
779 "array_mean",
780 "array_reverse",
781 "array_sum",
782 "array_to_string",
783 "array_unique",
784 "array_upper",
785 "asin",
786 "asind",
787 "atan",
788 "atan2",
789 "atan2d",
790 "atand",
791 "avg",
792 "bit_and",
793 "bit_count",
794 "bit_length",
795 "bit_or",
796 "bit_xor",
797 "cbrt",
798 "ceil",
799 "ceiling",
800 "char_length",
801 "character_length",
802 "coalesce",
803 "columns",
804 "concat",
805 "concat_ws",
806 "corr",
807 "cos",
808 "cosd",
809 "cot",
810 "cotd",
811 "count",
812 "covar",
813 "covar_pop",
814 "covar_samp",
815 "date",
816 "date_part",
817 "degrees",
818 "dense_rank",
819 "ends_with",
820 "exp",
821 "first",
822 "first_value",
823 "floor",
824 "greatest",
825 "if",
826 "ifnull",
827 "initcap",
828 "lag",
829 "last",
830 "last_value",
831 "lead",
832 "least",
833 "left",
834 "length",
835 "ln",
836 "log",
837 "log10",
838 "log1p",
839 "log2",
840 "lower",
841 "ltrim",
842 "max",
843 "median",
844 "quantile_disc",
845 "min",
846 "mod",
847 "nullif",
848 "octet_length",
849 "pi",
850 "pow",
851 "power",
852 "quantile_cont",
853 "quantile_disc",
854 "radians",
855 "rank",
856 "regexp_like",
857 "replace",
858 "reverse",
859 "right",
860 "round",
861 "row_number",
862 "rtrim",
863 "sign",
864 "sin",
865 "sind",
866 "sqrt",
867 "starts_with",
868 "stddev",
869 "stddev_samp",
870 "stdev",
871 "stdev_samp",
872 "strftime",
873 "strpos",
874 "strptime",
875 "substr",
876 "sum",
877 "tan",
878 "tand",
879 "unnest",
880 "upper",
881 "var",
882 "var_samp",
883 "variance",
884 ]
885 }
886}
887
888impl PolarsSQLFunctions {
889 fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult<Self> {
890 let function_name = function.name.0[0].as_ident().unwrap().value.to_lowercase();
891 Ok(match function_name.as_str() {
892 "bit_and" | "bitand" => Self::BitAnd,
896 #[cfg(feature = "bitwise")]
897 "bit_count" | "bitcount" => Self::BitCount,
898 "bit_not" | "bitnot" => Self::BitNot,
899 "bit_or" | "bitor" => Self::BitOr,
900 "bit_xor" | "bitxor" | "xor" => Self::BitXor,
901
902 "abs" => Self::Abs,
906 "cbrt" => Self::Cbrt,
907 "ceil" | "ceiling" => Self::Ceil,
908 "div" => Self::Div,
909 "exp" => Self::Exp,
910 "floor" => Self::Floor,
911 "ln" => Self::Ln,
912 "log" => Self::Log,
913 "log10" => Self::Log10,
914 "log1p" => Self::Log1p,
915 "log2" => Self::Log2,
916 "mod" => Self::Mod,
917 "pi" => Self::Pi,
918 "pow" | "power" => Self::Pow,
919 "round" => Self::Round,
920 "sign" => Self::Sign,
921 "sqrt" => Self::Sqrt,
922
923 "cos" => Self::Cos,
927 "cot" => Self::Cot,
928 "sin" => Self::Sin,
929 "tan" => Self::Tan,
930 "cosd" => Self::CosD,
931 "cotd" => Self::CotD,
932 "sind" => Self::SinD,
933 "tand" => Self::TanD,
934 "acos" => Self::Acos,
935 "asin" => Self::Asin,
936 "atan" => Self::Atan,
937 "atan2" => Self::Atan2,
938 "acosd" => Self::AcosD,
939 "asind" => Self::AsinD,
940 "atand" => Self::AtanD,
941 "atan2d" => Self::Atan2D,
942 "degrees" => Self::Degrees,
943 "radians" => Self::Radians,
944
945 "coalesce" => Self::Coalesce,
949 "greatest" => Self::Greatest,
950 "if" => Self::If,
951 "ifnull" => Self::IfNull,
952 "least" => Self::Least,
953 "nullif" => Self::NullIf,
954
955 "date_part" => Self::DatePart,
959 "strftime" => Self::Strftime,
960
961 "bit_length" => Self::BitLength,
965 "concat" => Self::Concat,
966 "concat_ws" => Self::ConcatWS,
967 "date" => Self::Date,
968 "timestamp" | "datetime" => Self::Timestamp,
969 "ends_with" => Self::EndsWith,
970 #[cfg(feature = "nightly")]
971 "initcap" => Self::InitCap,
972 "length" | "char_length" | "character_length" => Self::Length,
973 "left" => Self::Left,
974 "lower" => Self::Lower,
975 "ltrim" => Self::LTrim,
976 "normalize" => Self::Normalize,
977 "octet_length" => Self::OctetLength,
978 "strpos" => Self::StrPos,
979 "regexp_like" => Self::RegexpLike,
980 "replace" => Self::Replace,
981 "reverse" => Self::Reverse,
982 "right" => Self::Right,
983 "rtrim" => Self::RTrim,
984 "split_part" => Self::SplitPart,
985 "starts_with" => Self::StartsWith,
986 "string_to_array" => Self::StringToArray,
987 "strptime" => Self::Strptime,
988 "substr" => Self::Substring,
989 "time" => Self::Time,
990 "upper" => Self::Upper,
991
992 "avg" => Self::Avg,
996 "corr" => Self::Corr,
997 "count" => Self::Count,
998 "covar_pop" => Self::CovarPop,
999 "covar" | "covar_samp" => Self::CovarSamp,
1000 "first" => Self::First,
1001 "last" => Self::Last,
1002 "max" => Self::Max,
1003 "median" => Self::Median,
1004 "quantile_cont" => Self::QuantileCont,
1005 "quantile_disc" => Self::QuantileDisc,
1006 "min" => Self::Min,
1007 "stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev,
1008 "sum" => Self::Sum,
1009 "var" | "variance" | "var_samp" => Self::Variance,
1010
1011 "array_agg" => Self::ArrayAgg,
1015 "array_contains" => Self::ArrayContains,
1016 "array_get" => Self::ArrayGet,
1017 "array_length" => Self::ArrayLength,
1018 "array_lower" => Self::ArrayMin,
1019 "array_mean" => Self::ArrayMean,
1020 "array_reverse" => Self::ArrayReverse,
1021 "array_sum" => Self::ArraySum,
1022 "array_to_string" => Self::ArrayToString,
1023 "array_unique" => Self::ArrayUnique,
1024 "array_upper" => Self::ArrayMax,
1025 "unnest" => Self::Explode,
1026
1027 #[cfg(feature = "rank")]
1031 "dense_rank" => Self::DenseRank,
1032 "first_value" => Self::FirstValue,
1033 "last_value" => Self::LastValue,
1034 "lag" => Self::Lag,
1035 "lead" => Self::Lead,
1036 #[cfg(feature = "rank")]
1037 "rank" => Self::Rank,
1038 "row_number" => Self::RowNumber,
1039
1040 "columns" => Self::Columns,
1044
1045 other => {
1046 if ctx.function_registry.contains(other) {
1047 Self::Udf(other.to_string())
1048 } else {
1049 polars_bail!(SQLInterface: "unsupported function '{}'", other);
1050 }
1051 },
1052 })
1053 }
1054}
1055
1056impl SQLFunctionVisitor<'_> {
1057 pub(crate) fn visit_function(&mut self) -> PolarsResult<Expr> {
1058 use PolarsSQLFunctions::*;
1059 use polars_lazy::prelude::Literal;
1060
1061 let function_name = PolarsSQLFunctions::try_from_sql(self.func, self.ctx)?;
1062 let function = self.func;
1063
1064 if !function.within_group.is_empty() {
1066 polars_bail!(SQLInterface: "'WITHIN GROUP' is not currently supported")
1067 }
1068 if function.filter.is_some() {
1069 polars_bail!(SQLInterface: "'FILTER' is not currently supported")
1070 }
1071 if function.null_treatment.is_some() {
1072 polars_bail!(SQLInterface: "'IGNORE|RESPECT NULLS' is not currently supported")
1073 }
1074
1075 let log_with_base =
1076 |e: Expr, base: f64| e.log(LiteralValue::Dyn(DynLiteralValue::Float(base)).lit());
1077
1078 match function_name {
1079 BitAnd => self.visit_binary::<Expr>(Expr::and),
1083 #[cfg(feature = "bitwise")]
1084 BitCount => self.visit_unary(Expr::bitwise_count_ones),
1085 BitNot => self.visit_unary(Expr::not),
1086 BitOr => self.visit_binary::<Expr>(Expr::or),
1087 BitXor => self.visit_binary::<Expr>(Expr::xor),
1088
1089 Abs => self.visit_unary(Expr::abs),
1093 Cbrt => self.visit_unary(Expr::cbrt),
1094 Ceil => self.visit_unary(Expr::ceil),
1095 Div => self.visit_binary(|e, d| e.floor_div(d).cast(DataType::Int64)),
1096 Exp => self.visit_unary(Expr::exp),
1097 Floor => self.visit_unary(Expr::floor),
1098 Ln => self.visit_unary(|e| log_with_base(e, std::f64::consts::E)),
1099 Log => self.visit_binary(Expr::log),
1100 Log10 => self.visit_unary(|e| log_with_base(e, 10.0)),
1101 Log1p => self.visit_unary(Expr::log1p),
1102 Log2 => self.visit_unary(|e| log_with_base(e, 2.0)),
1103 Pi => self.visit_nullary(Expr::pi),
1104 Mod => self.visit_binary(|e1, e2| e1 % e2),
1105 Pow => self.visit_binary::<Expr>(Expr::pow),
1106 Round => {
1107 let args = extract_args(function)?;
1108 match args.len() {
1109 1 => self.visit_unary(|e| e.round(0, RoundMode::default())),
1110 2 => self.try_visit_binary(|e, decimals| {
1111 Ok(e.round(match decimals {
1112 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1113 if n >= 0 { n as u32 } else {
1114 polars_bail!(SQLInterface: "ROUND does not currently support negative decimals value ({})", args[1])
1115 }
1116 },
1117 _ => polars_bail!(SQLSyntax: "invalid value for ROUND decimals ({})", args[1]),
1118 }, RoundMode::default()))
1119 }),
1120 _ => polars_bail!(SQLSyntax: "ROUND expects 1-2 arguments (found {})", args.len()),
1121 }
1122 },
1123 Sign => self.visit_unary(Expr::sign),
1124 Sqrt => self.visit_unary(Expr::sqrt),
1125
1126 Acos => self.visit_unary(Expr::arccos),
1130 AcosD => self.visit_unary(|e| e.arccos().degrees()),
1131 Asin => self.visit_unary(Expr::arcsin),
1132 AsinD => self.visit_unary(|e| e.arcsin().degrees()),
1133 Atan => self.visit_unary(Expr::arctan),
1134 Atan2 => self.visit_binary(Expr::arctan2),
1135 Atan2D => self.visit_binary(|e, s| e.arctan2(s).degrees()),
1136 AtanD => self.visit_unary(|e| e.arctan().degrees()),
1137 Cos => self.visit_unary(Expr::cos),
1138 CosD => self.visit_unary(|e| e.radians().cos()),
1139 Cot => self.visit_unary(Expr::cot),
1140 CotD => self.visit_unary(|e| e.radians().cot()),
1141 Degrees => self.visit_unary(Expr::degrees),
1142 Radians => self.visit_unary(Expr::radians),
1143 Sin => self.visit_unary(Expr::sin),
1144 SinD => self.visit_unary(|e| e.radians().sin()),
1145 Tan => self.visit_unary(Expr::tan),
1146 TanD => self.visit_unary(|e| e.radians().tan()),
1147
1148 Coalesce => self.visit_variadic(coalesce),
1152 Greatest => self.visit_variadic(|exprs: &[Expr]| max_horizontal(exprs).unwrap()),
1153 If => {
1154 let args = extract_args(function)?;
1155 match args.len() {
1156 3 => self.try_visit_ternary(|cond: Expr, expr1: Expr, expr2: Expr| {
1157 Ok(when(cond).then(expr1).otherwise(expr2))
1158 }),
1159 _ => {
1160 polars_bail!(SQLSyntax: "IF expects 3 arguments (found {})", args.len()
1161 )
1162 },
1163 }
1164 },
1165 IfNull => {
1166 let args = extract_args(function)?;
1167 match args.len() {
1168 2 => self.visit_variadic(coalesce),
1169 _ => {
1170 polars_bail!(SQLSyntax: "IFNULL expects 2 arguments (found {})", args.len())
1171 },
1172 }
1173 },
1174 Least => self.visit_variadic(|exprs: &[Expr]| min_horizontal(exprs).unwrap()),
1175 NullIf => {
1176 let args = extract_args(function)?;
1177 match args.len() {
1178 2 => self.visit_binary(|l: Expr, r: Expr| {
1179 when(l.clone().eq(r))
1180 .then(lit(LiteralValue::untyped_null()))
1181 .otherwise(l)
1182 }),
1183 _ => {
1184 polars_bail!(SQLSyntax: "NULLIF expects 2 arguments (found {})", args.len())
1185 },
1186 }
1187 },
1188
1189 DatePart => self.try_visit_binary(|part, e| {
1193 match part {
1194 Expr::Literal(p) if p.extract_str().is_some() => {
1195 let p = p.extract_str().unwrap();
1196 parse_extract_date_part(
1199 e,
1200 &DateTimeField::Custom(Ident {
1201 value: p.to_string(),
1202 quote_style: None,
1203 span: Span::empty(),
1204 }),
1205 )
1206 },
1207 _ => {
1208 polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART ({})", part);
1209 },
1210 }
1211 }),
1212 Strftime => {
1213 let args = extract_args(function)?;
1214 match args.len() {
1215 2 => self.visit_binary(|e, fmt: String| e.dt().strftime(fmt.as_str())),
1216 _ => {
1217 polars_bail!(SQLSyntax: "STRFTIME expects 2 arguments (found {})", args.len())
1218 },
1219 }
1220 },
1221
1222 BitLength => self.visit_unary(|e| e.str().len_bytes() * lit(8)),
1226 Concat => {
1227 let args = extract_args(function)?;
1228 if args.is_empty() {
1229 polars_bail!(SQLSyntax: "CONCAT expects at least 1 argument (found 0)");
1230 } else {
1231 self.visit_variadic(|exprs: &[Expr]| concat_str(exprs, "", true))
1232 }
1233 },
1234 ConcatWS => {
1235 let args = extract_args(function)?;
1236 if args.len() < 2 {
1237 polars_bail!(SQLSyntax: "CONCAT_WS expects at least 2 arguments (found {})", args.len());
1238 } else {
1239 self.try_visit_variadic(|exprs: &[Expr]| {
1240 match &exprs[0] {
1241 Expr::Literal(lv) if lv.extract_str().is_some() => Ok(concat_str(&exprs[1..], lv.extract_str().unwrap(), true)),
1242 _ => polars_bail!(SQLSyntax: "CONCAT_WS 'separator' must be a literal string (found {:?})", exprs[0]),
1243 }
1244 })
1245 }
1246 },
1247 Date => {
1248 let args = extract_args(function)?;
1249 match args.len() {
1250 1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())),
1251 2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
1252 _ => {
1253 polars_bail!(SQLSyntax: "DATE expects 1-2 arguments (found {})", args.len())
1254 },
1255 }
1256 },
1257 EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)),
1258 #[cfg(feature = "nightly")]
1259 InitCap => self.visit_unary(|e| e.str().to_titlecase()),
1260 Left => self.try_visit_binary(|e, length| {
1261 Ok(match length {
1262 Expr::Literal(lv) if lv.is_null() => lit(lv),
1263 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => lit(""),
1264 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1265 let len = if n > 0 {
1266 lit(n)
1267 } else {
1268 (e.clone().str().len_chars() + lit(n)).clip_min(lit(0))
1269 };
1270 e.str().slice(lit(0), len)
1271 },
1272 Expr::Literal(v) => {
1273 polars_bail!(SQLSyntax: "invalid 'n_chars' for LEFT ({:?})", v)
1274 },
1275 _ => when(length.clone().gt_eq(lit(0)))
1276 .then(e.clone().str().slice(lit(0), length.clone().abs()))
1277 .otherwise(e.clone().str().slice(
1278 lit(0),
1279 (e.str().len_chars() + length.clone()).clip_min(lit(0)),
1280 )),
1281 })
1282 }),
1283 Length => self.visit_unary(|e| e.str().len_chars()),
1284 Lower => self.visit_unary(|e| e.str().to_lowercase()),
1285 LTrim => {
1286 let args = extract_args(function)?;
1287 match args.len() {
1288 1 => self.visit_unary(|e| {
1289 e.str().strip_chars_start(lit(LiteralValue::untyped_null()))
1290 }),
1291 2 => self.visit_binary(|e, s| e.str().strip_chars_start(s)),
1292 _ => {
1293 polars_bail!(SQLSyntax: "LTRIM expects 1-2 arguments (found {})", args.len())
1294 },
1295 }
1296 },
1297 Normalize => {
1298 let args = extract_args(function)?;
1299 match args.len() {
1300 1 => self.visit_unary(|e| e.str().normalize(UnicodeForm::NFC)),
1301 2 => {
1302 let form = if let FunctionArgExpr::Expr(SQLExpr::Identifier(Ident {
1303 value: s,
1304 quote_style: None,
1305 span: _,
1306 })) = args[1]
1307 {
1308 match s.to_uppercase().as_str() {
1309 "NFC" => UnicodeForm::NFC,
1310 "NFD" => UnicodeForm::NFD,
1311 "NFKC" => UnicodeForm::NFKC,
1312 "NFKD" => UnicodeForm::NFKD,
1313 _ => {
1314 polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", s)
1315 },
1316 }
1317 } else {
1318 polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", args[1])
1319 };
1320 self.try_visit_binary(|e, _form: Expr| Ok(e.str().normalize(form.clone())))
1321 },
1322 _ => {
1323 polars_bail!(SQLSyntax: "NORMALIZE expects 1-2 arguments (found {})", args.len())
1324 },
1325 }
1326 },
1327 OctetLength => self.visit_unary(|e| e.str().len_bytes()),
1328 StrPos => {
1329 self.visit_binary(|expr, substring| {
1331 (expr.str().find(substring, true) + typed_lit(1u32)).fill_null(typed_lit(0u32))
1332 })
1333 },
1334 RegexpLike => {
1335 let args = extract_args(function)?;
1336 match args.len() {
1337 2 => self.visit_binary(|e, s| e.str().contains(s, true)),
1338 3 => self.try_visit_ternary(|e, pat, flags| {
1339 Ok(e.str().contains(
1340 match (pat, flags) {
1341 (Expr::Literal(s_lv), Expr::Literal(f_lv)) if s_lv.extract_str().is_some() && f_lv.extract_str().is_some() => {
1342 let s = s_lv.extract_str().unwrap();
1343 let f = f_lv.extract_str().unwrap();
1344 if f.is_empty() {
1345 polars_bail!(SQLSyntax: "invalid/empty 'flags' for REGEXP_LIKE ({})", args[2]);
1346 };
1347 lit(format!("(?{f}){s}"))
1348 },
1349 _ => {
1350 polars_bail!(SQLSyntax: "invalid arguments for REGEXP_LIKE ({}, {})", args[1], args[2]);
1351 },
1352 },
1353 true))
1354 }),
1355 _ => polars_bail!(SQLSyntax: "REGEXP_LIKE expects 2-3 arguments (found {})",args.len()),
1356 }
1357 },
1358 Replace => {
1359 let args = extract_args(function)?;
1360 match args.len() {
1361 3 => self
1362 .try_visit_ternary(|e, old, new| Ok(e.str().replace_all(old, new, true))),
1363 _ => {
1364 polars_bail!(SQLSyntax: "REPLACE expects 3 arguments (found {})", args.len())
1365 },
1366 }
1367 },
1368 Reverse => self.visit_unary(|e| e.str().reverse()),
1369 Right => self.try_visit_binary(|e, length| {
1370 Ok(match length {
1371 Expr::Literal(lv) if lv.is_null() => lit(lv),
1372 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => typed_lit(""),
1373 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1374 let n: i64 = n.try_into().unwrap();
1375 let offset = if n < 0 {
1376 lit(n.abs())
1377 } else {
1378 e.clone().str().len_chars().cast(DataType::Int32) - lit(n)
1379 };
1380 e.str().slice(offset, lit(LiteralValue::untyped_null()))
1381 },
1382 Expr::Literal(v) => {
1383 polars_bail!(SQLSyntax: "invalid 'n_chars' for RIGHT ({:?})", v)
1384 },
1385 _ => when(length.clone().lt(lit(0)))
1386 .then(
1387 e.clone()
1388 .str()
1389 .slice(length.clone().abs(), lit(LiteralValue::untyped_null())),
1390 )
1391 .otherwise(e.clone().str().slice(
1392 e.str().len_chars().cast(DataType::Int32) - length.clone(),
1393 lit(LiteralValue::untyped_null()),
1394 )),
1395 })
1396 }),
1397 RTrim => {
1398 let args = extract_args(function)?;
1399 match args.len() {
1400 1 => self.visit_unary(|e| {
1401 e.str().strip_chars_end(lit(LiteralValue::untyped_null()))
1402 }),
1403 2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)),
1404 _ => {
1405 polars_bail!(SQLSyntax: "RTRIM expects 1-2 arguments (found {})", args.len())
1406 },
1407 }
1408 },
1409 SplitPart => {
1410 let args = extract_args(function)?;
1411 match args.len() {
1412 3 => self.try_visit_ternary(|e, sep, idx| {
1413 let idx = adjust_one_indexed_param(idx, true);
1414 Ok(when(e.clone().is_not_null())
1415 .then(
1416 e.clone()
1417 .str()
1418 .split(sep)
1419 .list()
1420 .get(idx, true)
1421 .fill_null(lit("")),
1422 )
1423 .otherwise(e))
1424 }),
1425 _ => {
1426 polars_bail!(SQLSyntax: "SPLIT_PART expects 3 arguments (found {})", args.len())
1427 },
1428 }
1429 },
1430 StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
1431 StringToArray => {
1432 let args = extract_args(function)?;
1433 match args.len() {
1434 2 => self.visit_binary(|e, sep| e.str().split(sep)),
1435 _ => {
1436 polars_bail!(SQLSyntax: "STRING_TO_ARRAY expects 2 arguments (found {})", args.len())
1437 },
1438 }
1439 },
1440 Strptime => {
1441 let args = extract_args(function)?;
1442 match args.len() {
1443 2 => self.visit_binary(|e, fmt: String| {
1444 e.str().strptime(
1445 DataType::Datetime(TimeUnit::Microseconds, None),
1446 StrptimeOptions {
1447 format: Some(fmt.into()),
1448 ..Default::default()
1449 },
1450 lit("latest"),
1451 )
1452 }),
1453 _ => {
1454 polars_bail!(SQLSyntax: "STRPTIME expects 2 arguments (found {})", args.len())
1455 },
1456 }
1457 },
1458 Time => {
1459 let args = extract_args(function)?;
1460 match args.len() {
1461 1 => self.visit_unary(|e| e.str().to_time(StrptimeOptions::default())),
1462 2 => self.visit_binary(|e, fmt| e.str().to_time(fmt)),
1463 _ => {
1464 polars_bail!(SQLSyntax: "TIME expects 1-2 arguments (found {})", args.len())
1465 },
1466 }
1467 },
1468 Timestamp => {
1469 let args = extract_args(function)?;
1470 match args.len() {
1471 1 => self.visit_unary(|e| {
1472 e.str()
1473 .to_datetime(None, None, StrptimeOptions::default(), lit("latest"))
1474 }),
1475 2 => self
1476 .visit_binary(|e, fmt| e.str().to_datetime(None, None, fmt, lit("latest"))),
1477 _ => {
1478 polars_bail!(SQLSyntax: "DATETIME expects 1-2 arguments (found {})", args.len())
1479 },
1480 }
1481 },
1482 Substring => {
1483 let args = extract_args(function)?;
1484 match args.len() {
1485 2 => self.try_visit_binary(|e, start| {
1487 Ok(match start {
1488 Expr::Literal(lv) if lv.is_null() => lit(lv),
1489 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1490 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null())),
1491 Expr::Literal(_) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1492 _ => start.clone() + lit(1),
1493 })
1494 }),
1495 3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| {
1496 Ok(match (start.clone(), length.clone()) {
1497 (Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1498 (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1499 polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", args[2])
1500 },
1501 (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => e.str().slice(lit(n - 1), length),
1502 (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => {
1503 e.str().slice(lit(0), (length + lit(n - 1)).clip_min(lit(0)))
1504 },
1505 (Expr::Literal(_), _) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1506 (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1507 polars_bail!(SQLSyntax: "invalid 'length' for SUBSTR ({})", args[1])
1508 },
1509 _ => {
1510 let adjusted_start = start - lit(1);
1511 when(adjusted_start.clone().lt(lit(0)))
1512 .then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0))))
1513 .otherwise(e.str().slice(adjusted_start, length))
1514 }
1515 })
1516 }),
1517 _ => polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found {})", args.len()),
1518 }
1519 },
1520 Upper => self.visit_unary(|e| e.str().to_uppercase()),
1521
1522 Avg => self.visit_unary(Expr::mean),
1526 Corr => self.visit_binary(polars_lazy::dsl::pearson_corr),
1527 Count => self.visit_count(),
1528 CovarPop => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 0)),
1529 CovarSamp => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 1)),
1530 First => self.visit_unary(Expr::first),
1531 Last => self.visit_unary(Expr::last),
1532 Max => self.visit_unary_with_opt_cumulative(Expr::max, Expr::cum_max),
1533 Median => self.visit_unary(Expr::median),
1534 QuantileCont => {
1535 let args = extract_args(function)?;
1536 match args.len() {
1537 2 => self.try_visit_binary(|e, q| {
1538 let value = match q {
1539 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
1540 if (0.0..=1.0).contains(&f) {
1541 Expr::from(f)
1542 } else {
1543 polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
1544 }
1545 },
1546 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1547 if (0..=1).contains(&n) {
1548 Expr::from(n as f64)
1549 } else {
1550 polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
1551 }
1552 },
1553 _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_CONT ({})", args[1])
1554 };
1555 Ok(e.quantile(value, QuantileMethod::Linear))
1556 }),
1557 _ => polars_bail!(SQLSyntax: "QUANTILE_CONT expects 2 arguments (found {})", args.len()),
1558 }
1559 },
1560 QuantileDisc => {
1561 let args = extract_args(function)?;
1562 match args.len() {
1563 2 => self.try_visit_binary(|e, q| {
1564 let value = match q {
1565 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
1566 if (0.0..=1.0).contains(&f) {
1567 Expr::from(f)
1568 } else {
1569 polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1])
1570 }
1571 },
1572 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1573 if (0..=1).contains(&n) {
1574 Expr::from(n as f64)
1575 } else {
1576 polars_bail!(SQLSyntax: "QUANTILE_DISC value must be between 0 and 1 ({})", args[1])
1577 }
1578 },
1579 _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_DISC ({})", args[1])
1580 };
1581 Ok(e.quantile(value, QuantileMethod::Equiprobable))
1582 }),
1583 _ => polars_bail!(SQLSyntax: "QUANTILE_DISC expects 2 arguments (found {})", args.len()),
1584 }
1585 },
1586 Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min),
1587 StdDev => self.visit_unary(|e| e.std(1)),
1588 Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum),
1589 Variance => self.visit_unary(|e| e.var(1)),
1590
1591 ArrayAgg => self.visit_arr_agg(),
1595 ArrayContains => self.visit_binary::<Expr>(|e, s| e.list().contains(s, true)),
1596 ArrayGet => {
1597 self.visit_binary(|e, idx: Expr| {
1599 let idx = adjust_one_indexed_param(idx, true);
1600 e.list().get(idx, true)
1601 })
1602 },
1603 ArrayLength => self.visit_unary(|e| e.list().len()),
1604 ArrayMax => self.visit_unary(|e| e.list().max()),
1605 ArrayMean => self.visit_unary(|e| e.list().mean()),
1606 ArrayMin => self.visit_unary(|e| e.list().min()),
1607 ArrayReverse => self.visit_unary(|e| e.list().reverse()),
1608 ArraySum => self.visit_unary(|e| e.list().sum()),
1609 ArrayToString => self.visit_arr_to_string(),
1610 ArrayUnique => self.visit_unary(|e| e.list().unique_stable()),
1611 Explode => self.visit_unary(|e| {
1612 e.explode(ExplodeOptions {
1613 empty_as_null: true,
1614 keep_nulls: true,
1615 })
1616 }),
1617
1618 Columns => {
1622 let active_schema = self.active_schema;
1623 self.try_visit_unary(|e: Expr| match e {
1624 Expr::Literal(lv) if lv.extract_str().is_some() => {
1625 let pat = lv.extract_str().unwrap();
1626 if pat == "*" {
1627 polars_bail!(
1628 SQLSyntax: "COLUMNS('*') is not a valid regex; \
1629 did you mean COLUMNS(*)?"
1630 )
1631 };
1632 let pat = match pat {
1633 _ if pat.starts_with('^') && pat.ends_with('$') => pat.to_string(),
1634 _ if pat.starts_with('^') => format!("{pat}.*$"),
1635 _ if pat.ends_with('$') => format!("^.*{pat}"),
1636 _ => format!("^.*{pat}.*$"),
1637 };
1638 if let Some(active_schema) = &active_schema {
1639 let rx = polars_utils::regex_cache::compile_regex(&pat).unwrap();
1640 let col_names = active_schema
1641 .iter_names()
1642 .filter(|name| rx.is_match(name))
1643 .cloned()
1644 .collect::<Vec<_>>();
1645
1646 Ok(if col_names.len() == 1 {
1647 col(col_names.into_iter().next().unwrap())
1648 } else {
1649 cols(col_names).as_expr()
1650 })
1651 } else {
1652 Ok(col(pat.as_str()))
1653 }
1654 },
1655 Expr::Selector(s) => Ok(s.as_expr()),
1656 _ => polars_bail!(SQLSyntax: "COLUMNS expects a regex; found {:?}", e),
1657 })
1658 },
1659
1660 FirstValue => self.visit_unary(Expr::first),
1664 LastValue => {
1665 let args = extract_args(function)?;
1669 match args.as_slice() {
1670 [FunctionArgExpr::Expr(sql_expr)] => {
1671 parse_sql_expr(sql_expr, self.ctx, self.active_schema)
1672 },
1673 _ => polars_bail!(
1674 SQLSyntax: "LAST_VALUE expects exactly 1 argument (found {})",
1675 args.len()
1676 ),
1677 }
1678 },
1679 Lag => self.visit_window_offset_function(1),
1680 Lead => self.visit_window_offset_function(-1),
1681 #[cfg(feature = "rank")]
1682 Rank | DenseRank => {
1683 let (func_name, rank_method) = match function_name {
1684 Rank => ("RANK", RankMethod::Min),
1685 DenseRank => ("DENSE_RANK", RankMethod::Dense),
1686 _ => unreachable!(),
1687 };
1688 let args = extract_args(function)?;
1689 if !args.is_empty() {
1690 polars_bail!(SQLSyntax: "{} expects 0 arguments (found {})", func_name, args.len());
1691 }
1692 let window_spec = match &self.func.over {
1693 Some(WindowType::WindowSpec(spec)) if !spec.order_by.is_empty() => spec,
1694 _ => {
1695 polars_bail!(SQLSyntax: "{} requires an OVER clause with ORDER BY", func_name)
1696 },
1697 };
1698 let (order_exprs, all_desc) =
1699 self.parse_order_by_in_window(&window_spec.order_by)?;
1700 let rank_expr = if order_exprs.len() == 1 {
1701 order_exprs[0].clone().rank(
1702 RankOptions {
1703 method: rank_method,
1704 descending: all_desc,
1705 },
1706 None,
1707 )
1708 } else {
1709 as_struct(order_exprs).rank(
1710 RankOptions {
1711 method: rank_method,
1712 descending: all_desc,
1713 },
1714 None,
1715 )
1716 };
1717 self.apply_window_spec(rank_expr, &self.func.over)
1718 },
1719 RowNumber => {
1720 let args = extract_args(function)?;
1721 if !args.is_empty() {
1722 polars_bail!(SQLSyntax: "ROW_NUMBER expects 0 arguments (found {})", args.len());
1723 }
1724 let row_num_expr = int_range(lit(0i64), len(), 1, DataType::UInt32) + lit(1u32);
1726 self.apply_window_spec(row_num_expr, &self.func.over)
1727 },
1728
1729 Udf(func_name) => self.visit_udf(&func_name),
1733 }
1734 }
1735
1736 fn visit_window_offset_function(&mut self, offset_multiplier: i64) -> PolarsResult<Expr> {
1737 if self.func.over.is_none() {
1739 polars_bail!(SQLSyntax: "{} requires an OVER clause", self.func.name);
1740 }
1741
1742 let window_type = self.func.over.as_ref().unwrap();
1744 let window_spec = self.resolve_window_spec(window_type)?;
1745 if window_spec.order_by.is_empty() {
1746 polars_bail!(SQLSyntax: "{} requires an ORDER BY in the OVER clause", self.func.name);
1747 }
1748
1749 let args = extract_args(self.func)?;
1750
1751 match args.as_slice() {
1752 [FunctionArgExpr::Expr(sql_expr)] => {
1753 let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1754 Ok(expr.shift(offset_multiplier.into()))
1755 },
1756 [FunctionArgExpr::Expr(sql_expr), FunctionArgExpr::Expr(offset_expr)] => {
1757 let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1758 let offset = parse_sql_expr(offset_expr, self.ctx, self.active_schema)?;
1759 if let Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) = offset {
1760 if n <= 0 {
1761 polars_bail!(SQLSyntax: "offset must be positive (found {})", n)
1762 }
1763 Ok(expr.shift((offset_multiplier * n as i64).into()))
1764 } else {
1765 polars_bail!(SQLSyntax: "offset must be an integer (found {:?})", offset)
1766 }
1767 },
1768 _ => polars_bail!(SQLSyntax: "{} expects 1 or 2 arguments (found {})", self.func.name, args.len()),
1769 }.and_then(|e| self.apply_window_spec(e, &self.func.over))
1770 }
1771
1772 fn visit_udf(&mut self, func_name: &str) -> PolarsResult<Expr> {
1773 let args = extract_args(self.func)?
1774 .into_iter()
1775 .map(|arg| {
1776 if let FunctionArgExpr::Expr(e) = arg {
1777 parse_sql_expr(e, self.ctx, self.active_schema)
1778 } else {
1779 polars_bail!(SQLInterface: "only expressions are supported in UDFs")
1780 }
1781 })
1782 .collect::<PolarsResult<Vec<_>>>()?;
1783
1784 Ok(self
1785 .ctx
1786 .function_registry
1787 .get_udf(func_name)?
1788 .ok_or_else(|| polars_err!(SQLInterface: "UDF {} not found", func_name))?
1789 .call(args))
1790 }
1791
1792 fn validate_window_frame(&self, window_frame: &Option<WindowFrame>) -> PolarsResult<()> {
1805 if let Some(frame) = window_frame {
1806 match frame.units {
1807 WindowFrameUnits::Range => {
1808 polars_bail!(
1809 SQLInterface:
1810 "RANGE-based window frames are not supported"
1811 );
1812 },
1813 WindowFrameUnits::Groups => {
1814 polars_bail!(
1815 SQLInterface:
1816 "GROUPS-based window frames are not supported"
1817 );
1818 },
1819 WindowFrameUnits::Rows => {
1820 if !matches!(
1821 (&frame.start_bound, &frame.end_bound),
1822 (
1823 WindowFrameBound::Preceding(None), None | Some(WindowFrameBound::CurrentRow) )
1826 ) {
1827 polars_bail!(
1828 SQLInterface:
1829 "only 'ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW' is currently supported; found 'ROWS BETWEEN {} AND {}'",
1830 frame.start_bound,
1831 frame.end_bound.as_ref().map_or("CURRENT ROW", |b| {
1832 match b {
1833 WindowFrameBound::CurrentRow => "CURRENT ROW",
1834 WindowFrameBound::Preceding(_) => "N PRECEDING",
1835 WindowFrameBound::Following(_) => "N FOLLOWING",
1836 }
1837 })
1838 );
1839 }
1840 },
1841 }
1842 }
1843 Ok(())
1844 }
1845
1846 fn apply_cumulative_window(
1870 &mut self,
1871 f: impl Fn(Expr) -> Expr,
1872 cumulative_fn: impl Fn(Expr, bool) -> Expr,
1873 WindowSpec {
1874 partition_by,
1875 order_by,
1876 window_frame,
1877 ..
1878 }: &WindowSpec,
1879 ) -> PolarsResult<Expr> {
1880 self.validate_window_frame(window_frame)?;
1881
1882 if !order_by.is_empty() {
1883 let (order_by_exprs, all_desc) = self.parse_order_by_in_window(order_by)?;
1885
1886 let args = extract_args(self.func)?;
1888 let base_expr = match args.as_slice() {
1889 [FunctionArgExpr::Expr(sql_expr)] => {
1890 parse_sql_expr(sql_expr, self.ctx, self.active_schema)?
1891 },
1892 _ => return self.not_supported_error(),
1893 };
1894 let partition_by_exprs = if partition_by.is_empty() {
1895 None
1896 } else {
1897 Some(
1898 partition_by
1899 .iter()
1900 .map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
1901 .collect::<PolarsResult<Vec<_>>>()?,
1902 )
1903 };
1904
1905 let cumulative_expr = cumulative_fn(base_expr, false);
1907 let sort_opts = SortOptions::default().with_order_descending(all_desc);
1908 cumulative_expr.over_with_options(
1909 partition_by_exprs,
1910 Some((order_by_exprs, sort_opts)),
1911 Default::default(),
1912 )
1913 } else {
1914 self.visit_unary(f)
1915 }
1916 }
1917
1918 fn visit_unary(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult<Expr> {
1919 self.try_visit_unary(|e| Ok(f(e)))
1920 }
1921
1922 fn try_visit_unary(&mut self, f: impl Fn(Expr) -> PolarsResult<Expr>) -> PolarsResult<Expr> {
1923 let args = extract_args(self.func)?;
1924 match args.as_slice() {
1925 [FunctionArgExpr::Expr(sql_expr)] => {
1926 f(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?)
1927 },
1928 [FunctionArgExpr::Wildcard] => f(parse_sql_expr(
1929 &SQLExpr::Wildcard(AttachedToken::empty()),
1930 self.ctx,
1931 self.active_schema,
1932 )?),
1933 _ => self.not_supported_error(),
1934 }
1935 .and_then(|e| self.apply_window_spec(e, &self.func.over))
1936 }
1937
1938 fn resolve_window_spec(&self, window_type: &WindowType) -> PolarsResult<WindowSpec> {
1940 match window_type {
1941 WindowType::WindowSpec(spec) => Ok(spec.clone()),
1942 WindowType::NamedWindow(name) => self
1943 .ctx
1944 .named_windows
1945 .get(&name.value)
1946 .cloned()
1947 .ok_or_else(|| {
1948 polars_err!(
1949 SQLInterface:
1950 "named window '{}' was not found",
1951 name.value
1952 )
1953 }),
1954 }
1955 }
1956
1957 fn visit_unary_with_opt_cumulative(
1960 &mut self,
1961 f: impl Fn(Expr) -> Expr,
1962 cumulative_fn: impl Fn(Expr, bool) -> Expr,
1963 ) -> PolarsResult<Expr> {
1964 match self.func.over.as_ref() {
1965 Some(window_type) => {
1966 let spec = self.resolve_window_spec(window_type)?;
1967 self.apply_cumulative_window(f, cumulative_fn, &spec)
1968 },
1969 None => self.visit_unary(f),
1970 }
1971 }
1972
1973 fn visit_binary<Arg: FromSQLExpr>(
1974 &mut self,
1975 f: impl Fn(Expr, Arg) -> Expr,
1976 ) -> PolarsResult<Expr> {
1977 self.try_visit_binary(|e, a| Ok(f(e, a)))
1978 }
1979
1980 fn try_visit_binary<Arg: FromSQLExpr>(
1981 &mut self,
1982 f: impl Fn(Expr, Arg) -> PolarsResult<Expr>,
1983 ) -> PolarsResult<Expr> {
1984 let args = extract_args(self.func)?;
1985 match args.as_slice() {
1986 [
1987 FunctionArgExpr::Expr(sql_expr1),
1988 FunctionArgExpr::Expr(sql_expr2),
1989 ] => {
1990 let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?;
1991 let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
1992 f(expr1, expr2)
1993 },
1994 _ => self.not_supported_error(),
1995 }
1996 }
1997
1998 fn visit_variadic(&mut self, f: impl Fn(&[Expr]) -> Expr) -> PolarsResult<Expr> {
1999 self.try_visit_variadic(|e| Ok(f(e)))
2000 }
2001
2002 fn try_visit_variadic(
2003 &mut self,
2004 f: impl Fn(&[Expr]) -> PolarsResult<Expr>,
2005 ) -> PolarsResult<Expr> {
2006 let args = extract_args(self.func)?;
2007 let mut expr_args = vec![];
2008 for arg in args {
2009 if let FunctionArgExpr::Expr(sql_expr) = arg {
2010 expr_args.push(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?);
2011 } else {
2012 return self.not_supported_error();
2013 };
2014 }
2015 f(&expr_args)
2016 }
2017
2018 fn try_visit_ternary<Arg: FromSQLExpr>(
2019 &mut self,
2020 f: impl Fn(Expr, Arg, Arg) -> PolarsResult<Expr>,
2021 ) -> PolarsResult<Expr> {
2022 let args = extract_args(self.func)?;
2023 match args.as_slice() {
2024 [
2025 FunctionArgExpr::Expr(sql_expr1),
2026 FunctionArgExpr::Expr(sql_expr2),
2027 FunctionArgExpr::Expr(sql_expr3),
2028 ] => {
2029 let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?;
2030 let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
2031 let expr3 = Arg::from_sql_expr(sql_expr3, self.ctx)?;
2032 f(expr1, expr2, expr3)
2033 },
2034 _ => self.not_supported_error(),
2035 }
2036 }
2037
2038 fn visit_nullary(&self, f: impl Fn() -> Expr) -> PolarsResult<Expr> {
2039 let args = extract_args(self.func)?;
2040 if !args.is_empty() {
2041 return self.not_supported_error();
2042 }
2043 Ok(f())
2044 }
2045
2046 fn visit_arr_agg(&mut self) -> PolarsResult<Expr> {
2047 let (args, is_distinct, clauses) = extract_args_and_clauses(self.func)?;
2048 match args.as_slice() {
2049 [FunctionArgExpr::Expr(sql_expr)] => {
2050 let mut base = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
2051 let mut order_by_clause = None;
2052 let mut limit_clause = None;
2053 for clause in &clauses {
2054 match clause {
2055 FunctionArgumentClause::OrderBy(order_exprs) => {
2056 order_by_clause = Some(order_exprs.as_slice());
2057 },
2058 FunctionArgumentClause::Limit(limit_expr) => {
2059 limit_clause = Some(limit_expr);
2060 },
2061 _ => {},
2062 }
2063 }
2064 if !is_distinct {
2065 if let Some(order_by) = order_by_clause {
2067 base = self.apply_order_by(base, order_by)?;
2068 }
2069 } else {
2070 base = base.unique_stable();
2072 if let Some(order_by) = order_by_clause {
2073 base = self.apply_order_by_to_distinct_array(base, order_by, sql_expr)?;
2074 }
2075 }
2076 if let Some(limit_expr) = limit_clause {
2077 let limit = parse_sql_expr(limit_expr, self.ctx, self.active_schema)?;
2078 match limit {
2079 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n >= 0 => {
2080 base = base.head(Some(n as usize))
2081 },
2082 _ => {
2083 polars_bail!(SQLSyntax: "LIMIT in ARRAY_AGG must be a positive integer")
2084 },
2085 };
2086 }
2087 Ok(base.implode())
2088 },
2089 _ => {
2090 polars_bail!(SQLSyntax: "ARRAY_AGG must have exactly one argument; found {}", args.len())
2091 },
2092 }
2093 }
2094
2095 fn visit_arr_to_string(&mut self) -> PolarsResult<Expr> {
2096 let args = extract_args(self.func)?;
2097 match args.len() {
2098 2 => self.try_visit_binary(|e, sep| {
2099 Ok(e.cast(DataType::List(Box::from(DataType::String)))
2100 .list()
2101 .join(sep, true))
2102 }),
2103 #[cfg(feature = "list_eval")]
2104 3 => self.try_visit_ternary(|e, sep, null_value| match null_value {
2105 Expr::Literal(lv) if lv.extract_str().is_some() => {
2106 Ok(if lv.extract_str().unwrap().is_empty() {
2107 e.cast(DataType::List(Box::from(DataType::String)))
2108 .list()
2109 .join(sep, true)
2110 } else {
2111 e.cast(DataType::List(Box::from(DataType::String)))
2112 .list()
2113 .eval(element().fill_null(lit(lv.extract_str().unwrap())))
2114 .list()
2115 .join(sep, false)
2116 })
2117 },
2118 _ => {
2119 polars_bail!(SQLSyntax: "invalid null value for ARRAY_TO_STRING ({})", args[2])
2120 },
2121 }),
2122 _ => {
2123 polars_bail!(SQLSyntax: "ARRAY_TO_STRING expects 2-3 arguments (found {})", args.len())
2124 },
2125 }
2126 }
2127
2128 fn visit_count(&mut self) -> PolarsResult<Expr> {
2129 let (args, is_distinct) = extract_args_distinct(self.func)?;
2130
2131 let has_order_by = match &self.func.over {
2133 Some(WindowType::WindowSpec(spec)) => !spec.order_by.is_empty(),
2134 _ => false,
2135 };
2136 if has_order_by && !is_distinct {
2137 if let Some(WindowType::WindowSpec(spec)) = &self.func.over {
2138 self.validate_window_frame(&spec.window_frame)?;
2139
2140 match args.as_slice() {
2141 [FunctionArgExpr::Wildcard] | [] => {
2142 let (order_by_exprs, all_desc) =
2144 self.parse_order_by_in_window(&spec.order_by)?;
2145 let partition_by_exprs = if spec.partition_by.is_empty() {
2146 None
2147 } else {
2148 Some(
2149 spec.partition_by
2150 .iter()
2151 .map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
2152 .collect::<PolarsResult<Vec<_>>>()?,
2153 )
2154 };
2155 let sort_opts = SortOptions::default().with_order_descending(all_desc);
2156 let row_number = int_range(lit(0), len(), 1, DataType::Int64).add(lit(1)); return row_number.over_with_options(
2159 partition_by_exprs,
2160 Some((order_by_exprs, sort_opts)),
2161 Default::default(),
2162 );
2163 },
2164 [FunctionArgExpr::Expr(_)] => {
2165 return self.visit_unary_with_opt_cumulative(
2167 |e| e.count(),
2168 |e, reverse| e.cum_count(reverse),
2169 );
2170 },
2171 _ => {},
2172 }
2173 }
2174 }
2175 let count_expr = match (is_distinct, args.as_slice()) {
2176 (false, [FunctionArgExpr::Wildcard] | []) => len(),
2178 (false, [FunctionArgExpr::Expr(sql_expr)]) => {
2180 let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
2181 expr.count()
2182 },
2183 (true, [FunctionArgExpr::Expr(sql_expr)]) => {
2185 let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
2186 expr.clone().n_unique().sub(expr.null_count().gt(lit(0)))
2187 },
2188 _ => self.not_supported_error()?,
2189 };
2190 self.apply_window_spec(count_expr, &self.func.over)
2191 }
2192
2193 fn apply_order_by(&mut self, expr: Expr, order_by: &[OrderByExpr]) -> PolarsResult<Expr> {
2194 let mut by = Vec::with_capacity(order_by.len());
2195 let mut descending = Vec::with_capacity(order_by.len());
2196 let mut nulls_last = Vec::with_capacity(order_by.len());
2197
2198 for ob in order_by {
2199 let desc_order = !ob.options.asc.unwrap_or(true);
2202 by.push(parse_sql_expr(&ob.expr, self.ctx, self.active_schema)?);
2203 nulls_last.push(!ob.options.nulls_first.unwrap_or(desc_order));
2204 descending.push(desc_order);
2205 }
2206 Ok(expr.sort_by(
2207 by,
2208 SortMultipleOptions::default()
2209 .with_order_descending_multi(descending)
2210 .with_nulls_last_multi(nulls_last)
2211 .with_maintain_order(true),
2212 ))
2213 }
2214
2215 fn apply_order_by_to_distinct_array(
2216 &mut self,
2217 expr: Expr,
2218 order_by: &[OrderByExpr],
2219 base_sql_expr: &SQLExpr,
2220 ) -> PolarsResult<Expr> {
2221 if order_by.len() == 1 && order_by[0].expr == *base_sql_expr {
2223 let desc_order = !order_by[0].options.asc.unwrap_or(true);
2224 let nulls_last = !order_by[0].options.nulls_first.unwrap_or(desc_order);
2225 return Ok(expr.sort(
2226 SortOptions::default()
2227 .with_order_descending(desc_order)
2228 .with_nulls_last(nulls_last)
2229 .with_maintain_order(true),
2230 ));
2231 }
2232 self.apply_order_by(expr, order_by)
2234 }
2235
2236 fn parse_order_by_in_window(
2238 &mut self,
2239 order_by: &[OrderByExpr],
2240 ) -> PolarsResult<(Vec<Expr>, bool)> {
2241 if order_by.is_empty() {
2242 return Ok((Vec::new(), false));
2243 }
2244 let all_ascending = order_by[0].options.asc.unwrap_or(true);
2246 let mut exprs = Vec::with_capacity(order_by.len());
2247 for o in order_by {
2248 if all_ascending != o.options.asc.unwrap_or(true) {
2249 polars_bail!(
2252 SQLSyntax:
2253 "OVER does not (yet) support mixed asc/desc directions for ORDER BY"
2254 )
2255 }
2256 let expr = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?;
2257 exprs.push(expr);
2258 }
2259 Ok((exprs, !all_ascending))
2260 }
2261
2262 fn apply_window_spec(
2263 &mut self,
2264 expr: Expr,
2265 window_type: &Option<WindowType>,
2266 ) -> PolarsResult<Expr> {
2267 let Some(window_type) = window_type else {
2268 return Ok(expr);
2269 };
2270 let window_spec = self.resolve_window_spec(window_type)?;
2271 self.validate_window_frame(&window_spec.window_frame)?;
2272
2273 let partition_by = if window_spec.partition_by.is_empty() {
2274 None
2275 } else {
2276 Some(
2277 window_spec
2278 .partition_by
2279 .iter()
2280 .map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
2281 .collect::<PolarsResult<Vec<_>>>()?,
2282 )
2283 };
2284 let order_by = if window_spec.order_by.is_empty() {
2285 None
2286 } else {
2287 let (order_exprs, all_desc) = self.parse_order_by_in_window(&window_spec.order_by)?;
2288 let sort_opts = SortOptions::default().with_order_descending(all_desc);
2289 Some((order_exprs, sort_opts))
2290 };
2291
2292 Ok(match (partition_by, order_by) {
2294 (None, None) => expr,
2295 (Some(part), None) => expr.over(part),
2296 (part, Some(order)) => expr.over_with_options(part, Some(order), Default::default())?,
2297 })
2298 }
2299
2300 fn not_supported_error(&self) -> PolarsResult<Expr> {
2301 polars_bail!(
2302 SQLInterface:
2303 "no function matches the given name and arguments: `{}`",
2304 self.func.to_string()
2305 );
2306 }
2307}
2308
2309fn extract_args(func: &SQLFunction) -> PolarsResult<Vec<&FunctionArgExpr>> {
2310 let (args, _, _) = _extract_func_args(func, false, false)?;
2311 Ok(args)
2312}
2313
2314fn extract_args_distinct(func: &SQLFunction) -> PolarsResult<(Vec<&FunctionArgExpr>, bool)> {
2315 let (args, is_distinct, _) = _extract_func_args(func, true, false)?;
2316 Ok((args, is_distinct))
2317}
2318
2319fn extract_args_and_clauses(
2320 func: &SQLFunction,
2321) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
2322 _extract_func_args(func, true, true)
2323}
2324
2325fn _extract_func_args(
2326 func: &SQLFunction,
2327 get_distinct: bool,
2328 get_clauses: bool,
2329) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
2330 match &func.args {
2331 FunctionArguments::List(FunctionArgumentList {
2332 args,
2333 duplicate_treatment,
2334 clauses,
2335 }) => {
2336 let is_distinct = matches!(duplicate_treatment, Some(DuplicateTreatment::Distinct));
2337 if !(get_clauses || get_distinct) && is_distinct {
2338 polars_bail!(SQLSyntax: "unexpected use of DISTINCT found in '{}'", func.name)
2339 } else if !get_clauses && !clauses.is_empty() {
2340 polars_bail!(SQLSyntax: "unexpected clause found in '{}' ({})", func.name, clauses[0])
2341 } else {
2342 let unpacked_args = args
2343 .iter()
2344 .map(|arg| match arg {
2345 FunctionArg::Named { arg, .. } => arg,
2346 FunctionArg::ExprNamed { arg, .. } => arg,
2347 FunctionArg::Unnamed(arg) => arg,
2348 })
2349 .collect();
2350 Ok((unpacked_args, is_distinct, clauses.clone()))
2351 }
2352 },
2353 FunctionArguments::Subquery { .. } => {
2354 Err(polars_err!(SQLInterface: "subquery not expected in {}", func.name))
2355 },
2356 FunctionArguments::None => Ok((vec![], false, vec![])),
2357 }
2358}
2359
2360pub(crate) trait FromSQLExpr {
2361 fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
2362 where
2363 Self: Sized;
2364}
2365
2366impl FromSQLExpr for f64 {
2367 fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
2368 where
2369 Self: Sized,
2370 {
2371 match expr {
2372 SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2373 SQLValue::Number(s, _) => s
2374 .parse()
2375 .map_err(|_| polars_err!(SQLInterface: "cannot parse literal {:?}", s)),
2376 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2377 },
2378 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2379 }
2380 }
2381}
2382
2383impl FromSQLExpr for bool {
2384 fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
2385 where
2386 Self: Sized,
2387 {
2388 match expr {
2389 SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2390 SQLValue::Boolean(v) => Ok(*v),
2391 _ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", v),
2392 },
2393 _ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", expr),
2394 }
2395 }
2396}
2397
2398impl FromSQLExpr for String {
2399 fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2400 where
2401 Self: Sized,
2402 {
2403 match expr {
2404 SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2405 SQLValue::SingleQuotedString(s) => Ok(s.clone()),
2406 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2407 },
2408 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2409 }
2410 }
2411}
2412
2413impl FromSQLExpr for StrptimeOptions {
2414 fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2415 where
2416 Self: Sized,
2417 {
2418 match expr {
2419 SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2420 SQLValue::SingleQuotedString(s) => Ok(StrptimeOptions {
2421 format: Some(PlSmallStr::from_str(s)),
2422 ..StrptimeOptions::default()
2423 }),
2424 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2425 },
2426 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2427 }
2428 }
2429}
2430
2431impl FromSQLExpr for Expr {
2432 fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
2433 where
2434 Self: Sized,
2435 {
2436 parse_sql_expr(expr, ctx, None)
2437 }
2438}