1use std::ops::{Add, Sub};
2
3use polars_core::chunked_array::ops::{FillNullStrategy, SortMultipleOptions, SortOptions};
4use polars_core::prelude::{
5 DataType, ExplodeOptions, PolarsResult, QuantileMethod, Schema, TimeUnit, polars_bail,
6 polars_err,
7};
8use polars_lazy::dsl::Expr;
9#[cfg(feature = "rank")]
10use polars_lazy::prelude::{RankMethod, RankOptions};
11use polars_ops::chunked_array::UnicodeForm;
12use polars_ops::series::RoundMode;
13use polars_plan::dsl::functions::{
14 as_struct, coalesce, col, cols, concat_str, element, int_range, len, lit, max_horizontal,
15 min_horizontal, when,
16};
17use polars_plan::plans::{DynLiteralValue, LiteralValue, typed_lit};
18use polars_plan::prelude::StrptimeOptions;
19use polars_utils::pl_str::PlSmallStr;
20use sqlparser::ast::helpers::attached_token::AttachedToken;
21use sqlparser::ast::{
22 DateTimeField, DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg,
23 FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, Ident,
24 OrderByExpr, Value as SQLValue, ValueWithSpan, WindowFrame, WindowFrameBound, WindowFrameUnits,
25 WindowSpec, WindowType,
26};
27use sqlparser::tokenizer::Span;
28
29use crate::SQLContext;
30use crate::sql_expr::{adjust_one_indexed_param, parse_extract_date_part, parse_sql_expr};
31
32pub(crate) struct SQLFunctionVisitor<'a> {
33 pub(crate) func: &'a SQLFunction,
34 pub(crate) ctx: &'a mut SQLContext,
35 pub(crate) active_schema: Option<&'a Schema>,
36 pub(crate) filter: Option<Expr>,
37}
38
39pub(crate) enum PolarsSQLFunctions {
41 BitAnd,
50 #[cfg(feature = "bitwise")]
56 BitCount,
57 BitNot,
63 BitOr,
69 BitXor,
75
76 Abs,
85 Ceil,
91 Div,
97 Exp,
103 Floor,
110 Pi,
116 Ln,
122 Log2,
128 Log10,
134 Log,
140 Log1p,
146 Pow,
152 Mod,
158 Sqrt,
164 Cbrt,
170 Round,
177 Truncate,
183 Sign,
189
190 Cos,
199 Cot,
205 Sin,
211 Tan,
217 CosD,
223 CotD,
229 SinD,
235 TanD,
241 Acos,
247 Asin,
253 Atan,
259 Atan2,
265 AcosD,
271 AsinD,
277 AtanD,
283 Atan2D,
289 Degrees,
297 Radians,
303
304 DatePart,
313 Strftime,
319
320 BitLength,
328 Concat,
334 ConcatWS,
341 Date,
350 EndsWith,
357 #[cfg(feature = "nightly")]
363 InitCap,
364 Left,
370 LeftPad,
376 LeftTrim,
382 Length,
388 Lower,
394 Normalize,
401 OctetLength,
407 RegexpLike,
413 Replace,
419 Reverse,
425 Right,
431 RightPad,
437 RightTrim,
443 SplitPart,
450 StartsWith,
457 StrPos,
463 Substring,
470 StringToArray,
476 Strptime,
482 Time,
491 Timestamp,
500 Upper,
506
507 Coalesce,
516 Greatest,
522 If,
529 IfNull,
535 Least,
541 NullIf,
547
548 Avg,
557 Corr,
563 Count,
572 CovarPop,
578 CovarSamp,
584 First,
590 Last,
596 Max,
602 Median,
608 QuantileCont,
615 QuantileDisc,
622 Min,
628 StdDev,
634 StringAgg,
643 Sum,
649 Variance,
655
656 ArrayLength,
665 ArrayMin,
671 ArrayMax,
677 ArraySum,
683 ArrayMean,
689 ArrayReverse,
695 ArrayUnique,
701 ArrayAgg,
707 ArrayToString,
714 ArrayGet,
720 ArrayContains,
726 Explode,
732
733 FirstValue,
742 LastValue,
749 Lag,
755 Lead,
761 RowNumber,
768 #[cfg(feature = "rank")]
776 Rank,
777 #[cfg(feature = "rank")]
785 DenseRank,
786
787 Columns,
791
792 Udf(String),
796}
797
798impl PolarsSQLFunctions {
799 pub(crate) fn keywords() -> &'static [&'static str] {
800 &[
801 "abs",
802 "acos",
803 "acosd",
804 "array_contains",
805 "array_get",
806 "array_length",
807 "array_lower",
808 "array_mean",
809 "array_reverse",
810 "array_sum",
811 "array_to_string",
812 "array_unique",
813 "array_upper",
814 "asin",
815 "asind",
816 "atan",
817 "atan2",
818 "atan2d",
819 "atand",
820 "avg",
821 "bit_and",
822 "bit_count",
823 "bit_length",
824 "bit_or",
825 "bit_xor",
826 "cbrt",
827 "ceil",
828 "ceiling",
829 "char_length",
830 "character_length",
831 "coalesce",
832 "columns",
833 "concat",
834 "concat_ws",
835 "corr",
836 "cos",
837 "cosd",
838 "cot",
839 "cotd",
840 "count",
841 "covar",
842 "covar_pop",
843 "covar_samp",
844 "date",
845 "date_part",
846 "degrees",
847 "dense_rank",
848 "ends_with",
849 "exp",
850 "first",
851 "first_value",
852 "floor",
853 "greatest",
854 "if",
855 "ifnull",
856 "initcap",
857 "lag",
858 "last",
859 "last_value",
860 "lead",
861 "least",
862 "left",
863 "length",
864 "ln",
865 "log",
866 "log10",
867 "log1p",
868 "log2",
869 "lower",
870 "lpad",
871 "ltrim",
872 "max",
873 "median",
874 "quantile_disc",
875 "min",
876 "mod",
877 "nullif",
878 "octet_length",
879 "pi",
880 "pow",
881 "power",
882 "quantile_cont",
883 "quantile_disc",
884 "radians",
885 "rank",
886 "regexp_like",
887 "replace",
888 "reverse",
889 "right",
890 "round",
891 "row_number",
892 "rpad",
893 "rtrim",
894 "sign",
895 "sin",
896 "sind",
897 "sqrt",
898 "starts_with",
899 "stddev",
900 "stddev_samp",
901 "stdev",
902 "stdev_samp",
903 "strftime",
904 "strpos",
905 "strptime",
906 "substr",
907 "sum",
908 "tan",
909 "tand",
910 "unnest",
911 "upper",
912 "var",
913 "var_samp",
914 "variance",
915 ]
916 }
917}
918
919impl PolarsSQLFunctions {
920 fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult<Self> {
921 let function_name = function.name.0[0].as_ident().unwrap().value.to_lowercase();
922 Ok(match function_name.as_str() {
923 "bit_and" | "bitand" => Self::BitAnd,
927 #[cfg(feature = "bitwise")]
928 "bit_count" | "bitcount" => Self::BitCount,
929 "bit_not" | "bitnot" => Self::BitNot,
930 "bit_or" | "bitor" => Self::BitOr,
931 "bit_xor" | "bitxor" | "xor" => Self::BitXor,
932
933 "abs" => Self::Abs,
937 "cbrt" => Self::Cbrt,
938 "ceil" | "ceiling" => Self::Ceil,
939 "div" => Self::Div,
940 "exp" => Self::Exp,
941 "floor" => Self::Floor,
942 "ln" => Self::Ln,
943 "log" => Self::Log,
944 "log10" => Self::Log10,
945 "log1p" => Self::Log1p,
946 "log2" => Self::Log2,
947 "mod" => Self::Mod,
948 "pi" => Self::Pi,
949 "pow" | "power" => Self::Pow,
950 "round" => Self::Round,
951 "trunc" | "truncate" => Self::Truncate,
952 "sign" => Self::Sign,
953 "sqrt" => Self::Sqrt,
954
955 "cos" => Self::Cos,
959 "cot" => Self::Cot,
960 "sin" => Self::Sin,
961 "tan" => Self::Tan,
962 "cosd" => Self::CosD,
963 "cotd" => Self::CotD,
964 "sind" => Self::SinD,
965 "tand" => Self::TanD,
966 "acos" => Self::Acos,
967 "asin" => Self::Asin,
968 "atan" => Self::Atan,
969 "atan2" => Self::Atan2,
970 "acosd" => Self::AcosD,
971 "asind" => Self::AsinD,
972 "atand" => Self::AtanD,
973 "atan2d" => Self::Atan2D,
974 "degrees" => Self::Degrees,
975 "radians" => Self::Radians,
976
977 "coalesce" => Self::Coalesce,
981 "greatest" => Self::Greatest,
982 "if" => Self::If,
983 "ifnull" => Self::IfNull,
984 "least" => Self::Least,
985 "nullif" => Self::NullIf,
986
987 "date" => Self::Date,
991 "date_part" => Self::DatePart,
992 "strftime" => Self::Strftime,
993 "timestamp" | "datetime" => Self::Timestamp,
994
995 "bit_length" => Self::BitLength,
999 "concat" => Self::Concat,
1000 "concat_ws" => Self::ConcatWS,
1001 "ends_with" => Self::EndsWith,
1002 #[cfg(feature = "nightly")]
1003 "initcap" => Self::InitCap,
1004 "left" => Self::Left,
1005 "length" | "char_length" | "character_length" => Self::Length,
1006 "lower" => Self::Lower,
1007 "lpad" => Self::LeftPad,
1008 "ltrim" => Self::LeftTrim,
1009 "normalize" => Self::Normalize,
1010 "octet_length" => Self::OctetLength,
1011 "regexp_like" => Self::RegexpLike,
1012 "replace" => Self::Replace,
1013 "reverse" => Self::Reverse,
1014 "right" => Self::Right,
1015 "rpad" => Self::RightPad,
1016 "rtrim" => Self::RightTrim,
1017 "split_part" => Self::SplitPart,
1018 "starts_with" => Self::StartsWith,
1019 "string_to_array" => Self::StringToArray,
1020 "strpos" => Self::StrPos,
1021 "strptime" => Self::Strptime,
1022 "substr" => Self::Substring,
1023 "time" => Self::Time,
1024 "upper" => Self::Upper,
1025
1026 "avg" => Self::Avg,
1030 "corr" => Self::Corr,
1031 "count" => Self::Count,
1032 "covar_pop" => Self::CovarPop,
1033 "covar_samp" | "covar" => Self::CovarSamp,
1034 "first" => Self::First,
1035 "last" => Self::Last,
1036 "max" => Self::Max,
1037 "median" => Self::Median,
1038 "min" => Self::Min,
1039 "quantile_cont" => Self::QuantileCont,
1040 "quantile_disc" => Self::QuantileDisc,
1041 "stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev,
1042 "string_agg" | "listagg" | "group_concat" => Self::StringAgg,
1043 "sum" => Self::Sum,
1044 "var" | "variance" | "var_samp" => Self::Variance,
1045
1046 "array_agg" => Self::ArrayAgg,
1050 "array_contains" => Self::ArrayContains,
1051 "array_get" => Self::ArrayGet,
1052 "array_length" => Self::ArrayLength,
1053 "array_lower" => Self::ArrayMin,
1054 "array_mean" => Self::ArrayMean,
1055 "array_reverse" => Self::ArrayReverse,
1056 "array_sum" => Self::ArraySum,
1057 "array_to_string" => Self::ArrayToString,
1058 "array_unique" => Self::ArrayUnique,
1059 "array_upper" => Self::ArrayMax,
1060 "unnest" => Self::Explode,
1061
1062 #[cfg(feature = "rank")]
1066 "dense_rank" => Self::DenseRank,
1067 "first_value" => Self::FirstValue,
1068 "last_value" => Self::LastValue,
1069 "lag" => Self::Lag,
1070 "lead" => Self::Lead,
1071 #[cfg(feature = "rank")]
1072 "rank" => Self::Rank,
1073 "row_number" => Self::RowNumber,
1074
1075 "columns" => Self::Columns,
1079
1080 other => {
1081 if ctx.function_registry.contains(other) {
1082 Self::Udf(other.to_string())
1083 } else {
1084 polars_bail!(SQLInterface: "unsupported function '{}'", other);
1085 }
1086 },
1087 })
1088 }
1089}
1090
1091impl SQLFunctionVisitor<'_> {
1092 pub(crate) fn visit_function(&mut self) -> PolarsResult<Expr> {
1093 use PolarsSQLFunctions::*;
1094 use polars_lazy::prelude::Literal;
1095
1096 let function_name = PolarsSQLFunctions::try_from_sql(self.func, self.ctx)?;
1097 let function = self.func;
1098
1099 if !function.within_group.is_empty() {
1101 polars_bail!(SQLInterface: "'WITHIN GROUP' is not currently supported")
1102 }
1103 if function.null_treatment.is_some() {
1104 polars_bail!(SQLInterface: "'IGNORE|RESPECT NULLS' is not currently supported")
1105 }
1106 if let Some(filter_expr) = &function.filter {
1107 if function.over.is_some() {
1108 polars_bail!(SQLInterface: "'FILTER' combined with 'OVER' is not supported")
1109 }
1110 self.filter = Some(parse_sql_expr(filter_expr, self.ctx, self.active_schema)?);
1111 }
1112
1113 let log_with_base =
1114 |e: Expr, base: f64| e.log(LiteralValue::Dyn(DynLiteralValue::Float(base)).lit());
1115
1116 match function_name {
1117 BitAnd => self.visit_binary::<Expr>(Expr::and),
1121 #[cfg(feature = "bitwise")]
1122 BitCount => self.visit_unary(Expr::bitwise_count_ones),
1123 BitNot => self.visit_unary(Expr::not),
1124 BitOr => self.visit_binary::<Expr>(Expr::or),
1125 BitXor => self.visit_binary::<Expr>(Expr::xor),
1126
1127 Abs => self.visit_unary(Expr::abs),
1131 Cbrt => self.visit_unary(Expr::cbrt),
1132 Ceil => self.visit_unary(Expr::ceil),
1133 Div => self.visit_binary(|e, d| e.floor_div(d).cast(DataType::Int64)),
1134 Exp => self.visit_unary(Expr::exp),
1135 Floor => self.visit_unary(Expr::floor),
1136 Ln => self.visit_unary(|e| log_with_base(e, std::f64::consts::E)),
1137 Log => self.visit_binary(Expr::log),
1138 Log10 => self.visit_unary(|e| log_with_base(e, 10.0)),
1139 Log1p => self.visit_unary(Expr::log1p),
1140 Log2 => self.visit_unary(|e| log_with_base(e, 2.0)),
1141 Pi => self.visit_nullary(Expr::pi),
1142 Mod => self.visit_binary(|e1, e2| e1 % e2),
1143 Pow => self.visit_binary::<Expr>(Expr::pow),
1144 Round => {
1145 let args = extract_args(function)?;
1146 match args.len() {
1147 1 => self.visit_unary(|e| e.round(0, RoundMode::default())),
1148 2 => self.try_visit_binary(|e, decimals| {
1149 Ok(e.round(match decimals {
1150 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1151 if n >= 0 { n as u32 } else {
1152 polars_bail!(SQLInterface: "ROUND does not support negative decimals value ({})", args[1])
1153 }
1154 },
1155 _ => polars_bail!(SQLSyntax: "invalid value for ROUND decimals ({})", args[1]),
1156 }, RoundMode::default()))
1157 }),
1158 _ => polars_bail!(SQLSyntax: "ROUND expects 1-2 arguments (found {})", args.len()),
1159 }
1160 },
1161 Truncate => {
1162 let args = extract_args(function)?;
1163 match args.len() {
1164 1 => self.visit_unary(|e| e.truncate(0)),
1165 2 => self.try_visit_binary(|e, decimals| {
1166 Ok(e.truncate(match decimals {
1167 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1168 if n >= 0 { n as u32 } else {
1169 polars_bail!(SQLInterface: "TRUNCATE does not support negative decimals value ({})", args[1])
1170 }
1171 },
1172 _ => polars_bail!(SQLSyntax: "invalid value for TRUNCATE decimals ({})", args[1]),
1173 }))
1174 }),
1175 _ => polars_bail!(SQLSyntax: "TRUNCATE expects 1-2 arguments (found {})", args.len()),
1176 }
1177 },
1178 Sign => self.visit_unary(Expr::sign),
1179 Sqrt => self.visit_unary(Expr::sqrt),
1180
1181 Acos => self.visit_unary(Expr::arccos),
1185 AcosD => self.visit_unary(|e| e.arccos().degrees()),
1186 Asin => self.visit_unary(Expr::arcsin),
1187 AsinD => self.visit_unary(|e| e.arcsin().degrees()),
1188 Atan => self.visit_unary(Expr::arctan),
1189 Atan2 => self.visit_binary(Expr::arctan2),
1190 Atan2D => self.visit_binary(|e, s| e.arctan2(s).degrees()),
1191 AtanD => self.visit_unary(|e| e.arctan().degrees()),
1192 Cos => self.visit_unary(Expr::cos),
1193 CosD => self.visit_unary(|e| e.radians().cos()),
1194 Cot => self.visit_unary(Expr::cot),
1195 CotD => self.visit_unary(|e| e.radians().cot()),
1196 Degrees => self.visit_unary(Expr::degrees),
1197 Radians => self.visit_unary(Expr::radians),
1198 Sin => self.visit_unary(Expr::sin),
1199 SinD => self.visit_unary(|e| e.radians().sin()),
1200 Tan => self.visit_unary(Expr::tan),
1201 TanD => self.visit_unary(|e| e.radians().tan()),
1202
1203 Coalesce => self.visit_variadic(coalesce),
1207 Greatest => self.visit_variadic(|exprs: &[Expr]| max_horizontal(exprs).unwrap()),
1208 If => {
1209 let args = extract_args(function)?;
1210 match args.len() {
1211 3 => self.try_visit_ternary(|cond: Expr, expr1: Expr, expr2: Expr| {
1212 Ok(when(cond).then(expr1).otherwise(expr2))
1213 }),
1214 _ => {
1215 polars_bail!(SQLSyntax: "IF expects 3 arguments (found {})", args.len()
1216 )
1217 },
1218 }
1219 },
1220 IfNull => {
1221 let args = extract_args(function)?;
1222 match args.len() {
1223 2 => self.visit_variadic(coalesce),
1224 _ => {
1225 polars_bail!(SQLSyntax: "IFNULL expects 2 arguments (found {})", args.len())
1226 },
1227 }
1228 },
1229 Least => self.visit_variadic(|exprs: &[Expr]| min_horizontal(exprs).unwrap()),
1230 NullIf => {
1231 let args = extract_args(function)?;
1232 match args.len() {
1233 2 => self.visit_binary(|l: Expr, r: Expr| {
1234 when(l.clone().eq(r))
1235 .then(lit(LiteralValue::untyped_null()))
1236 .otherwise(l)
1237 }),
1238 _ => {
1239 polars_bail!(SQLSyntax: "NULLIF expects 2 arguments (found {})", args.len())
1240 },
1241 }
1242 },
1243
1244 DatePart => self.try_visit_binary(|part, e| {
1248 match part {
1249 Expr::Literal(p) if p.extract_str().is_some() => {
1250 let p = p.extract_str().unwrap();
1251 parse_extract_date_part(
1254 e,
1255 &DateTimeField::Custom(Ident {
1256 value: p.to_string(),
1257 quote_style: None,
1258 span: Span::empty(),
1259 }),
1260 )
1261 },
1262 _ => {
1263 polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART ({})", part);
1264 },
1265 }
1266 }),
1267 Strftime => {
1268 let args = extract_args(function)?;
1269 match args.len() {
1270 2 => self.visit_binary(|e, fmt: String| e.dt().strftime(fmt.as_str())),
1271 _ => {
1272 polars_bail!(SQLSyntax: "STRFTIME expects 2 arguments (found {})", args.len())
1273 },
1274 }
1275 },
1276
1277 BitLength => self.visit_unary(|e| e.str().len_bytes() * lit(8)),
1281 Concat => {
1282 let args = extract_args(function)?;
1283 if args.is_empty() {
1284 polars_bail!(SQLSyntax: "CONCAT expects at least 1 argument (found 0)");
1285 } else {
1286 self.visit_variadic(|exprs: &[Expr]| concat_str(exprs, "", true))
1287 }
1288 },
1289 ConcatWS => {
1290 let args = extract_args(function)?;
1291 if args.len() < 2 {
1292 polars_bail!(SQLSyntax: "CONCAT_WS expects at least 2 arguments (found {})", args.len());
1293 } else {
1294 self.try_visit_variadic(|exprs: &[Expr]| {
1295 match &exprs[0] {
1296 Expr::Literal(lv) if lv.extract_str().is_some() => Ok(concat_str(&exprs[1..], lv.extract_str().unwrap(), true)),
1297 _ => polars_bail!(SQLSyntax: "CONCAT_WS 'separator' must be a literal string (found {:?})", exprs[0]),
1298 }
1299 })
1300 }
1301 },
1302 Date => {
1303 let args = extract_args(function)?;
1304 match args.len() {
1305 1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())),
1306 2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
1307 _ => {
1308 polars_bail!(SQLSyntax: "DATE expects 1-2 arguments (found {})", args.len())
1309 },
1310 }
1311 },
1312 EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)),
1313 #[cfg(feature = "nightly")]
1314 InitCap => self.visit_unary(|e| e.str().to_titlecase()),
1315 Left => self.try_visit_binary(|e, length| {
1316 Ok(match length {
1317 Expr::Literal(lv) if lv.is_null() => lit(lv),
1318 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => lit(""),
1319 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1320 let len = if n > 0 {
1321 lit(n)
1322 } else {
1323 (e.clone().str().len_chars() + lit(n)).clip_min(lit(0))
1324 };
1325 e.str().slice(lit(0), len)
1326 },
1327 Expr::Literal(v) => {
1328 polars_bail!(SQLSyntax: "invalid 'n_chars' for LEFT ({:?})", v)
1329 },
1330 _ => when(length.clone().gt_eq(lit(0)))
1331 .then(e.clone().str().slice(lit(0), length.clone().abs()))
1332 .otherwise(e.clone().str().slice(
1333 lit(0),
1334 (e.str().len_chars() + length.clone()).clip_min(lit(0)),
1335 )),
1336 })
1337 }),
1338 LeftPad | RightPad => {
1339 let is_lpad = matches!(function_name, LeftPad);
1340 let fname = if is_lpad { "LPAD" } else { "RPAD" };
1341 let args = extract_args(function)?;
1342 let pad = |e: Expr, length: Expr, fill_char: char| {
1343 let padded = if is_lpad {
1344 e.str().pad_start(length.clone(), fill_char)
1345 } else {
1346 e.str().pad_end(length.clone(), fill_char)
1347 };
1348 Ok(padded.str().slice(lit(0), length))
1349 };
1350 match args.len() {
1351 2 => self.try_visit_binary(|e, length| pad(e, length, ' ')),
1352 3 => self.try_visit_ternary(|e: Expr, length: Expr, fill: Expr| match fill {
1353 Expr::Literal(lv) if lv.extract_str().is_some() => {
1354 let s = lv.extract_str().unwrap();
1355 let mut chars = s.chars();
1356 match (chars.next(), chars.next()) {
1357 (Some(c), None) => pad(e, length, c),
1358 _ => polars_bail!(SQLSyntax: "{} fill value must be a single character (found '{}')", fname, s),
1359 }
1360 },
1361 _ => polars_bail!(SQLSyntax: "{} fill value must be a string literal", fname),
1362 }),
1363 _ => polars_bail!(SQLSyntax: "{} expects 2-3 arguments (found {})", fname, args.len()),
1364 }
1365 },
1366 LeftTrim | RightTrim => {
1367 let is_ltrim = matches!(function_name, LeftTrim);
1368 let fname = if is_ltrim { "LTRIM" } else { "RTRIM" };
1369 let strip: fn(Expr, Expr) -> Expr = if is_ltrim {
1370 |e, s| e.str().strip_chars_start(s)
1371 } else {
1372 |e, s| e.str().strip_chars_end(s)
1373 };
1374 let args = extract_args(function)?;
1375 match args.len() {
1376 1 => self.visit_unary(|e| strip(e, lit(LiteralValue::untyped_null()))),
1377 2 => self.visit_binary(strip),
1378 _ => {
1379 polars_bail!(SQLSyntax: "{} expects 1-2 arguments (found {})", fname, args.len())
1380 },
1381 }
1382 },
1383 Length => self.visit_unary(|e| e.str().len_chars()),
1384 Lower => self.visit_unary(|e| e.str().to_lowercase()),
1385 Normalize => {
1386 let args = extract_args(function)?;
1387 match args.len() {
1388 1 => self.visit_unary(|e| e.str().normalize(UnicodeForm::NFC)),
1389 2 => {
1390 let form = if let FunctionArgExpr::Expr(SQLExpr::Identifier(Ident {
1391 value: s,
1392 quote_style: None,
1393 span: _,
1394 })) = args[1]
1395 {
1396 match s.to_uppercase().as_str() {
1397 "NFC" => UnicodeForm::NFC,
1398 "NFD" => UnicodeForm::NFD,
1399 "NFKC" => UnicodeForm::NFKC,
1400 "NFKD" => UnicodeForm::NFKD,
1401 _ => {
1402 polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", s)
1403 },
1404 }
1405 } else {
1406 polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", args[1])
1407 };
1408 self.try_visit_binary(|e, _form: Expr| Ok(e.str().normalize(form.clone())))
1409 },
1410 _ => {
1411 polars_bail!(SQLSyntax: "NORMALIZE expects 1-2 arguments (found {})", args.len())
1412 },
1413 }
1414 },
1415 OctetLength => self.visit_unary(|e| e.str().len_bytes()),
1416 StrPos => {
1417 self.visit_binary(|expr, substring| {
1419 (expr.str().find(substring, true) + typed_lit(1u32)).fill_null(typed_lit(0u32))
1420 })
1421 },
1422 RegexpLike => {
1423 let args = extract_args(function)?;
1424 match args.len() {
1425 2 => self.visit_binary(|e, s| e.str().contains(s, true)),
1426 3 => self.try_visit_ternary(|e, pat, flags| {
1427 Ok(e.str().contains(
1428 match (pat, flags) {
1429 (Expr::Literal(s_lv), Expr::Literal(f_lv)) if s_lv.extract_str().is_some() && f_lv.extract_str().is_some() => {
1430 let s = s_lv.extract_str().unwrap();
1431 let f = f_lv.extract_str().unwrap();
1432 if f.is_empty() {
1433 polars_bail!(SQLSyntax: "invalid/empty 'flags' for REGEXP_LIKE ({})", args[2]);
1434 };
1435 lit(format!("(?{f}){s}"))
1436 },
1437 _ => {
1438 polars_bail!(SQLSyntax: "invalid arguments for REGEXP_LIKE ({}, {})", args[1], args[2]);
1439 },
1440 },
1441 true))
1442 }),
1443 _ => polars_bail!(SQLSyntax: "REGEXP_LIKE expects 2-3 arguments (found {})",args.len()),
1444 }
1445 },
1446 Replace => {
1447 let args = extract_args(function)?;
1448 match args.len() {
1449 3 => self
1450 .try_visit_ternary(|e, old, new| Ok(e.str().replace_all(old, new, true))),
1451 _ => {
1452 polars_bail!(SQLSyntax: "REPLACE expects 3 arguments (found {})", args.len())
1453 },
1454 }
1455 },
1456 Reverse => self.visit_unary(|e| e.str().reverse()),
1457 Right => self.try_visit_binary(|e, length| {
1458 Ok(match length {
1459 Expr::Literal(lv) if lv.is_null() => lit(lv),
1460 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => typed_lit(""),
1461 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1462 let n: i64 = n.try_into().unwrap();
1463 let offset = if n < 0 {
1464 lit(n.abs())
1465 } else {
1466 e.clone().str().len_chars().cast(DataType::Int32) - lit(n)
1467 };
1468 e.str().slice(offset, lit(LiteralValue::untyped_null()))
1469 },
1470 Expr::Literal(v) => {
1471 polars_bail!(SQLSyntax: "invalid 'n_chars' for RIGHT ({:?})", v)
1472 },
1473 _ => when(length.clone().lt(lit(0)))
1474 .then(
1475 e.clone()
1476 .str()
1477 .slice(length.clone().abs(), lit(LiteralValue::untyped_null())),
1478 )
1479 .otherwise(e.clone().str().slice(
1480 e.str().len_chars().cast(DataType::Int32) - length.clone(),
1481 lit(LiteralValue::untyped_null()),
1482 )),
1483 })
1484 }),
1485 SplitPart => {
1486 let args = extract_args(function)?;
1487 match args.len() {
1488 3 => self.try_visit_ternary(|e, sep, idx| {
1489 let idx = adjust_one_indexed_param(idx, true);
1490 Ok(when(e.clone().is_not_null())
1491 .then(
1492 e.clone()
1493 .str()
1494 .split(sep)
1495 .list()
1496 .get(idx, true)
1497 .fill_null(lit("")),
1498 )
1499 .otherwise(e))
1500 }),
1501 _ => {
1502 polars_bail!(SQLSyntax: "SPLIT_PART expects 3 arguments (found {})", args.len())
1503 },
1504 }
1505 },
1506 StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
1507 StringToArray => {
1508 let args = extract_args(function)?;
1509 match args.len() {
1510 2 => self.visit_binary(|e, sep| e.str().split(sep)),
1511 _ => {
1512 polars_bail!(SQLSyntax: "STRING_TO_ARRAY expects 2 arguments (found {})", args.len())
1513 },
1514 }
1515 },
1516 Strptime => {
1517 let args = extract_args(function)?;
1518 match args.len() {
1519 2 => self.visit_binary(|e, fmt: String| {
1520 e.str().strptime(
1521 DataType::Datetime(TimeUnit::Microseconds, None),
1522 StrptimeOptions {
1523 format: Some(fmt.into()),
1524 ..Default::default()
1525 },
1526 lit("latest"),
1527 )
1528 }),
1529 _ => {
1530 polars_bail!(SQLSyntax: "STRPTIME expects 2 arguments (found {})", args.len())
1531 },
1532 }
1533 },
1534 Time => {
1535 let args = extract_args(function)?;
1536 match args.len() {
1537 1 => self.visit_unary(|e| e.str().to_time(StrptimeOptions::default())),
1538 2 => self.visit_binary(|e, fmt| e.str().to_time(fmt)),
1539 _ => {
1540 polars_bail!(SQLSyntax: "TIME expects 1-2 arguments (found {})", args.len())
1541 },
1542 }
1543 },
1544 Timestamp => {
1545 let args = extract_args(function)?;
1546 match args.len() {
1547 1 => self.visit_unary(|e| {
1548 e.str()
1549 .to_datetime(None, None, StrptimeOptions::default(), lit("latest"))
1550 }),
1551 2 => self
1552 .visit_binary(|e, fmt| e.str().to_datetime(None, None, fmt, lit("latest"))),
1553 _ => {
1554 polars_bail!(SQLSyntax: "DATETIME expects 1-2 arguments (found {})", args.len())
1555 },
1556 }
1557 },
1558 Substring => {
1559 let args = extract_args(function)?;
1560 match args.len() {
1561 2 => self.try_visit_binary(|e, start| {
1563 Ok(match start {
1564 Expr::Literal(lv) if lv.is_null() => lit(lv),
1565 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1566 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null())),
1567 Expr::Literal(_) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1568 _ => start.clone() + lit(1),
1569 })
1570 }),
1571 3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| {
1572 Ok(match (start.clone(), length.clone()) {
1573 (Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1574 (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1575 polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", args[2])
1576 },
1577 (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => e.str().slice(lit(n - 1), length),
1578 (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => {
1579 e.str().slice(lit(0), (length + lit(n - 1)).clip_min(lit(0)))
1580 },
1581 (Expr::Literal(_), _) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1582 (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1583 polars_bail!(SQLSyntax: "invalid 'length' for SUBSTR ({})", args[1])
1584 },
1585 _ => {
1586 let adjusted_start = start - lit(1);
1587 when(adjusted_start.clone().lt(lit(0)))
1588 .then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0))))
1589 .otherwise(e.str().slice(adjusted_start, length))
1590 }
1591 })
1592 }),
1593 _ => polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found {})", args.len()),
1594 }
1595 },
1596 Upper => self.visit_unary(|e| e.str().to_uppercase()),
1597
1598 Avg => self.visit_unary(Expr::mean),
1602 Corr => self.visit_binary(polars_lazy::dsl::pearson_corr),
1603 Count => self.visit_count(),
1604 CovarPop => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 0)),
1605 CovarSamp => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 1)),
1606 First => self.visit_unary(Expr::first),
1607 Last => self.visit_unary(Expr::last),
1608 Max => self.visit_unary_with_opt_cumulative(Expr::max, Expr::cum_max),
1609 Median => self.visit_unary(Expr::median),
1610 QuantileCont | QuantileDisc => {
1611 let (fname, method) = if matches!(function_name, QuantileCont) {
1612 ("QUANTILE_CONT", QuantileMethod::Linear)
1613 } else {
1614 ("QUANTILE_DISC", QuantileMethod::Equiprobable)
1615 };
1616 let args = extract_args(function)?;
1617 match args.len() {
1618 2 => self.try_visit_binary(|e, q| {
1619 let value = match q {
1620 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
1621 if (0.0..=1.0).contains(&f) {
1622 Expr::from(f)
1623 } else {
1624 polars_bail!(SQLSyntax: "{} value must be between 0 and 1 ({})", fname, args[1])
1625 }
1626 },
1627 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1628 if (0..=1).contains(&n) {
1629 Expr::from(n as f64)
1630 } else {
1631 polars_bail!(SQLSyntax: "{} value must be between 0 and 1 ({})", fname, args[1])
1632 }
1633 },
1634 _ => polars_bail!(SQLSyntax: "invalid value for {} ({})", fname, args[1])
1635 };
1636 Ok(e.quantile(value, method))
1637 }),
1638 _ => polars_bail!(SQLSyntax: "{} expects 2 arguments (found {})", fname, args.len()),
1639 }
1640 },
1641 Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min),
1642 StdDev => self.visit_unary(|e| e.std(1)),
1643 StringAgg => self.visit_string_agg(),
1644 Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum),
1645 Variance => self.visit_unary(|e| e.var(1)),
1646
1647 ArrayAgg => self.visit_arr_agg(),
1651 ArrayContains => self.visit_binary::<Expr>(|e, s| e.list().contains(s, true)),
1652 ArrayGet => {
1653 self.visit_binary(|e, idx: Expr| {
1655 let idx = adjust_one_indexed_param(idx, true);
1656 e.list().get(idx, true)
1657 })
1658 },
1659 ArrayLength => self.visit_unary(|e| e.list().len()),
1660 ArrayMax => self.visit_unary(|e| e.list().max()),
1661 ArrayMean => self.visit_unary(|e| e.list().mean()),
1662 ArrayMin => self.visit_unary(|e| e.list().min()),
1663 ArrayReverse => self.visit_unary(|e| e.list().eval(element().reverse())),
1664 ArraySum => self.visit_unary(|e| e.list().sum()),
1665 ArrayToString => self.visit_arr_to_string(),
1666 ArrayUnique => self.visit_unary(|e| e.list().eval(element().unique_stable())),
1667 Explode => self.visit_unary(|e| {
1668 e.explode(ExplodeOptions {
1669 empty_as_null: true,
1670 keep_nulls: true,
1671 })
1672 }),
1673
1674 Columns => {
1678 let active_schema = self.active_schema;
1679 self.try_visit_unary(|e: Expr| match e {
1680 Expr::Literal(lv) if lv.extract_str().is_some() => {
1681 let pat = lv.extract_str().unwrap();
1682 if pat == "*" {
1683 polars_bail!(
1684 SQLSyntax: "COLUMNS('*') is not a valid regex; \
1685 did you mean COLUMNS(*)?"
1686 )
1687 };
1688 let pat = match pat {
1689 _ if pat.starts_with('^') && pat.ends_with('$') => pat.to_string(),
1690 _ if pat.starts_with('^') => format!("{pat}.*$"),
1691 _ if pat.ends_with('$') => format!("^.*{pat}"),
1692 _ => format!("^.*{pat}.*$"),
1693 };
1694 if let Some(active_schema) = &active_schema {
1695 let rx = polars_utils::regex_cache::compile_regex(&pat).unwrap();
1696 let col_names = active_schema
1697 .iter_names()
1698 .filter(|name| rx.is_match(name))
1699 .cloned()
1700 .collect::<Vec<_>>();
1701
1702 Ok(if col_names.len() == 1 {
1703 col(col_names.into_iter().next().unwrap())
1704 } else {
1705 cols(col_names).as_expr()
1706 })
1707 } else {
1708 Ok(col(pat.as_str()))
1709 }
1710 },
1711 Expr::Selector(s) => Ok(s.as_expr()),
1712 _ => polars_bail!(SQLSyntax: "COLUMNS expects a regex; found {:?}", e),
1713 })
1714 },
1715
1716 FirstValue => self.visit_unary(Expr::first),
1720 LastValue => {
1721 let args = extract_args(function)?;
1725 match args.as_slice() {
1726 [FunctionArgExpr::Expr(sql_expr)] => {
1727 parse_sql_expr(sql_expr, self.ctx, self.active_schema)
1728 },
1729 _ => polars_bail!(
1730 SQLSyntax: "LAST_VALUE expects exactly 1 argument (found {})",
1731 args.len()
1732 ),
1733 }
1734 },
1735 Lag => self.visit_window_offset_function(1),
1736 Lead => self.visit_window_offset_function(-1),
1737 #[cfg(feature = "rank")]
1738 Rank | DenseRank => {
1739 let (func_name, rank_method) = match function_name {
1740 Rank => ("RANK", RankMethod::Min),
1741 DenseRank => ("DENSE_RANK", RankMethod::Dense),
1742 _ => unreachable!(),
1743 };
1744 let args = extract_args(function)?;
1745 if !args.is_empty() {
1746 polars_bail!(SQLSyntax: "{} expects 0 arguments (found {})", func_name, args.len());
1747 }
1748 let window_spec = match &self.func.over {
1749 Some(WindowType::WindowSpec(spec)) if !spec.order_by.is_empty() => spec,
1750 _ => {
1751 polars_bail!(SQLSyntax: "{} requires an OVER clause with ORDER BY", func_name)
1752 },
1753 };
1754 let (order_exprs, all_desc) =
1755 self.parse_order_by_in_window(&window_spec.order_by)?;
1756 let rank_expr = if order_exprs.len() == 1 {
1757 order_exprs[0].clone().rank(
1758 RankOptions {
1759 method: rank_method,
1760 descending: all_desc,
1761 },
1762 None,
1763 )
1764 } else {
1765 as_struct(order_exprs).rank(
1766 RankOptions {
1767 method: rank_method,
1768 descending: all_desc,
1769 },
1770 None,
1771 )
1772 };
1773 self.apply_window_spec(rank_expr, &self.func.over)
1774 },
1775 RowNumber => {
1776 let args = extract_args(function)?;
1777 if !args.is_empty() {
1778 polars_bail!(SQLSyntax: "ROW_NUMBER expects 0 arguments (found {})", args.len());
1779 }
1780 let row_num_expr = int_range(lit(0i64), len(), 1, DataType::UInt32) + lit(1u32);
1782 self.apply_window_spec(row_num_expr, &self.func.over)
1783 },
1784
1785 Udf(func_name) => self.visit_udf(&func_name),
1789 }
1790 }
1791
1792 fn visit_window_offset_function(&mut self, offset_multiplier: i64) -> PolarsResult<Expr> {
1793 if self.func.over.is_none() {
1795 polars_bail!(SQLSyntax: "{} requires an OVER clause", self.func.name);
1796 }
1797
1798 let window_type = self.func.over.as_ref().unwrap();
1800 let window_spec = self.resolve_window_spec(window_type)?;
1801 if window_spec.order_by.is_empty() {
1802 polars_bail!(SQLSyntax: "{} requires an ORDER BY in the OVER clause", self.func.name);
1803 }
1804
1805 let args = extract_args(self.func)?;
1806
1807 match args.as_slice() {
1808 [FunctionArgExpr::Expr(sql_expr)] => {
1809 let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1810 Ok(expr.shift(offset_multiplier.into()))
1811 },
1812 [FunctionArgExpr::Expr(sql_expr), FunctionArgExpr::Expr(offset_expr)] => {
1813 let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1814 let offset = parse_sql_expr(offset_expr, self.ctx, self.active_schema)?;
1815 if let Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) = offset {
1816 if n <= 0 {
1817 polars_bail!(SQLSyntax: "offset must be positive (found {})", n)
1818 }
1819 Ok(expr.shift((offset_multiplier * n as i64).into()))
1820 } else {
1821 polars_bail!(SQLSyntax: "offset must be an integer (found {:?})", offset)
1822 }
1823 },
1824 _ => polars_bail!(SQLSyntax: "{} expects 1 or 2 arguments (found {})", self.func.name, args.len()),
1825 }.and_then(|e| self.apply_window_spec(e, &self.func.over))
1826 }
1827
1828 fn visit_udf(&mut self, func_name: &str) -> PolarsResult<Expr> {
1829 let args = extract_args(self.func)?
1830 .into_iter()
1831 .map(|arg| {
1832 if let FunctionArgExpr::Expr(e) = arg {
1833 parse_sql_expr(e, self.ctx, self.active_schema)
1834 } else {
1835 polars_bail!(SQLInterface: "only expressions are supported in UDFs")
1836 }
1837 })
1838 .collect::<PolarsResult<Vec<_>>>()?;
1839
1840 Ok(self
1841 .ctx
1842 .function_registry
1843 .get_udf(func_name)?
1844 .ok_or_else(|| polars_err!(SQLInterface: "UDF {} not found", func_name))?
1845 .call(args))
1846 }
1847
1848 fn validate_window_frame(&self, window_frame: &Option<WindowFrame>) -> PolarsResult<()> {
1861 if let Some(frame) = window_frame {
1862 match frame.units {
1863 WindowFrameUnits::Range => {
1864 polars_bail!(
1865 SQLInterface:
1866 "RANGE-based window frames are not supported"
1867 );
1868 },
1869 WindowFrameUnits::Groups => {
1870 polars_bail!(
1871 SQLInterface:
1872 "GROUPS-based window frames are not supported"
1873 );
1874 },
1875 WindowFrameUnits::Rows => {
1876 if !matches!(
1877 (&frame.start_bound, &frame.end_bound),
1878 (
1879 WindowFrameBound::Preceding(None), None | Some(WindowFrameBound::CurrentRow) )
1882 ) {
1883 polars_bail!(
1884 SQLInterface:
1885 "only 'ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW' is currently supported; found 'ROWS BETWEEN {} AND {}'",
1886 frame.start_bound,
1887 frame.end_bound.as_ref().map_or("CURRENT ROW", |b| {
1888 match b {
1889 WindowFrameBound::CurrentRow => "CURRENT ROW",
1890 WindowFrameBound::Preceding(_) => "N PRECEDING",
1891 WindowFrameBound::Following(_) => "N FOLLOWING",
1892 }
1893 })
1894 );
1895 }
1896 },
1897 }
1898 }
1899 Ok(())
1900 }
1901
1902 fn apply_cumulative_window(
1926 &mut self,
1927 f: impl Fn(Expr) -> Expr,
1928 cumulative_fn: impl Fn(Expr, bool) -> Expr,
1929 WindowSpec {
1930 partition_by,
1931 order_by,
1932 window_frame,
1933 ..
1934 }: &WindowSpec,
1935 ) -> PolarsResult<Expr> {
1936 self.validate_window_frame(window_frame)?;
1937
1938 if !order_by.is_empty() {
1939 let (order_by_exprs, all_desc) = self.parse_order_by_in_window(order_by)?;
1941
1942 let args = extract_args(self.func)?;
1944 let base_expr = match args.as_slice() {
1945 [FunctionArgExpr::Expr(sql_expr)] => {
1946 parse_sql_expr(sql_expr, self.ctx, self.active_schema)?
1947 },
1948 _ => return self.not_supported_error(),
1949 };
1950 let partition_by_exprs = if partition_by.is_empty() {
1951 None
1952 } else {
1953 Some(
1954 partition_by
1955 .iter()
1956 .map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
1957 .collect::<PolarsResult<Vec<_>>>()?,
1958 )
1959 };
1960
1961 let cumulative_expr = cumulative_fn(base_expr, false)
1963 .fill_null_with_strategy(FillNullStrategy::Forward(None));
1964 let sort_opts = SortOptions::default().with_order_descending(all_desc);
1965 cumulative_expr.over_with_options(
1966 partition_by_exprs,
1967 Some((order_by_exprs, sort_opts)),
1968 Default::default(),
1969 )
1970 } else {
1971 self.visit_unary(f)
1972 }
1973 }
1974
1975 fn parse_sql_arg(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
1980 let parsed = parse_sql_expr(expr, self.ctx, self.active_schema)?;
1981 Ok(match &self.filter {
1982 Some(pred) => parsed.filter(pred.clone()),
1983 None => parsed,
1984 })
1985 }
1986
1987 fn visit_unary(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult<Expr> {
1988 self.try_visit_unary(|e| Ok(f(e)))
1989 }
1990
1991 fn try_visit_unary(&mut self, f: impl Fn(Expr) -> PolarsResult<Expr>) -> PolarsResult<Expr> {
1992 let args = extract_args(self.func)?;
1993 match args.as_slice() {
1994 [FunctionArgExpr::Expr(sql_expr)] => f(self.parse_sql_arg(sql_expr)?),
1995 [FunctionArgExpr::Wildcard] => {
1996 f(self.parse_sql_arg(&SQLExpr::Wildcard(AttachedToken::empty()))?)
1997 },
1998 _ => self.not_supported_error(),
1999 }
2000 .and_then(|e| self.apply_window_spec(e, &self.func.over))
2001 }
2002
2003 fn resolve_window_spec(&self, window_type: &WindowType) -> PolarsResult<WindowSpec> {
2005 match window_type {
2006 WindowType::WindowSpec(spec) => Ok(spec.clone()),
2007 WindowType::NamedWindow(name) => self
2008 .ctx
2009 .named_windows
2010 .get(&name.value)
2011 .cloned()
2012 .ok_or_else(|| {
2013 polars_err!(
2014 SQLInterface:
2015 "named window '{}' was not found",
2016 name.value
2017 )
2018 }),
2019 }
2020 }
2021
2022 fn visit_unary_with_opt_cumulative(
2025 &mut self,
2026 f: impl Fn(Expr) -> Expr,
2027 cumulative_fn: impl Fn(Expr, bool) -> Expr,
2028 ) -> PolarsResult<Expr> {
2029 match self.func.over.as_ref() {
2030 Some(window_type) => {
2031 let spec = self.resolve_window_spec(window_type)?;
2032 self.apply_cumulative_window(f, cumulative_fn, &spec)
2033 },
2034 None => self.visit_unary(f),
2035 }
2036 }
2037
2038 fn visit_binary<Arg: FromSQLExpr>(
2039 &mut self,
2040 f: impl Fn(Expr, Arg) -> Expr,
2041 ) -> PolarsResult<Expr> {
2042 self.try_visit_binary(|e, a| Ok(f(e, a)))
2043 }
2044
2045 fn try_visit_binary<Arg: FromSQLExpr>(
2046 &mut self,
2047 f: impl Fn(Expr, Arg) -> PolarsResult<Expr>,
2048 ) -> PolarsResult<Expr> {
2049 let args = extract_args(self.func)?;
2050 match args.as_slice() {
2051 [
2052 FunctionArgExpr::Expr(sql_expr1),
2053 FunctionArgExpr::Expr(sql_expr2),
2054 ] => {
2055 let expr1 = self.parse_sql_arg(sql_expr1)?;
2056 let expr2 = Arg::from_sql_arg(sql_expr2, self)?;
2057 f(expr1, expr2)
2058 },
2059 _ => self.not_supported_error(),
2060 }
2061 }
2062
2063 fn visit_variadic(&mut self, f: impl Fn(&[Expr]) -> Expr) -> PolarsResult<Expr> {
2064 self.try_visit_variadic(|e| Ok(f(e)))
2065 }
2066
2067 fn try_visit_variadic(
2068 &mut self,
2069 f: impl Fn(&[Expr]) -> PolarsResult<Expr>,
2070 ) -> PolarsResult<Expr> {
2071 let args = extract_args(self.func)?;
2072 let mut expr_args = vec![];
2073 for arg in args {
2074 if let FunctionArgExpr::Expr(sql_expr) = arg {
2075 expr_args.push(self.parse_sql_arg(sql_expr)?);
2076 } else {
2077 return self.not_supported_error();
2078 };
2079 }
2080 f(&expr_args)
2081 }
2082
2083 fn try_visit_ternary<Arg: FromSQLExpr>(
2084 &mut self,
2085 f: impl Fn(Expr, Arg, Arg) -> PolarsResult<Expr>,
2086 ) -> PolarsResult<Expr> {
2087 let args = extract_args(self.func)?;
2088 match args.as_slice() {
2089 [
2090 FunctionArgExpr::Expr(sql_expr1),
2091 FunctionArgExpr::Expr(sql_expr2),
2092 FunctionArgExpr::Expr(sql_expr3),
2093 ] => {
2094 let expr1 = self.parse_sql_arg(sql_expr1)?;
2095 let expr2 = Arg::from_sql_arg(sql_expr2, self)?;
2096 let expr3 = Arg::from_sql_arg(sql_expr3, self)?;
2097 f(expr1, expr2, expr3)
2098 },
2099 _ => self.not_supported_error(),
2100 }
2101 }
2102
2103 fn visit_nullary(&self, f: impl Fn() -> Expr) -> PolarsResult<Expr> {
2104 let args = extract_args(self.func)?;
2105 if !args.is_empty() {
2106 return self.not_supported_error();
2107 }
2108 Ok(f())
2109 }
2110
2111 fn apply_aggregate_clauses(
2115 &mut self,
2116 mut base: Expr,
2117 is_distinct: bool,
2118 clauses: &[FunctionArgumentClause],
2119 base_sql_expr: &SQLExpr,
2120 func_name: &str,
2121 ) -> PolarsResult<Expr> {
2122 let mut order_by_clause = None;
2123 let mut limit_clause = None;
2124 for clause in clauses {
2125 match clause {
2126 FunctionArgumentClause::OrderBy(order_exprs) => {
2127 order_by_clause = Some(order_exprs.as_slice());
2128 },
2129 FunctionArgumentClause::Limit(limit_expr) => {
2130 limit_clause = Some(limit_expr);
2131 },
2132 _ => {},
2133 }
2134 }
2135 if is_distinct {
2136 base = base.unique_stable();
2138 if let Some(order_by) = order_by_clause {
2139 base = self.apply_order_by_to_distinct_array(base, order_by, base_sql_expr)?;
2140 }
2141 } else if let Some(order_by) = order_by_clause {
2142 base = self.apply_order_by(base, order_by)?;
2143 }
2144 if let Some(limit_expr) = limit_clause {
2145 let limit = parse_sql_expr(limit_expr, self.ctx, self.active_schema)?;
2146 match limit {
2147 Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n >= 0 => {
2148 base = base.head(Some(n as usize))
2149 },
2150 _ => {
2151 polars_bail!(SQLSyntax: "LIMIT in {} must be a positive integer", func_name)
2152 },
2153 };
2154 }
2155 Ok(base)
2156 }
2157
2158 fn visit_arr_agg(&mut self) -> PolarsResult<Expr> {
2159 let (args, is_distinct, clauses) = extract_args_and_clauses(self.func)?;
2160 match args.as_slice() {
2161 [FunctionArgExpr::Expr(sql_expr)] => {
2162 let base = self.parse_sql_arg(sql_expr)?;
2163 let base = self.apply_aggregate_clauses(
2164 base,
2165 is_distinct,
2166 &clauses,
2167 sql_expr,
2168 "ARRAY_AGG",
2169 )?;
2170 Ok(base.implode(true))
2171 },
2172 _ => {
2173 polars_bail!(SQLSyntax: "ARRAY_AGG must have exactly one argument; found {}", args.len())
2174 },
2175 }
2176 }
2177
2178 fn visit_string_agg(&mut self) -> PolarsResult<Expr> {
2179 let (args, is_distinct, clauses) = extract_args_and_clauses(self.func)?;
2180 let (sql_expr, separator) = match args.as_slice() {
2181 [FunctionArgExpr::Expr(sql_expr)] => (sql_expr, lit(",")),
2182 [
2183 FunctionArgExpr::Expr(sql_expr),
2184 FunctionArgExpr::Expr(sep_sql_expr),
2185 ] => (
2186 sql_expr,
2187 parse_sql_expr(sep_sql_expr, self.ctx, self.active_schema)?,
2188 ),
2189 _ => polars_bail!(
2190 SQLSyntax: "STRING_AGG expects 1-2 arguments (found {})",
2191 args.len()
2192 ),
2193 };
2194 let base = self.parse_sql_arg(sql_expr)?;
2195 let base =
2196 self.apply_aggregate_clauses(base, is_distinct, &clauses, sql_expr, "STRING_AGG")?;
2197 Ok(base
2198 .cast(DataType::String)
2199 .implode(true)
2200 .list()
2201 .join(separator, true))
2202 }
2203
2204 fn visit_arr_to_string(&mut self) -> PolarsResult<Expr> {
2205 let args = extract_args(self.func)?;
2206 match args.len() {
2207 2 => self.try_visit_binary(|e, sep| {
2208 Ok(e.cast(DataType::List(Box::from(DataType::String)))
2209 .list()
2210 .join(sep, true))
2211 }),
2212 #[cfg(feature = "list_eval")]
2213 3 => self.try_visit_ternary(|e, sep, null_value| match null_value {
2214 Expr::Literal(lv) if lv.extract_str().is_some() => {
2215 Ok(if lv.extract_str().unwrap().is_empty() {
2216 e.cast(DataType::List(Box::from(DataType::String)))
2217 .list()
2218 .join(sep, true)
2219 } else {
2220 e.cast(DataType::List(Box::from(DataType::String)))
2221 .list()
2222 .eval(element().fill_null(lit(lv.extract_str().unwrap())))
2223 .list()
2224 .join(sep, false)
2225 })
2226 },
2227 _ => {
2228 polars_bail!(SQLSyntax: "invalid null value for ARRAY_TO_STRING ({})", args[2])
2229 },
2230 }),
2231 _ => {
2232 polars_bail!(SQLSyntax: "ARRAY_TO_STRING expects 2-3 arguments (found {})", args.len())
2233 },
2234 }
2235 }
2236
2237 fn visit_count(&mut self) -> PolarsResult<Expr> {
2238 let (args, is_distinct) = extract_args_distinct(self.func)?;
2239
2240 let has_order_by = match &self.func.over {
2242 Some(WindowType::WindowSpec(spec)) => !spec.order_by.is_empty(),
2243 _ => false,
2244 };
2245 if has_order_by && !is_distinct {
2246 if let Some(WindowType::WindowSpec(spec)) = &self.func.over {
2247 self.validate_window_frame(&spec.window_frame)?;
2248
2249 let is_count_star = match args.as_slice() {
2250 [FunctionArgExpr::Wildcard] | [] => true,
2251 [FunctionArgExpr::Expr(e)] => is_non_null_literal(e),
2252 _ => false,
2253 };
2254 match args.as_slice() {
2255 _ if is_count_star => {
2256 let (order_by_exprs, all_desc) =
2258 self.parse_order_by_in_window(&spec.order_by)?;
2259 let partition_by_exprs = if spec.partition_by.is_empty() {
2260 None
2261 } else {
2262 Some(
2263 spec.partition_by
2264 .iter()
2265 .map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
2266 .collect::<PolarsResult<Vec<_>>>()?,
2267 )
2268 };
2269 let sort_opts = SortOptions::default().with_order_descending(all_desc);
2270 let row_number = int_range(lit(0), len(), 1, DataType::Int64).add(lit(1)); return row_number.over_with_options(
2273 partition_by_exprs,
2274 Some((order_by_exprs, sort_opts)),
2275 Default::default(),
2276 );
2277 },
2278 [FunctionArgExpr::Expr(_)] => {
2279 return self.visit_unary_with_opt_cumulative(
2281 |e| e.count(),
2282 |e, reverse| e.cum_count(reverse),
2283 );
2284 },
2285 _ => {},
2286 }
2287 }
2288 }
2289 let count_star = || match &self.filter {
2291 Some(pred) => pred.clone().sum(),
2292 None => len(),
2293 };
2294 let count_expr = match (is_distinct, args.as_slice()) {
2295 (false, [FunctionArgExpr::Wildcard] | []) => count_star(),
2297 (false, [FunctionArgExpr::Expr(sql_expr)]) if is_non_null_literal(sql_expr) => {
2299 count_star()
2300 },
2301 (false, [FunctionArgExpr::Expr(sql_expr)]) => {
2303 let expr = self.parse_sql_arg(sql_expr)?;
2304 expr.count()
2305 },
2306 (true, [FunctionArgExpr::Expr(sql_expr)]) => {
2308 let expr = self.parse_sql_arg(sql_expr)?;
2309 expr.clone().n_unique().sub(expr.null_count().gt(lit(0)))
2310 },
2311 _ => self.not_supported_error()?,
2312 };
2313 self.apply_window_spec(count_expr, &self.func.over)
2314 }
2315
2316 fn apply_order_by(&mut self, expr: Expr, order_by: &[OrderByExpr]) -> PolarsResult<Expr> {
2317 let mut by = Vec::with_capacity(order_by.len());
2318 let mut descending = Vec::with_capacity(order_by.len());
2319 let mut nulls_last = Vec::with_capacity(order_by.len());
2320
2321 for ob in order_by {
2322 let desc_order = !ob.options.asc.unwrap_or(true);
2327 by.push(self.parse_sql_arg(&ob.expr)?);
2328 nulls_last.push(!ob.options.nulls_first.unwrap_or(desc_order));
2329 descending.push(desc_order);
2330 }
2331 Ok(expr.sort_by(
2332 by,
2333 SortMultipleOptions::default()
2334 .with_order_descending_multi(descending)
2335 .with_nulls_last_multi(nulls_last),
2336 ))
2337 }
2338
2339 fn apply_order_by_to_distinct_array(
2340 &mut self,
2341 expr: Expr,
2342 order_by: &[OrderByExpr],
2343 base_sql_expr: &SQLExpr,
2344 ) -> PolarsResult<Expr> {
2345 if order_by.len() == 1 && order_by[0].expr == *base_sql_expr {
2347 let desc_order = !order_by[0].options.asc.unwrap_or(true);
2348 let nulls_last = !order_by[0].options.nulls_first.unwrap_or(desc_order);
2349 return Ok(expr.sort(
2350 SortOptions::default()
2351 .with_order_descending(desc_order)
2352 .with_nulls_last(nulls_last),
2353 ));
2354 }
2355 self.apply_order_by(expr, order_by)
2357 }
2358
2359 fn parse_order_by_in_window(
2361 &mut self,
2362 order_by: &[OrderByExpr],
2363 ) -> PolarsResult<(Vec<Expr>, bool)> {
2364 if order_by.is_empty() {
2365 return Ok((Vec::new(), false));
2366 }
2367 let all_ascending = order_by[0].options.asc.unwrap_or(true);
2369 let mut exprs = Vec::with_capacity(order_by.len());
2370 for o in order_by {
2371 if all_ascending != o.options.asc.unwrap_or(true) {
2372 polars_bail!(
2375 SQLSyntax:
2376 "OVER does not (yet) support mixed asc/desc directions for ORDER BY"
2377 )
2378 }
2379 let expr = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?;
2380 exprs.push(expr);
2381 }
2382 Ok((exprs, !all_ascending))
2383 }
2384
2385 fn apply_window_spec(
2386 &mut self,
2387 expr: Expr,
2388 window_type: &Option<WindowType>,
2389 ) -> PolarsResult<Expr> {
2390 let Some(window_type) = window_type else {
2391 return Ok(expr);
2392 };
2393 let window_spec = self.resolve_window_spec(window_type)?;
2394 self.validate_window_frame(&window_spec.window_frame)?;
2395
2396 let partition_by = if window_spec.partition_by.is_empty() {
2397 None
2398 } else {
2399 Some(
2400 window_spec
2401 .partition_by
2402 .iter()
2403 .map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
2404 .collect::<PolarsResult<Vec<_>>>()?,
2405 )
2406 };
2407 let order_by = if window_spec.order_by.is_empty() {
2408 None
2409 } else {
2410 let (order_exprs, all_desc) = self.parse_order_by_in_window(&window_spec.order_by)?;
2411 let sort_opts = SortOptions::default().with_order_descending(all_desc);
2412 Some((order_exprs, sort_opts))
2413 };
2414
2415 Ok(match (partition_by, order_by) {
2417 (None, None) => expr,
2418 (Some(part), None) => expr.over(part)?,
2419 (part, Some(order)) => expr.over_with_options(part, Some(order), Default::default())?,
2420 })
2421 }
2422
2423 fn not_supported_error(&self) -> PolarsResult<Expr> {
2424 polars_bail!(
2425 SQLInterface:
2426 "no function matches the given name and arguments: `{}`",
2427 self.func.to_string()
2428 );
2429 }
2430}
2431
2432fn is_non_null_literal(expr: &SQLExpr) -> bool {
2434 matches!(
2435 expr,
2436 SQLExpr::Value(ValueWithSpan {
2437 value: v,
2438 ..
2439 }) if !matches!(v, SQLValue::Null)
2440 )
2441}
2442
2443fn extract_args(func: &SQLFunction) -> PolarsResult<Vec<&FunctionArgExpr>> {
2444 let (args, _, _) = _extract_func_args(func, false, false)?;
2445 Ok(args)
2446}
2447
2448fn extract_args_distinct(func: &SQLFunction) -> PolarsResult<(Vec<&FunctionArgExpr>, bool)> {
2449 let (args, is_distinct, _) = _extract_func_args(func, true, false)?;
2450 Ok((args, is_distinct))
2451}
2452
2453fn extract_args_and_clauses(
2454 func: &SQLFunction,
2455) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
2456 _extract_func_args(func, true, true)
2457}
2458
2459fn _extract_func_args(
2460 func: &SQLFunction,
2461 get_distinct: bool,
2462 get_clauses: bool,
2463) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
2464 match &func.args {
2465 FunctionArguments::List(FunctionArgumentList {
2466 args,
2467 duplicate_treatment,
2468 clauses,
2469 }) => {
2470 let is_distinct = matches!(duplicate_treatment, Some(DuplicateTreatment::Distinct));
2471 if !(get_clauses || get_distinct) && is_distinct {
2472 polars_bail!(SQLSyntax: "unexpected use of DISTINCT found in '{}'", func.name)
2473 } else if !get_clauses && !clauses.is_empty() {
2474 polars_bail!(SQLSyntax: "unexpected clause found in '{}' ({})", func.name, clauses[0])
2475 } else {
2476 let unpacked_args = args
2477 .iter()
2478 .map(|arg| match arg {
2479 FunctionArg::Named { arg, .. } => arg,
2480 FunctionArg::ExprNamed { arg, .. } => arg,
2481 FunctionArg::Unnamed(arg) => arg,
2482 })
2483 .collect();
2484 Ok((unpacked_args, is_distinct, clauses.clone()))
2485 }
2486 },
2487 FunctionArguments::Subquery { .. } => {
2488 Err(polars_err!(SQLInterface: "subquery not expected in {}", func.name))
2489 },
2490 FunctionArguments::None => Ok((vec![], false, vec![])),
2491 }
2492}
2493
2494pub(crate) trait FromSQLExpr {
2495 fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
2496 where
2497 Self: Sized;
2498
2499 fn from_sql_arg(expr: &SQLExpr, visitor: &mut SQLFunctionVisitor<'_>) -> PolarsResult<Self>
2503 where
2504 Self: Sized,
2505 {
2506 Self::from_sql_expr(expr, visitor.ctx)
2507 }
2508}
2509
2510impl FromSQLExpr for f64 {
2511 fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
2512 where
2513 Self: Sized,
2514 {
2515 match expr {
2516 SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2517 SQLValue::Number(s, _) => s
2518 .parse()
2519 .map_err(|_| polars_err!(SQLInterface: "cannot parse literal {:?}", s)),
2520 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2521 },
2522 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2523 }
2524 }
2525}
2526
2527impl FromSQLExpr for bool {
2528 fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
2529 where
2530 Self: Sized,
2531 {
2532 match expr {
2533 SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2534 SQLValue::Boolean(v) => Ok(*v),
2535 _ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", v),
2536 },
2537 _ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", expr),
2538 }
2539 }
2540}
2541
2542impl FromSQLExpr for String {
2543 fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2544 where
2545 Self: Sized,
2546 {
2547 match expr {
2548 SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2549 SQLValue::SingleQuotedString(s) => Ok(s.clone()),
2550 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2551 },
2552 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2553 }
2554 }
2555}
2556
2557impl FromSQLExpr for StrptimeOptions {
2558 fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2559 where
2560 Self: Sized,
2561 {
2562 match expr {
2563 SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2564 SQLValue::SingleQuotedString(s) => Ok(StrptimeOptions {
2565 format: Some(PlSmallStr::from_str(s)),
2566 ..StrptimeOptions::default()
2567 }),
2568 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2569 },
2570 _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2571 }
2572 }
2573}
2574
2575impl FromSQLExpr for Expr {
2576 fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
2577 where
2578 Self: Sized,
2579 {
2580 parse_sql_expr(expr, ctx, None)
2581 }
2582
2583 fn from_sql_arg(expr: &SQLExpr, visitor: &mut SQLFunctionVisitor<'_>) -> PolarsResult<Self> {
2584 visitor.parse_sql_arg(expr)
2585 }
2586}