Skip to main content

polars_sql/
functions.rs

1use std::ops::{Add, Sub};
2
3use polars_core::chunked_array::ops::{FillNullStrategy, SortMultipleOptions, SortOptions};
4use polars_core::prelude::{
5    DataType, ExplodeOptions, PolarsResult, QuantileMethod, Schema, TimeUnit, polars_bail,
6    polars_err,
7};
8use polars_lazy::dsl::Expr;
9#[cfg(feature = "rank")]
10use polars_lazy::prelude::{RankMethod, RankOptions};
11use polars_ops::chunked_array::UnicodeForm;
12use polars_ops::series::RoundMode;
13use polars_plan::dsl::functions::{
14    as_struct, coalesce, col, cols, concat_str, element, int_range, len, lit, max_horizontal,
15    min_horizontal, when,
16};
17use polars_plan::plans::{DynLiteralValue, LiteralValue, typed_lit};
18use polars_plan::prelude::StrptimeOptions;
19use polars_utils::pl_str::PlSmallStr;
20use sqlparser::ast::helpers::attached_token::AttachedToken;
21use sqlparser::ast::{
22    DateTimeField, DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg,
23    FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, Ident,
24    OrderByExpr, Value as SQLValue, ValueWithSpan, WindowFrame, WindowFrameBound, WindowFrameUnits,
25    WindowSpec, WindowType,
26};
27use sqlparser::tokenizer::Span;
28
29use crate::SQLContext;
30use crate::sql_expr::{adjust_one_indexed_param, parse_extract_date_part, parse_sql_expr};
31
32pub(crate) struct SQLFunctionVisitor<'a> {
33    pub(crate) func: &'a SQLFunction,
34    pub(crate) ctx: &'a mut SQLContext,
35    pub(crate) active_schema: Option<&'a Schema>,
36    pub(crate) filter: Option<Expr>,
37}
38
39/// SQL functions that are supported by Polars
40pub(crate) enum PolarsSQLFunctions {
41    // ----
42    // Bitwise functions
43    // ----
44    /// SQL 'bit_and' function.
45    /// Returns the bitwise AND of the input expressions.
46    /// ```sql
47    /// SELECT BIT_AND(col1, col2) FROM df;
48    /// ```
49    BitAnd,
50    /// SQL 'bit_count' function.
51    /// Returns the number of set bits in the input expression.
52    /// ```sql
53    /// SELECT BIT_COUNT(col1) FROM df;
54    /// ```
55    #[cfg(feature = "bitwise")]
56    BitCount,
57    /// SQL 'bit_or' function.
58    /// Returns the bitwise OR of the input expressions.
59    /// ```sql
60    /// SELECT BIT_OR(col1, col2) FROM df;
61    /// ```
62    BitNot,
63    /// SQL 'bit_not' function.
64    /// Returns the bitwise Not of the input expression.
65    /// ```sql
66    /// SELECT BIT_Not(col1) FROM df;
67    /// ```
68    BitOr,
69    /// SQL 'bit_xor' function.
70    /// Returns the bitwise XOR of the input expressions.
71    /// ```sql
72    /// SELECT BIT_XOR(col1, col2) FROM df;
73    /// ```
74    BitXor,
75
76    // ----
77    // Math functions
78    // ----
79    /// SQL 'abs' function.
80    /// Returns the absolute value of the input expression.
81    /// ```sql
82    /// SELECT ABS(col1) FROM df;
83    /// ```
84    Abs,
85    /// SQL 'ceil' function.
86    /// Returns the nearest integer closest from zero.
87    /// ```sql
88    /// SELECT CEIL(col1) FROM df;
89    /// ```
90    Ceil,
91    /// SQL 'div' function.
92    /// Returns the integer quotient of the division.
93    /// ```sql
94    /// SELECT DIV(col1, 2) FROM df;
95    /// ```
96    Div,
97    /// SQL 'exp' function.
98    /// Computes the exponential of the given value.
99    /// ```sql
100    /// SELECT EXP(col1) FROM df;
101    /// ```
102    Exp,
103    /// SQL 'floor' function.
104    /// Returns the nearest integer away from zero.
105    ///   0.5 will be rounded
106    /// ```sql
107    /// SELECT FLOOR(col1) FROM df;
108    /// ```
109    Floor,
110    /// SQL 'pi' function.
111    /// Returns a (very good) approximation of 𝜋.
112    /// ```sql
113    /// SELECT PI() FROM df;
114    /// ```
115    Pi,
116    /// SQL 'ln' function.
117    /// Computes the natural logarithm of the given value.
118    /// ```sql
119    /// SELECT LN(col1) FROM df;
120    /// ```
121    Ln,
122    /// SQL 'log2' function.
123    /// Computes the logarithm of the given value in base 2.
124    /// ```sql
125    /// SELECT LOG2(col1) FROM df;
126    /// ```
127    Log2,
128    /// SQL 'log10' function.
129    /// Computes the logarithm of the given value in base 10.
130    /// ```sql
131    /// SELECT LOG10(col1) FROM df;
132    /// ```
133    Log10,
134    /// SQL 'log' function.
135    /// Computes the `base` logarithm of the given value.
136    /// ```sql
137    /// SELECT LOG(col1, 10) FROM df;
138    /// ```
139    Log,
140    /// SQL 'log1p' function.
141    /// Computes the natural logarithm of "given value plus one".
142    /// ```sql
143    /// SELECT LOG1P(col1) FROM df;
144    /// ```
145    Log1p,
146    /// SQL 'pow' function.
147    /// Returns the value to the power of the given exponent.
148    /// ```sql
149    /// SELECT POW(col1, 2) FROM df;
150    /// ```
151    Pow,
152    /// SQL 'mod' function.
153    /// Returns the remainder of a numeric expression divided by another numeric expression.
154    /// ```sql
155    /// SELECT MOD(col1, 2) FROM df;
156    /// ```
157    Mod,
158    /// SQL 'sqrt' function.
159    /// Returns the square root (√) of a number.
160    /// ```sql
161    /// SELECT SQRT(col1) FROM df;
162    /// ```
163    Sqrt,
164    /// SQL 'cbrt' function.
165    /// Returns the cube root (∛) of a number.
166    /// ```sql
167    /// SELECT CBRT(col1) FROM df;
168    /// ```
169    Cbrt,
170    /// SQL 'round' function.
171    /// Round a number to `n` decimals (default: 0) away from zero.
172    ///   .5 is rounded away from zero.
173    /// ```sql
174    /// SELECT ROUND(col1, 3) FROM df;
175    /// ```
176    Round,
177    /// SQL 'truncate' function.
178    /// Truncate a number toward zero to `n` decimals (default: 0).
179    /// ```sql
180    /// SELECT TRUNCATE(col1, 2) FROM df;
181    /// ```
182    Truncate,
183    /// SQL 'sign' function.
184    /// Returns the sign of the argument as -1, 0, or +1.
185    /// ```sql
186    /// SELECT SIGN(col1) FROM df;
187    /// ```
188    Sign,
189
190    // ----
191    // Trig functions
192    // ----
193    /// SQL 'cos' function.
194    /// Compute the cosine sine of the input expression (in radians).
195    /// ```sql
196    /// SELECT COS(col1) FROM df;
197    /// ```
198    Cos,
199    /// SQL 'cot' function.
200    /// Compute the cotangent of the input expression (in radians).
201    /// ```sql
202    /// SELECT COT(col1) FROM df;
203    /// ```
204    Cot,
205    /// SQL 'sin' function.
206    /// Compute the sine of the input expression (in radians).
207    /// ```sql
208    /// SELECT SIN(col1) FROM df;
209    /// ```
210    Sin,
211    /// SQL 'tan' function.
212    /// Compute the tangent of the input expression (in radians).
213    /// ```sql
214    /// SELECT TAN(col1) FROM df;
215    /// ```
216    Tan,
217    /// SQL 'cosd' function.
218    /// Compute the cosine sine of the input expression (in degrees).
219    /// ```sql
220    /// SELECT COSD(col1) FROM df;
221    /// ```
222    CosD,
223    /// SQL 'cotd' function.
224    /// Compute cotangent of the input expression (in degrees).
225    /// ```sql
226    /// SELECT COTD(col1) FROM df;
227    /// ```
228    CotD,
229    /// SQL 'sind' function.
230    /// Compute the sine of the input expression (in degrees).
231    /// ```sql
232    /// SELECT SIND(col1) FROM df;
233    /// ```
234    SinD,
235    /// SQL 'tand' function.
236    /// Compute the tangent of the input expression (in degrees).
237    /// ```sql
238    /// SELECT TAND(col1) FROM df;
239    /// ```
240    TanD,
241    /// SQL 'acos' function.
242    /// Compute inverse cosine of the input expression (in radians).
243    /// ```sql
244    /// SELECT ACOS(col1) FROM df;
245    /// ```
246    Acos,
247    /// SQL 'asin' function.
248    /// Compute inverse sine of the input expression (in radians).
249    /// ```sql
250    /// SELECT ASIN(col1) FROM df;
251    /// ```
252    Asin,
253    /// SQL 'atan' function.
254    /// Compute inverse tangent of the input expression (in radians).
255    /// ```sql
256    /// SELECT ATAN(col1) FROM df;
257    /// ```
258    Atan,
259    /// SQL 'atan2' function.
260    /// Compute the inverse tangent of col1/col2 (in radians).
261    /// ```sql
262    /// SELECT ATAN2(col1, col2) FROM df;
263    /// ```
264    Atan2,
265    /// SQL 'acosd' function.
266    /// Compute inverse cosine of the input expression (in degrees).
267    /// ```sql
268    /// SELECT ACOSD(col1) FROM df;
269    /// ```
270    AcosD,
271    /// SQL 'asind' function.
272    /// Compute inverse sine of the input expression (in degrees).
273    /// ```sql
274    /// SELECT ASIND(col1) FROM df;
275    /// ```
276    AsinD,
277    /// SQL 'atand' function.
278    /// Compute inverse tangent of the input expression (in degrees).
279    /// ```sql
280    /// SELECT ATAND(col1) FROM df;
281    /// ```
282    AtanD,
283    /// SQL 'atan2d' function.
284    /// Compute the inverse tangent of col1/col2 (in degrees).
285    /// ```sql
286    /// SELECT ATAN2D(col1) FROM df;
287    /// ```
288    Atan2D,
289    /// SQL 'degrees' function.
290    /// Convert between radians and degrees.
291    /// ```sql
292    /// SELECT DEGREES(col1) FROM df;
293    /// ```
294    ///
295    ///
296    Degrees,
297    /// SQL 'RADIANS' function.
298    /// Convert between degrees and radians.
299    /// ```sql
300    /// SELECT RADIANS(col1) FROM df;
301    /// ```
302    Radians,
303
304    // ----
305    // Temporal functions
306    // ----
307    /// SQL 'date_part' function.
308    /// Extracts a part of a date (or datetime) such as 'year', 'month', etc.
309    /// ```sql
310    /// SELECT DATE_PART('year', col1) FROM df;
311    /// SELECT DATE_PART('day', col1) FROM df;
312    DatePart,
313    /// SQL 'strftime' function.
314    /// Converts a datetime to a string using a format string.
315    /// ```sql
316    /// SELECT STRFTIME(col1, '%d-%m-%Y %H:%M') FROM df;
317    /// ```
318    Strftime,
319
320    // ----
321    // String functions
322    // ----
323    /// SQL 'bit_length' function (bytes).
324    /// ```sql
325    /// SELECT BIT_LENGTH(col1) FROM df;
326    /// ```
327    BitLength,
328    /// SQL 'concat' function.
329    /// Returns all input expressions concatenated together as a string.
330    /// ```sql
331    /// SELECT CONCAT(col1, col2) FROM df;
332    /// ```
333    Concat,
334    /// SQL 'concat_ws' function.
335    /// Returns all input expressions concatenated together
336    /// (and interleaved with a separator) as a string.
337    /// ```sql
338    /// SELECT CONCAT_WS(':', col1, col2, col3) FROM df;
339    /// ```
340    ConcatWS,
341    /// SQL 'date' function.
342    /// Converts a formatted string date to an actual Date type; ISO-8601 format is assumed
343    /// unless a strftime-compatible formatting string is provided as the second parameter.
344    /// ```sql
345    /// SELECT DATE('2021-03-15') FROM df;
346    /// SELECT DATE('2021-15-03', '%Y-d%-%m') FROM df;
347    /// SELECT DATE('2021-03', '%Y-%m') FROM df;
348    /// ```
349    Date,
350    /// SQL 'ends_with' function.
351    /// Returns True if the value ends with the second argument.
352    /// ```sql
353    /// SELECT ENDS_WITH(col1, 'a') FROM df;
354    /// SELECT col2 from df WHERE ENDS_WITH(col1, 'a');
355    /// ```
356    EndsWith,
357    /// SQL 'initcap' function.
358    /// Returns the value with the first letter capitalized.
359    /// ```sql
360    /// SELECT INITCAP(col1) FROM df;
361    /// ```
362    #[cfg(feature = "nightly")]
363    InitCap,
364    /// SQL 'left' function.
365    /// Returns the first (leftmost) `n` characters.
366    /// ```sql
367    /// SELECT LEFT(col1, 3) FROM df;
368    /// ```
369    Left,
370    /// SQL 'lpad' function.
371    /// Pads a string on the left to a specified length, using an optional fill character.
372    /// ```sql
373    /// SELECT LPAD(col1, 10, 'x') FROM df;
374    /// ```
375    LeftPad,
376    /// SQL 'ltrim' function.
377    /// Strip whitespaces from the left.
378    /// ```sql
379    /// SELECT LTRIM(col1) FROM df;
380    /// ```
381    LeftTrim,
382    /// SQL 'length' function (characters.
383    /// Returns the character length of the string.
384    /// ```sql
385    /// SELECT LENGTH(col1) FROM df;
386    /// ```
387    Length,
388    /// SQL 'lower' function.
389    /// Returns an lowercased column.
390    /// ```sql
391    /// SELECT LOWER(col1) FROM df;
392    /// ```
393    Lower,
394    /// SQL 'normalize' function.
395    /// Convert string to Unicode normalization form
396    /// (one of NFC, NFKC, NFD, or NFKD - unquoted).
397    /// ```sql
398    /// SELECT NORMALIZE(col1, NFC) FROM df;
399    /// ```
400    Normalize,
401    /// SQL 'octet_length' function.
402    /// Returns the length of a given string in bytes.
403    /// ```sql
404    /// SELECT OCTET_LENGTH(col1) FROM df;
405    /// ```
406    OctetLength,
407    /// SQL 'regexp_like' function.
408    /// True if `pattern` matches the value (optional: `flags`).
409    /// ```sql
410    /// SELECT REGEXP_LIKE(col1, 'xyz', 'i') FROM df;
411    /// ```
412    RegexpLike,
413    /// SQL 'replace' function.
414    /// Replace a given substring with another string.
415    /// ```sql
416    /// SELECT REPLACE(col1, 'old', 'new') FROM df;
417    /// ```
418    Replace,
419    /// SQL 'reverse' function.
420    /// Return the reversed string.
421    /// ```sql
422    /// SELECT REVERSE(col1) FROM df;
423    /// ```
424    Reverse,
425    /// SQL 'right' function.
426    /// Returns the last (rightmost) `n` characters.
427    /// ```sql
428    /// SELECT RIGHT(col1, 3) FROM df;
429    /// ```
430    Right,
431    /// SQL 'rpad' function.
432    /// Pads a string on the right to a specified length, using an optional fill character.
433    /// ```sql
434    /// SELECT RPAD(col1, 10, 'x') FROM df;
435    /// ```
436    RightPad,
437    /// SQL 'rtrim' function.
438    /// Strip whitespaces from the right.
439    /// ```sql
440    /// SELECT RTRIM(col1) FROM df;
441    /// ```
442    RightTrim,
443    /// SQL 'split_part' function.
444    /// Splits a string into an array of strings using the given delimiter
445    /// and returns the `n`-th part (1-indexed).
446    /// ```sql
447    /// SELECT SPLIT_PART(col1, ',', 2) FROM df;
448    /// ```
449    SplitPart,
450    /// SQL 'starts_with' function.
451    /// Returns True if the value starts with the second argument.
452    /// ```sql
453    /// SELECT STARTS_WITH(col1, 'a') FROM df;
454    /// SELECT col2 from df WHERE STARTS_WITH(col1, 'a');
455    /// ```
456    StartsWith,
457    /// SQL 'strpos' function.
458    /// Returns the index of the given substring in the target string.
459    /// ```sql
460    /// SELECT STRPOS(col1,'xyz') FROM df;
461    /// ```
462    StrPos,
463    /// SQL 'substr' function.
464    /// Returns a portion of the data (first character = 1) in the range.
465    ///   \[start, start + length]
466    /// ```sql
467    /// SELECT SUBSTR(col1, 3, 5) FROM df;
468    /// ```
469    Substring,
470    /// SQL 'string_to_array' function.
471    /// Splits a string into an array of strings using the given delimiter.
472    /// ```sql
473    /// SELECT STRING_TO_ARRAY(col1, ',') FROM df;
474    /// ```
475    StringToArray,
476    /// SQL 'strptime' function.
477    /// Converts a string to a datetime using a format string.
478    /// ```sql
479    /// SELECT STRPTIME(col1, '%d-%m-%Y %H:%M') FROM df;
480    /// ```
481    Strptime,
482    /// SQL 'time' function.
483    /// Converts a formatted string time to an actual Time type; ISO-8601 format is
484    /// assumed unless a strftime-compatible formatting string is provided as the second
485    /// parameter.
486    /// ```sql
487    /// SELECT TIME('10:30:45') FROM df;
488    /// SELECT TIME('20.30', '%H.%M') FROM df;
489    /// ```
490    Time,
491    /// SQL 'timestamp' function.
492    /// Converts a formatted string datetime to an actual Datetime type; ISO-8601 format is
493    /// assumed unless a strftime-compatible formatting string is provided as the second
494    /// parameter.
495    /// ```sql
496    /// SELECT TIMESTAMP('2021-03-15 10:30:45') FROM df;
497    /// SELECT TIMESTAMP('2021-15-03T00:01:02.333', '%Y-d%-%m %H:%M:%S') FROM df;
498    /// ```
499    Timestamp,
500    /// SQL 'upper' function.
501    /// Returns an uppercased column.
502    /// ```sql
503    /// SELECT UPPER(col1) FROM df;
504    /// ```
505    Upper,
506
507    // ----
508    // Conditional functions
509    // ----
510    /// SQL 'coalesce' function.
511    /// Returns the first non-null value in the provided values/columns.
512    /// ```sql
513    /// SELECT COALESCE(col1, ...) FROM df;
514    /// ```
515    Coalesce,
516    /// SQL 'greatest' function.
517    /// Returns the greatest value in the list of expressions.
518    /// ```sql
519    /// SELECT GREATEST(col1, col2, ...) FROM df;
520    /// ```
521    Greatest,
522    /// SQL 'if' function.
523    /// Returns expr1 if the boolean condition provided as the first
524    /// parameter evaluates to true, and expr2 otherwise.
525    /// ```sql
526    /// SELECT IF(column < 0, expr1, expr2) FROM df;
527    /// ```
528    If,
529    /// SQL 'ifnull' function.
530    /// If an expression value is NULL, return an alternative value.
531    /// ```sql
532    /// SELECT IFNULL(string_col, 'n/a') FROM df;
533    /// ```
534    IfNull,
535    /// SQL 'least' function.
536    /// Returns the smallest value in the list of expressions.
537    /// ```sql
538    /// SELECT LEAST(col1, col2, ...) FROM df;
539    /// ```
540    Least,
541    /// SQL 'nullif' function.
542    /// Returns NULL if two expressions are equal, otherwise returns the first.
543    /// ```sql
544    /// SELECT NULLIF(col1, col2) FROM df;
545    /// ```
546    NullIf,
547
548    // ----
549    // Aggregate functions
550    // ----
551    /// SQL 'avg' function.
552    /// Returns the average (mean) of all the elements in the grouping.
553    /// ```sql
554    /// SELECT AVG(col1) FROM df;
555    /// ```
556    Avg,
557    /// SQL 'corr' function.
558    /// Returns the Pearson correlation coefficient between two columns.
559    /// ```sql
560    /// SELECT CORR(col1, col2) FROM df;
561    /// ```
562    Corr,
563    /// SQL 'count' function.
564    /// Returns the amount of elements in the grouping.
565    /// ```sql
566    /// SELECT COUNT(col1) FROM df;
567    /// SELECT COUNT(*) FROM df;
568    /// SELECT COUNT(DISTINCT col1) FROM df;
569    /// SELECT COUNT(DISTINCT *) FROM df;
570    /// ```
571    Count,
572    /// SQL 'covar_pop' function.
573    /// Returns the population covariance between two columns.
574    /// ```sql
575    /// SELECT COVAR_POP(col1, col2) FROM df;
576    /// ```
577    CovarPop,
578    /// SQL 'covar_samp' function.
579    /// Returns the sample covariance between two columns.
580    /// ```sql
581    /// SELECT COVAR_SAMP(col1, col2) FROM df;
582    /// ```
583    CovarSamp,
584    /// SQL 'first' function.
585    /// Returns the first element of the grouping.
586    /// ```sql
587    /// SELECT FIRST(col1) FROM df;
588    /// ```
589    First,
590    /// SQL 'last' function.
591    /// Returns the last element of the grouping.
592    /// ```sql
593    /// SELECT LAST(col1) FROM df;
594    /// ```
595    Last,
596    /// SQL 'max' function.
597    /// Returns the greatest (maximum) of all the elements in the grouping.
598    /// ```sql
599    /// SELECT MAX(col1) FROM df;
600    /// ```
601    Max,
602    /// SQL 'median' function.
603    /// Returns the median element from the grouping.
604    /// ```sql
605    /// SELECT MEDIAN(col1) FROM df;
606    /// ```
607    Median,
608    /// SQL 'quantile_cont' function.
609    /// Returns the continuous quantile element from the grouping
610    /// (interpolated value between two closest values).
611    /// ```sql
612    /// SELECT QUANTILE_CONT(col1) FROM df;
613    /// ```
614    QuantileCont,
615    /// SQL 'quantile_disc' function.
616    /// Divides the [0, 1] interval into equal-length subintervals, each corresponding to a value,
617    /// and returns the value associated with the subinterval where the quantile value falls.
618    /// ```sql
619    /// SELECT QUANTILE_DISC(col1) FROM df;
620    /// ```
621    QuantileDisc,
622    /// SQL 'min' function.
623    /// Returns the smallest (minimum) of all the elements in the grouping.
624    /// ```sql
625    /// SELECT MIN(col1) FROM df;
626    /// ```
627    Min,
628    /// SQL 'stddev' function.
629    /// Returns the standard deviation of all the elements in the grouping.
630    /// ```sql
631    /// SELECT STDDEV(col1) FROM df;
632    /// ```
633    StdDev,
634    /// SQL 'string_agg' function (also known as `GROUP_CONCAT`).
635    /// Concatenates the input string values into a single string,
636    /// separated by the given delimiter (`,` if unspecified).
637    /// ```sql
638    /// SELECT STRING_AGG(col1) FROM df;
639    /// SELECT STRING_AGG(col1, ',' ORDER BY col2 DESC) FROM df;
640    /// SELECT STRING_AGG(DISTINCT col1, ',' ORDER BY col1) FROM df;
641    /// ```
642    StringAgg,
643    /// SQL 'sum' function.
644    /// Returns the sum of all the elements in the grouping.
645    /// ```sql
646    /// SELECT SUM(col1) FROM df;
647    /// ```
648    Sum,
649    /// SQL 'variance' function.
650    /// Returns the variance of all the elements in the grouping.
651    /// ```sql
652    /// SELECT VARIANCE(col1) FROM df;
653    /// ```
654    Variance,
655
656    // ----
657    // Array functions
658    // ----
659    /// SQL 'array_length' function.
660    /// Returns the length of the array.
661    /// ```sql
662    /// SELECT ARRAY_LENGTH(col1) FROM df;
663    /// ```
664    ArrayLength,
665    /// SQL 'array_lower' function.
666    /// Returns the minimum value in an array; equivalent to `array_min`.
667    /// ```sql
668    /// SELECT ARRAY_LOWER(col1) FROM df;
669    /// ```
670    ArrayMin,
671    /// SQL 'array_upper' function.
672    /// Returns the maximum value in an array; equivalent to `array_max`.
673    /// ```sql
674    /// SELECT ARRAY_UPPER(col1) FROM df;
675    /// ```
676    ArrayMax,
677    /// SQL 'array_sum' function.
678    /// Returns the sum of all values in an array.
679    /// ```sql
680    /// SELECT ARRAY_SUM(col1) FROM df;
681    /// ```
682    ArraySum,
683    /// SQL 'array_mean' function.
684    /// Returns the mean of all values in an array.
685    /// ```sql
686    /// SELECT ARRAY_MEAN(col1) FROM df;
687    /// ```
688    ArrayMean,
689    /// SQL 'array_reverse' function.
690    /// Returns the array with the elements in reverse order.
691    /// ```sql
692    /// SELECT ARRAY_REVERSE(col1) FROM df;
693    /// ```
694    ArrayReverse,
695    /// SQL 'array_unique' function.
696    /// Returns the array with the unique elements.
697    /// ```sql
698    /// SELECT ARRAY_UNIQUE(col1) FROM df;
699    /// ```
700    ArrayUnique,
701    /// SQL 'array_agg' function.
702    /// Concatenates the input expressions, including nulls, into an array.
703    /// ```sql
704    /// SELECT ARRAY_AGG(col1, col2, ...) FROM df;
705    /// ```
706    ArrayAgg,
707    /// SQL 'array_to_string' function.
708    /// Takes all elements of the array and joins them into one string.
709    /// ```sql
710    /// SELECT ARRAY_TO_STRING(col1, ',') FROM df;
711    /// SELECT ARRAY_TO_STRING(col1, ',', 'n/a') FROM df;
712    /// ```
713    ArrayToString,
714    /// SQL 'array_get' function.
715    /// Returns the value at the given index in the array.
716    /// ```sql
717    /// SELECT ARRAY_GET(col1, 1) FROM df;
718    /// ```
719    ArrayGet,
720    /// SQL 'array_contains' function.
721    /// Returns true if the array contains the value.
722    /// ```sql
723    /// SELECT ARRAY_CONTAINS(col1, 'foo') FROM df;
724    /// ```
725    ArrayContains,
726    /// SQL 'unnest' function.
727    /// Unnest/explodes an array column into multiple rows.
728    /// ```sql
729    /// SELECT UNNEST(col1) FROM df;
730    /// ```
731    Explode,
732
733    // ----
734    // Window functions
735    // ----
736    /// SQL 'first_value' window function.
737    /// Returns the first value in an ordered set of values (respecting window frame).
738    /// ```sql
739    /// SELECT FIRST_VALUE(col1) OVER (PARTITION BY category ORDER BY id) FROM df;
740    /// ```
741    FirstValue,
742    /// SQL 'last_value' window function.
743    /// Returns the last value in an ordered set of values (respecting window frame).
744    /// With default frame, returns the current row's value.
745    /// ```sql
746    /// SELECT LAST_VALUE(col1) OVER (PARTITION BY category ORDER BY id) FROM df;
747    /// ```
748    LastValue,
749    /// SQL 'lag' function.
750    /// Returns the value of the expression evaluated at the row n rows before the current row.
751    /// ```sql
752    /// SELECT lag(column_1, 1) OVER (PARTITION BY column_2 ORDER BY column_3) FROM df;
753    /// ```
754    Lag,
755    /// SQL 'lead' function.
756    /// Returns the value of the expression evaluated at the row n rows after the current row.
757    /// ```sql
758    /// SELECT lead(column_1, 1) OVER (PARTITION BY column_2 ORDER BY column_3) FROM df;
759    /// ```
760    Lead,
761    /// SQL 'row_number' function.
762    /// Returns the sequential row number within a window partition, starting from 1.
763    /// ```sql
764    /// SELECT ROW_NUMBER() OVER (ORDER BY col1) FROM df;
765    /// SELECT ROW_NUMBER() OVER (PARTITION BY col1 ORDER BY col2) FROM df;
766    /// ```
767    RowNumber,
768    /// SQL 'rank' function.
769    /// Returns the rank of each row within a window partition, with gaps for ties.
770    /// Rows with equal values receive the same rank, and the next rank skips numbers.
771    /// ```sql
772    /// SELECT RANK() OVER (ORDER BY col1) FROM df;
773    /// SELECT RANK() OVER (PARTITION BY col1 ORDER BY col2 DESC) FROM df;
774    /// ```
775    #[cfg(feature = "rank")]
776    Rank,
777    /// SQL 'dense_rank' function.
778    /// Returns the rank of each row within a window partition, without gaps for ties.
779    /// Rows with equal values receive the same rank, and the next rank is consecutive.
780    /// ```sql
781    /// SELECT DENSE_RANK() OVER (ORDER BY col1) FROM df;
782    /// SELECT DENSE_RANK() OVER (PARTITION BY col1 ORDER BY col2 DESC) FROM df;
783    /// ```
784    #[cfg(feature = "rank")]
785    DenseRank,
786
787    // ----
788    // Column selection
789    // ----
790    Columns,
791
792    // ----
793    // User-defined
794    // ----
795    Udf(String),
796}
797
798impl PolarsSQLFunctions {
799    pub(crate) fn keywords() -> &'static [&'static str] {
800        &[
801            "abs",
802            "acos",
803            "acosd",
804            "array_contains",
805            "array_get",
806            "array_length",
807            "array_lower",
808            "array_mean",
809            "array_reverse",
810            "array_sum",
811            "array_to_string",
812            "array_unique",
813            "array_upper",
814            "asin",
815            "asind",
816            "atan",
817            "atan2",
818            "atan2d",
819            "atand",
820            "avg",
821            "bit_and",
822            "bit_count",
823            "bit_length",
824            "bit_or",
825            "bit_xor",
826            "cbrt",
827            "ceil",
828            "ceiling",
829            "char_length",
830            "character_length",
831            "coalesce",
832            "columns",
833            "concat",
834            "concat_ws",
835            "corr",
836            "cos",
837            "cosd",
838            "cot",
839            "cotd",
840            "count",
841            "covar",
842            "covar_pop",
843            "covar_samp",
844            "date",
845            "date_part",
846            "degrees",
847            "dense_rank",
848            "ends_with",
849            "exp",
850            "first",
851            "first_value",
852            "floor",
853            "greatest",
854            "if",
855            "ifnull",
856            "initcap",
857            "lag",
858            "last",
859            "last_value",
860            "lead",
861            "least",
862            "left",
863            "length",
864            "ln",
865            "log",
866            "log10",
867            "log1p",
868            "log2",
869            "lower",
870            "lpad",
871            "ltrim",
872            "max",
873            "median",
874            "quantile_disc",
875            "min",
876            "mod",
877            "nullif",
878            "octet_length",
879            "pi",
880            "pow",
881            "power",
882            "quantile_cont",
883            "quantile_disc",
884            "radians",
885            "rank",
886            "regexp_like",
887            "replace",
888            "reverse",
889            "right",
890            "round",
891            "row_number",
892            "rpad",
893            "rtrim",
894            "sign",
895            "sin",
896            "sind",
897            "sqrt",
898            "starts_with",
899            "stddev",
900            "stddev_samp",
901            "stdev",
902            "stdev_samp",
903            "strftime",
904            "strpos",
905            "strptime",
906            "substr",
907            "sum",
908            "tan",
909            "tand",
910            "unnest",
911            "upper",
912            "var",
913            "var_samp",
914            "variance",
915        ]
916    }
917}
918
919impl PolarsSQLFunctions {
920    fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult<Self> {
921        let function_name = function.name.0[0].as_ident().unwrap().value.to_lowercase();
922        Ok(match function_name.as_str() {
923            // ----
924            // Bitwise functions
925            // ----
926            "bit_and" | "bitand" => Self::BitAnd,
927            #[cfg(feature = "bitwise")]
928            "bit_count" | "bitcount" => Self::BitCount,
929            "bit_not" | "bitnot" => Self::BitNot,
930            "bit_or" | "bitor" => Self::BitOr,
931            "bit_xor" | "bitxor" | "xor" => Self::BitXor,
932
933            // ----
934            // Math functions
935            // ----
936            "abs" => Self::Abs,
937            "cbrt" => Self::Cbrt,
938            "ceil" | "ceiling" => Self::Ceil,
939            "div" => Self::Div,
940            "exp" => Self::Exp,
941            "floor" => Self::Floor,
942            "ln" => Self::Ln,
943            "log" => Self::Log,
944            "log10" => Self::Log10,
945            "log1p" => Self::Log1p,
946            "log2" => Self::Log2,
947            "mod" => Self::Mod,
948            "pi" => Self::Pi,
949            "pow" | "power" => Self::Pow,
950            "round" => Self::Round,
951            "trunc" | "truncate" => Self::Truncate,
952            "sign" => Self::Sign,
953            "sqrt" => Self::Sqrt,
954
955            // ----
956            // Trig functions
957            // ----
958            "cos" => Self::Cos,
959            "cot" => Self::Cot,
960            "sin" => Self::Sin,
961            "tan" => Self::Tan,
962            "cosd" => Self::CosD,
963            "cotd" => Self::CotD,
964            "sind" => Self::SinD,
965            "tand" => Self::TanD,
966            "acos" => Self::Acos,
967            "asin" => Self::Asin,
968            "atan" => Self::Atan,
969            "atan2" => Self::Atan2,
970            "acosd" => Self::AcosD,
971            "asind" => Self::AsinD,
972            "atand" => Self::AtanD,
973            "atan2d" => Self::Atan2D,
974            "degrees" => Self::Degrees,
975            "radians" => Self::Radians,
976
977            // ----
978            // Conditional functions
979            // ----
980            "coalesce" => Self::Coalesce,
981            "greatest" => Self::Greatest,
982            "if" => Self::If,
983            "ifnull" => Self::IfNull,
984            "least" => Self::Least,
985            "nullif" => Self::NullIf,
986
987            // ----
988            // Temporal functions
989            // ----
990            "date" => Self::Date,
991            "date_part" => Self::DatePart,
992            "strftime" => Self::Strftime,
993            "timestamp" | "datetime" => Self::Timestamp,
994
995            // ----
996            // String functions
997            // ----
998            "bit_length" => Self::BitLength,
999            "concat" => Self::Concat,
1000            "concat_ws" => Self::ConcatWS,
1001            "ends_with" => Self::EndsWith,
1002            #[cfg(feature = "nightly")]
1003            "initcap" => Self::InitCap,
1004            "left" => Self::Left,
1005            "length" | "char_length" | "character_length" => Self::Length,
1006            "lower" => Self::Lower,
1007            "lpad" => Self::LeftPad,
1008            "ltrim" => Self::LeftTrim,
1009            "normalize" => Self::Normalize,
1010            "octet_length" => Self::OctetLength,
1011            "regexp_like" => Self::RegexpLike,
1012            "replace" => Self::Replace,
1013            "reverse" => Self::Reverse,
1014            "right" => Self::Right,
1015            "rpad" => Self::RightPad,
1016            "rtrim" => Self::RightTrim,
1017            "split_part" => Self::SplitPart,
1018            "starts_with" => Self::StartsWith,
1019            "string_to_array" => Self::StringToArray,
1020            "strpos" => Self::StrPos,
1021            "strptime" => Self::Strptime,
1022            "substr" => Self::Substring,
1023            "time" => Self::Time,
1024            "upper" => Self::Upper,
1025
1026            // ----
1027            // Aggregate functions
1028            // ----
1029            "avg" => Self::Avg,
1030            "corr" => Self::Corr,
1031            "count" => Self::Count,
1032            "covar_pop" => Self::CovarPop,
1033            "covar_samp" | "covar" => Self::CovarSamp,
1034            "first" => Self::First,
1035            "last" => Self::Last,
1036            "max" => Self::Max,
1037            "median" => Self::Median,
1038            "min" => Self::Min,
1039            "quantile_cont" => Self::QuantileCont,
1040            "quantile_disc" => Self::QuantileDisc,
1041            "stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev,
1042            "string_agg" | "listagg" | "group_concat" => Self::StringAgg,
1043            "sum" => Self::Sum,
1044            "var" | "variance" | "var_samp" => Self::Variance,
1045
1046            // ----
1047            // Array functions
1048            // ----
1049            "array_agg" => Self::ArrayAgg,
1050            "array_contains" => Self::ArrayContains,
1051            "array_get" => Self::ArrayGet,
1052            "array_length" => Self::ArrayLength,
1053            "array_lower" => Self::ArrayMin,
1054            "array_mean" => Self::ArrayMean,
1055            "array_reverse" => Self::ArrayReverse,
1056            "array_sum" => Self::ArraySum,
1057            "array_to_string" => Self::ArrayToString,
1058            "array_unique" => Self::ArrayUnique,
1059            "array_upper" => Self::ArrayMax,
1060            "unnest" => Self::Explode,
1061
1062            // ----
1063            // Window functions
1064            // ----
1065            #[cfg(feature = "rank")]
1066            "dense_rank" => Self::DenseRank,
1067            "first_value" => Self::FirstValue,
1068            "last_value" => Self::LastValue,
1069            "lag" => Self::Lag,
1070            "lead" => Self::Lead,
1071            #[cfg(feature = "rank")]
1072            "rank" => Self::Rank,
1073            "row_number" => Self::RowNumber,
1074
1075            // ----
1076            // Column selection
1077            // ----
1078            "columns" => Self::Columns,
1079
1080            other => {
1081                if ctx.function_registry.contains(other) {
1082                    Self::Udf(other.to_string())
1083                } else {
1084                    polars_bail!(SQLInterface: "unsupported function '{}'", other);
1085                }
1086            },
1087        })
1088    }
1089}
1090
1091impl SQLFunctionVisitor<'_> {
1092    pub(crate) fn visit_function(&mut self) -> PolarsResult<Expr> {
1093        use PolarsSQLFunctions::*;
1094        use polars_lazy::prelude::Literal;
1095
1096        let function_name = PolarsSQLFunctions::try_from_sql(self.func, self.ctx)?;
1097        let function = self.func;
1098
1099        // TODO: implement the following modifiers where possible
1100        if !function.within_group.is_empty() {
1101            polars_bail!(SQLInterface: "'WITHIN GROUP' is not currently supported")
1102        }
1103        if function.null_treatment.is_some() {
1104            polars_bail!(SQLInterface: "'IGNORE|RESPECT NULLS' is not currently supported")
1105        }
1106        if let Some(filter_expr) = &function.filter {
1107            if function.over.is_some() {
1108                polars_bail!(SQLInterface: "'FILTER' combined with 'OVER' is not supported")
1109            }
1110            self.filter = Some(parse_sql_expr(filter_expr, self.ctx, self.active_schema)?);
1111        }
1112
1113        let log_with_base =
1114            |e: Expr, base: f64| e.log(LiteralValue::Dyn(DynLiteralValue::Float(base)).lit());
1115
1116        match function_name {
1117            // ----
1118            // Bitwise functions
1119            // ----
1120            BitAnd => self.visit_binary::<Expr>(Expr::and),
1121            #[cfg(feature = "bitwise")]
1122            BitCount => self.visit_unary(Expr::bitwise_count_ones),
1123            BitNot => self.visit_unary(Expr::not),
1124            BitOr => self.visit_binary::<Expr>(Expr::or),
1125            BitXor => self.visit_binary::<Expr>(Expr::xor),
1126
1127            // ----
1128            // Math functions
1129            // ----
1130            Abs => self.visit_unary(Expr::abs),
1131            Cbrt => self.visit_unary(Expr::cbrt),
1132            Ceil => self.visit_unary(Expr::ceil),
1133            Div => self.visit_binary(|e, d| e.floor_div(d).cast(DataType::Int64)),
1134            Exp => self.visit_unary(Expr::exp),
1135            Floor => self.visit_unary(Expr::floor),
1136            Ln => self.visit_unary(|e| log_with_base(e, std::f64::consts::E)),
1137            Log => self.visit_binary(Expr::log),
1138            Log10 => self.visit_unary(|e| log_with_base(e, 10.0)),
1139            Log1p => self.visit_unary(Expr::log1p),
1140            Log2 => self.visit_unary(|e| log_with_base(e, 2.0)),
1141            Pi => self.visit_nullary(Expr::pi),
1142            Mod => self.visit_binary(|e1, e2| e1 % e2),
1143            Pow => self.visit_binary::<Expr>(Expr::pow),
1144            Round => {
1145                let args = extract_args(function)?;
1146                match args.len() {
1147                    1 => self.visit_unary(|e| e.round(0, RoundMode::default())),
1148                    2 => self.try_visit_binary(|e, decimals| {
1149                        Ok(e.round(match decimals {
1150                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1151                                if n >= 0 { n as u32 } else {
1152                                    polars_bail!(SQLInterface: "ROUND does not support negative decimals value ({})", args[1])
1153                                }
1154                            },
1155                            _ => polars_bail!(SQLSyntax: "invalid value for ROUND decimals ({})", args[1]),
1156                        }, RoundMode::default()))
1157                    }),
1158                    _ => polars_bail!(SQLSyntax: "ROUND expects 1-2 arguments (found {})", args.len()),
1159                }
1160            },
1161            Truncate => {
1162                let args = extract_args(function)?;
1163                match args.len() {
1164                    1 => self.visit_unary(|e| e.truncate(0)),
1165                    2 => self.try_visit_binary(|e, decimals| {
1166                        Ok(e.truncate(match decimals {
1167                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1168                                if n >= 0 { n as u32 } else {
1169                                    polars_bail!(SQLInterface: "TRUNCATE does not support negative decimals value ({})", args[1])
1170                                }
1171                            },
1172                            _ => polars_bail!(SQLSyntax: "invalid value for TRUNCATE decimals ({})", args[1]),
1173                        }))
1174                    }),
1175                    _ => polars_bail!(SQLSyntax: "TRUNCATE expects 1-2 arguments (found {})", args.len()),
1176                }
1177            },
1178            Sign => self.visit_unary(Expr::sign),
1179            Sqrt => self.visit_unary(Expr::sqrt),
1180
1181            // ----
1182            // Trig functions
1183            // ----
1184            Acos => self.visit_unary(Expr::arccos),
1185            AcosD => self.visit_unary(|e| e.arccos().degrees()),
1186            Asin => self.visit_unary(Expr::arcsin),
1187            AsinD => self.visit_unary(|e| e.arcsin().degrees()),
1188            Atan => self.visit_unary(Expr::arctan),
1189            Atan2 => self.visit_binary(Expr::arctan2),
1190            Atan2D => self.visit_binary(|e, s| e.arctan2(s).degrees()),
1191            AtanD => self.visit_unary(|e| e.arctan().degrees()),
1192            Cos => self.visit_unary(Expr::cos),
1193            CosD => self.visit_unary(|e| e.radians().cos()),
1194            Cot => self.visit_unary(Expr::cot),
1195            CotD => self.visit_unary(|e| e.radians().cot()),
1196            Degrees => self.visit_unary(Expr::degrees),
1197            Radians => self.visit_unary(Expr::radians),
1198            Sin => self.visit_unary(Expr::sin),
1199            SinD => self.visit_unary(|e| e.radians().sin()),
1200            Tan => self.visit_unary(Expr::tan),
1201            TanD => self.visit_unary(|e| e.radians().tan()),
1202
1203            // ----
1204            // Conditional functions
1205            // ----
1206            Coalesce => self.visit_variadic(coalesce),
1207            Greatest => self.visit_variadic(|exprs: &[Expr]| max_horizontal(exprs).unwrap()),
1208            If => {
1209                let args = extract_args(function)?;
1210                match args.len() {
1211                    3 => self.try_visit_ternary(|cond: Expr, expr1: Expr, expr2: Expr| {
1212                        Ok(when(cond).then(expr1).otherwise(expr2))
1213                    }),
1214                    _ => {
1215                        polars_bail!(SQLSyntax: "IF expects 3 arguments (found {})", args.len()
1216                        )
1217                    },
1218                }
1219            },
1220            IfNull => {
1221                let args = extract_args(function)?;
1222                match args.len() {
1223                    2 => self.visit_variadic(coalesce),
1224                    _ => {
1225                        polars_bail!(SQLSyntax: "IFNULL expects 2 arguments (found {})", args.len())
1226                    },
1227                }
1228            },
1229            Least => self.visit_variadic(|exprs: &[Expr]| min_horizontal(exprs).unwrap()),
1230            NullIf => {
1231                let args = extract_args(function)?;
1232                match args.len() {
1233                    2 => self.visit_binary(|l: Expr, r: Expr| {
1234                        when(l.clone().eq(r))
1235                            .then(lit(LiteralValue::untyped_null()))
1236                            .otherwise(l)
1237                    }),
1238                    _ => {
1239                        polars_bail!(SQLSyntax: "NULLIF expects 2 arguments (found {})", args.len())
1240                    },
1241                }
1242            },
1243
1244            // ----
1245            // Date functions
1246            // ----
1247            DatePart => self.try_visit_binary(|part, e| {
1248                match part {
1249                    Expr::Literal(p) if p.extract_str().is_some() => {
1250                        let p = p.extract_str().unwrap();
1251                        // note: 'DATE_PART' and 'EXTRACT' are minor syntactic
1252                        // variations on otherwise identical functionality
1253                        parse_extract_date_part(
1254                            e,
1255                            &DateTimeField::Custom(Ident {
1256                                value: p.to_string(),
1257                                quote_style: None,
1258                                span: Span::empty(),
1259                            }),
1260                        )
1261                    },
1262                    _ => {
1263                        polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART ({})", part);
1264                    },
1265                }
1266            }),
1267            Strftime => {
1268                let args = extract_args(function)?;
1269                match args.len() {
1270                    2 => self.visit_binary(|e, fmt: String| e.dt().strftime(fmt.as_str())),
1271                    _ => {
1272                        polars_bail!(SQLSyntax: "STRFTIME expects 2 arguments (found {})", args.len())
1273                    },
1274                }
1275            },
1276
1277            // ----
1278            // String functions
1279            // ----
1280            BitLength => self.visit_unary(|e| e.str().len_bytes() * lit(8)),
1281            Concat => {
1282                let args = extract_args(function)?;
1283                if args.is_empty() {
1284                    polars_bail!(SQLSyntax: "CONCAT expects at least 1 argument (found 0)");
1285                } else {
1286                    self.visit_variadic(|exprs: &[Expr]| concat_str(exprs, "", true))
1287                }
1288            },
1289            ConcatWS => {
1290                let args = extract_args(function)?;
1291                if args.len() < 2 {
1292                    polars_bail!(SQLSyntax: "CONCAT_WS expects at least 2 arguments (found {})", args.len());
1293                } else {
1294                    self.try_visit_variadic(|exprs: &[Expr]| {
1295                        match &exprs[0] {
1296                            Expr::Literal(lv) if lv.extract_str().is_some() => Ok(concat_str(&exprs[1..], lv.extract_str().unwrap(), true)),
1297                            _ => polars_bail!(SQLSyntax: "CONCAT_WS 'separator' must be a literal string (found {:?})", exprs[0]),
1298                        }
1299                    })
1300                }
1301            },
1302            Date => {
1303                let args = extract_args(function)?;
1304                match args.len() {
1305                    1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())),
1306                    2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
1307                    _ => {
1308                        polars_bail!(SQLSyntax: "DATE expects 1-2 arguments (found {})", args.len())
1309                    },
1310                }
1311            },
1312            EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)),
1313            #[cfg(feature = "nightly")]
1314            InitCap => self.visit_unary(|e| e.str().to_titlecase()),
1315            Left => self.try_visit_binary(|e, length| {
1316                Ok(match length {
1317                    Expr::Literal(lv) if lv.is_null() => lit(lv),
1318                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => lit(""),
1319                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1320                        let len = if n > 0 {
1321                            lit(n)
1322                        } else {
1323                            (e.clone().str().len_chars() + lit(n)).clip_min(lit(0))
1324                        };
1325                        e.str().slice(lit(0), len)
1326                    },
1327                    Expr::Literal(v) => {
1328                        polars_bail!(SQLSyntax: "invalid 'n_chars' for LEFT ({:?})", v)
1329                    },
1330                    _ => when(length.clone().gt_eq(lit(0)))
1331                        .then(e.clone().str().slice(lit(0), length.clone().abs()))
1332                        .otherwise(e.clone().str().slice(
1333                            lit(0),
1334                            (e.str().len_chars() + length.clone()).clip_min(lit(0)),
1335                        )),
1336                })
1337            }),
1338            LeftPad | RightPad => {
1339                let is_lpad = matches!(function_name, LeftPad);
1340                let fname = if is_lpad { "LPAD" } else { "RPAD" };
1341                let args = extract_args(function)?;
1342                let pad = |e: Expr, length: Expr, fill_char: char| {
1343                    let padded = if is_lpad {
1344                        e.str().pad_start(length.clone(), fill_char)
1345                    } else {
1346                        e.str().pad_end(length.clone(), fill_char)
1347                    };
1348                    Ok(padded.str().slice(lit(0), length))
1349                };
1350                match args.len() {
1351                    2 => self.try_visit_binary(|e, length| pad(e, length, ' ')),
1352                    3 => self.try_visit_ternary(|e: Expr, length: Expr, fill: Expr| match fill {
1353                        Expr::Literal(lv) if lv.extract_str().is_some() => {
1354                            let s = lv.extract_str().unwrap();
1355                            let mut chars = s.chars();
1356                            match (chars.next(), chars.next()) {
1357                                (Some(c), None) => pad(e, length, c),
1358                                _ => polars_bail!(SQLSyntax: "{} fill value must be a single character (found '{}')", fname, s),
1359                            }
1360                        },
1361                        _ => polars_bail!(SQLSyntax: "{} fill value must be a string literal", fname),
1362                    }),
1363                    _ => polars_bail!(SQLSyntax: "{} expects 2-3 arguments (found {})", fname, args.len()),
1364                }
1365            },
1366            LeftTrim | RightTrim => {
1367                let is_ltrim = matches!(function_name, LeftTrim);
1368                let fname = if is_ltrim { "LTRIM" } else { "RTRIM" };
1369                let strip: fn(Expr, Expr) -> Expr = if is_ltrim {
1370                    |e, s| e.str().strip_chars_start(s)
1371                } else {
1372                    |e, s| e.str().strip_chars_end(s)
1373                };
1374                let args = extract_args(function)?;
1375                match args.len() {
1376                    1 => self.visit_unary(|e| strip(e, lit(LiteralValue::untyped_null()))),
1377                    2 => self.visit_binary(strip),
1378                    _ => {
1379                        polars_bail!(SQLSyntax: "{} expects 1-2 arguments (found {})", fname, args.len())
1380                    },
1381                }
1382            },
1383            Length => self.visit_unary(|e| e.str().len_chars()),
1384            Lower => self.visit_unary(|e| e.str().to_lowercase()),
1385            Normalize => {
1386                let args = extract_args(function)?;
1387                match args.len() {
1388                    1 => self.visit_unary(|e| e.str().normalize(UnicodeForm::NFC)),
1389                    2 => {
1390                        let form = if let FunctionArgExpr::Expr(SQLExpr::Identifier(Ident {
1391                            value: s,
1392                            quote_style: None,
1393                            span: _,
1394                        })) = args[1]
1395                        {
1396                            match s.to_uppercase().as_str() {
1397                                "NFC" => UnicodeForm::NFC,
1398                                "NFD" => UnicodeForm::NFD,
1399                                "NFKC" => UnicodeForm::NFKC,
1400                                "NFKD" => UnicodeForm::NFKD,
1401                                _ => {
1402                                    polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", s)
1403                                },
1404                            }
1405                        } else {
1406                            polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", args[1])
1407                        };
1408                        self.try_visit_binary(|e, _form: Expr| Ok(e.str().normalize(form.clone())))
1409                    },
1410                    _ => {
1411                        polars_bail!(SQLSyntax: "NORMALIZE expects 1-2 arguments (found {})", args.len())
1412                    },
1413                }
1414            },
1415            OctetLength => self.visit_unary(|e| e.str().len_bytes()),
1416            StrPos => {
1417                // note: SQL is 1-indexed; returns zero if no match found
1418                self.visit_binary(|expr, substring| {
1419                    (expr.str().find(substring, true) + typed_lit(1u32)).fill_null(typed_lit(0u32))
1420                })
1421            },
1422            RegexpLike => {
1423                let args = extract_args(function)?;
1424                match args.len() {
1425                    2 => self.visit_binary(|e, s| e.str().contains(s, true)),
1426                    3 => self.try_visit_ternary(|e, pat, flags| {
1427                        Ok(e.str().contains(
1428                            match (pat, flags) {
1429                                (Expr::Literal(s_lv), Expr::Literal(f_lv)) if s_lv.extract_str().is_some() && f_lv.extract_str().is_some() => {
1430                                    let s = s_lv.extract_str().unwrap();
1431                                    let f = f_lv.extract_str().unwrap();
1432                                    if f.is_empty() {
1433                                        polars_bail!(SQLSyntax: "invalid/empty 'flags' for REGEXP_LIKE ({})", args[2]);
1434                                    };
1435                                    lit(format!("(?{f}){s}"))
1436                                },
1437                                _ => {
1438                                    polars_bail!(SQLSyntax: "invalid arguments for REGEXP_LIKE ({}, {})", args[1], args[2]);
1439                                },
1440                            },
1441                            true))
1442                    }),
1443                    _ => polars_bail!(SQLSyntax: "REGEXP_LIKE expects 2-3 arguments (found {})",args.len()),
1444                }
1445            },
1446            Replace => {
1447                let args = extract_args(function)?;
1448                match args.len() {
1449                    3 => self
1450                        .try_visit_ternary(|e, old, new| Ok(e.str().replace_all(old, new, true))),
1451                    _ => {
1452                        polars_bail!(SQLSyntax: "REPLACE expects 3 arguments (found {})", args.len())
1453                    },
1454                }
1455            },
1456            Reverse => self.visit_unary(|e| e.str().reverse()),
1457            Right => self.try_visit_binary(|e, length| {
1458                Ok(match length {
1459                    Expr::Literal(lv) if lv.is_null() => lit(lv),
1460                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(0))) => typed_lit(""),
1461                    Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1462                        let n: i64 = n.try_into().unwrap();
1463                        let offset = if n < 0 {
1464                            lit(n.abs())
1465                        } else {
1466                            e.clone().str().len_chars().cast(DataType::Int32) - lit(n)
1467                        };
1468                        e.str().slice(offset, lit(LiteralValue::untyped_null()))
1469                    },
1470                    Expr::Literal(v) => {
1471                        polars_bail!(SQLSyntax: "invalid 'n_chars' for RIGHT ({:?})", v)
1472                    },
1473                    _ => when(length.clone().lt(lit(0)))
1474                        .then(
1475                            e.clone()
1476                                .str()
1477                                .slice(length.clone().abs(), lit(LiteralValue::untyped_null())),
1478                        )
1479                        .otherwise(e.clone().str().slice(
1480                            e.str().len_chars().cast(DataType::Int32) - length.clone(),
1481                            lit(LiteralValue::untyped_null()),
1482                        )),
1483                })
1484            }),
1485            SplitPart => {
1486                let args = extract_args(function)?;
1487                match args.len() {
1488                    3 => self.try_visit_ternary(|e, sep, idx| {
1489                        let idx = adjust_one_indexed_param(idx, true);
1490                        Ok(when(e.clone().is_not_null())
1491                            .then(
1492                                e.clone()
1493                                    .str()
1494                                    .split(sep)
1495                                    .list()
1496                                    .get(idx, true)
1497                                    .fill_null(lit("")),
1498                            )
1499                            .otherwise(e))
1500                    }),
1501                    _ => {
1502                        polars_bail!(SQLSyntax: "SPLIT_PART expects 3 arguments (found {})", args.len())
1503                    },
1504                }
1505            },
1506            StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
1507            StringToArray => {
1508                let args = extract_args(function)?;
1509                match args.len() {
1510                    2 => self.visit_binary(|e, sep| e.str().split(sep)),
1511                    _ => {
1512                        polars_bail!(SQLSyntax: "STRING_TO_ARRAY expects 2 arguments (found {})", args.len())
1513                    },
1514                }
1515            },
1516            Strptime => {
1517                let args = extract_args(function)?;
1518                match args.len() {
1519                    2 => self.visit_binary(|e, fmt: String| {
1520                        e.str().strptime(
1521                            DataType::Datetime(TimeUnit::Microseconds, None),
1522                            StrptimeOptions {
1523                                format: Some(fmt.into()),
1524                                ..Default::default()
1525                            },
1526                            lit("latest"),
1527                        )
1528                    }),
1529                    _ => {
1530                        polars_bail!(SQLSyntax: "STRPTIME expects 2 arguments (found {})", args.len())
1531                    },
1532                }
1533            },
1534            Time => {
1535                let args = extract_args(function)?;
1536                match args.len() {
1537                    1 => self.visit_unary(|e| e.str().to_time(StrptimeOptions::default())),
1538                    2 => self.visit_binary(|e, fmt| e.str().to_time(fmt)),
1539                    _ => {
1540                        polars_bail!(SQLSyntax: "TIME expects 1-2 arguments (found {})", args.len())
1541                    },
1542                }
1543            },
1544            Timestamp => {
1545                let args = extract_args(function)?;
1546                match args.len() {
1547                    1 => self.visit_unary(|e| {
1548                        e.str()
1549                            .to_datetime(None, None, StrptimeOptions::default(), lit("latest"))
1550                    }),
1551                    2 => self
1552                        .visit_binary(|e, fmt| e.str().to_datetime(None, None, fmt, lit("latest"))),
1553                    _ => {
1554                        polars_bail!(SQLSyntax: "DATETIME expects 1-2 arguments (found {})", args.len())
1555                    },
1556                }
1557            },
1558            Substring => {
1559                let args = extract_args(function)?;
1560                match args.len() {
1561                    // note: SQL is 1-indexed, hence the need for adjustments
1562                    2 => self.try_visit_binary(|e, start| {
1563                        Ok(match start {
1564                            Expr::Literal(lv) if lv.is_null() => lit(lv),
1565                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n <= 0 => e,
1566                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => e.str().slice(lit(n - 1), lit(LiteralValue::untyped_null())),
1567                            Expr::Literal(_) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1568                            _ => start.clone() + lit(1),
1569                        })
1570                    }),
1571                    3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| {
1572                        Ok(match (start.clone(), length.clone()) {
1573                            (Expr::Literal(lv), _) | (_, Expr::Literal(lv)) if lv.is_null() => lit(lv),
1574                            (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n)))) if n < 0 => {
1575                                polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", args[2])
1576                            },
1577                            (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) if n > 0 => e.str().slice(lit(n - 1), length),
1578                            (Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))), _) => {
1579                                e.str().slice(lit(0), (length + lit(n - 1)).clip_min(lit(0)))
1580                            },
1581                            (Expr::Literal(_), _) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]),
1582                            (_, Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(_)))) => {
1583                                polars_bail!(SQLSyntax: "invalid 'length' for SUBSTR ({})", args[1])
1584                            },
1585                            _ => {
1586                                let adjusted_start = start - lit(1);
1587                                when(adjusted_start.clone().lt(lit(0)))
1588                                    .then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0))))
1589                                    .otherwise(e.str().slice(adjusted_start, length))
1590                            }
1591                        })
1592                    }),
1593                    _ => polars_bail!(SQLSyntax: "SUBSTR expects 2-3 arguments (found {})", args.len()),
1594                }
1595            },
1596            Upper => self.visit_unary(|e| e.str().to_uppercase()),
1597
1598            // ----
1599            // Aggregate functions
1600            // ----
1601            Avg => self.visit_unary(Expr::mean),
1602            Corr => self.visit_binary(polars_lazy::dsl::pearson_corr),
1603            Count => self.visit_count(),
1604            CovarPop => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 0)),
1605            CovarSamp => self.visit_binary(|a, b| polars_lazy::dsl::cov(a, b, 1)),
1606            First => self.visit_unary(Expr::first),
1607            Last => self.visit_unary(Expr::last),
1608            Max => self.visit_unary_with_opt_cumulative(Expr::max, Expr::cum_max),
1609            Median => self.visit_unary(Expr::median),
1610            QuantileCont | QuantileDisc => {
1611                let (fname, method) = if matches!(function_name, QuantileCont) {
1612                    ("QUANTILE_CONT", QuantileMethod::Linear)
1613                } else {
1614                    ("QUANTILE_DISC", QuantileMethod::Equiprobable)
1615                };
1616                let args = extract_args(function)?;
1617                match args.len() {
1618                    2 => self.try_visit_binary(|e, q| {
1619                        let value = match q {
1620                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Float(f))) => {
1621                                if (0.0..=1.0).contains(&f) {
1622                                    Expr::from(f)
1623                                } else {
1624                                    polars_bail!(SQLSyntax: "{} value must be between 0 and 1 ({})", fname, args[1])
1625                                }
1626                            },
1627                            Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) => {
1628                                if (0..=1).contains(&n) {
1629                                    Expr::from(n as f64)
1630                                } else {
1631                                    polars_bail!(SQLSyntax: "{} value must be between 0 and 1 ({})", fname, args[1])
1632                                }
1633                            },
1634                            _ => polars_bail!(SQLSyntax: "invalid value for {} ({})", fname, args[1])
1635                        };
1636                        Ok(e.quantile(value, method))
1637                    }),
1638                    _ => polars_bail!(SQLSyntax: "{} expects 2 arguments (found {})", fname, args.len()),
1639                }
1640            },
1641            Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min),
1642            StdDev => self.visit_unary(|e| e.std(1)),
1643            StringAgg => self.visit_string_agg(),
1644            Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum),
1645            Variance => self.visit_unary(|e| e.var(1)),
1646
1647            // ----
1648            // Array functions
1649            // ----
1650            ArrayAgg => self.visit_arr_agg(),
1651            ArrayContains => self.visit_binary::<Expr>(|e, s| e.list().contains(s, true)),
1652            ArrayGet => {
1653                // note: SQL is 1-indexed, not 0-indexed
1654                self.visit_binary(|e, idx: Expr| {
1655                    let idx = adjust_one_indexed_param(idx, true);
1656                    e.list().get(idx, true)
1657                })
1658            },
1659            ArrayLength => self.visit_unary(|e| e.list().len()),
1660            ArrayMax => self.visit_unary(|e| e.list().max()),
1661            ArrayMean => self.visit_unary(|e| e.list().mean()),
1662            ArrayMin => self.visit_unary(|e| e.list().min()),
1663            ArrayReverse => self.visit_unary(|e| e.list().eval(element().reverse())),
1664            ArraySum => self.visit_unary(|e| e.list().sum()),
1665            ArrayToString => self.visit_arr_to_string(),
1666            ArrayUnique => self.visit_unary(|e| e.list().eval(element().unique_stable())),
1667            Explode => self.visit_unary(|e| {
1668                e.explode(ExplodeOptions {
1669                    empty_as_null: true,
1670                    keep_nulls: true,
1671                })
1672            }),
1673
1674            // ----
1675            // Column selection
1676            // ----
1677            Columns => {
1678                let active_schema = self.active_schema;
1679                self.try_visit_unary(|e: Expr| match e {
1680                    Expr::Literal(lv) if lv.extract_str().is_some() => {
1681                        let pat = lv.extract_str().unwrap();
1682                        if pat == "*" {
1683                            polars_bail!(
1684                                SQLSyntax: "COLUMNS('*') is not a valid regex; \
1685                                did you mean COLUMNS(*)?"
1686                            )
1687                        };
1688                        let pat = match pat {
1689                            _ if pat.starts_with('^') && pat.ends_with('$') => pat.to_string(),
1690                            _ if pat.starts_with('^') => format!("{pat}.*$"),
1691                            _ if pat.ends_with('$') => format!("^.*{pat}"),
1692                            _ => format!("^.*{pat}.*$"),
1693                        };
1694                        if let Some(active_schema) = &active_schema {
1695                            let rx = polars_utils::regex_cache::compile_regex(&pat).unwrap();
1696                            let col_names = active_schema
1697                                .iter_names()
1698                                .filter(|name| rx.is_match(name))
1699                                .cloned()
1700                                .collect::<Vec<_>>();
1701
1702                            Ok(if col_names.len() == 1 {
1703                                col(col_names.into_iter().next().unwrap())
1704                            } else {
1705                                cols(col_names).as_expr()
1706                            })
1707                        } else {
1708                            Ok(col(pat.as_str()))
1709                        }
1710                    },
1711                    Expr::Selector(s) => Ok(s.as_expr()),
1712                    _ => polars_bail!(SQLSyntax: "COLUMNS expects a regex; found {:?}", e),
1713                })
1714            },
1715
1716            // ----
1717            // Window functions
1718            // ----
1719            FirstValue => self.visit_unary(Expr::first),
1720            LastValue => {
1721                // With the default window frame (ROWS UNBOUNDED PRECEDING TO CURRENT ROW),
1722                // LAST_VALUE returns the last value from the start of the partition up
1723                // to the current row - which is simply the current row's value.
1724                let args = extract_args(function)?;
1725                match args.as_slice() {
1726                    [FunctionArgExpr::Expr(sql_expr)] => {
1727                        parse_sql_expr(sql_expr, self.ctx, self.active_schema)
1728                    },
1729                    _ => polars_bail!(
1730                        SQLSyntax: "LAST_VALUE expects exactly 1 argument (found {})",
1731                        args.len()
1732                    ),
1733                }
1734            },
1735            Lag => self.visit_window_offset_function(1),
1736            Lead => self.visit_window_offset_function(-1),
1737            #[cfg(feature = "rank")]
1738            Rank | DenseRank => {
1739                let (func_name, rank_method) = match function_name {
1740                    Rank => ("RANK", RankMethod::Min),
1741                    DenseRank => ("DENSE_RANK", RankMethod::Dense),
1742                    _ => unreachable!(),
1743                };
1744                let args = extract_args(function)?;
1745                if !args.is_empty() {
1746                    polars_bail!(SQLSyntax: "{} expects 0 arguments (found {})", func_name, args.len());
1747                }
1748                let window_spec = match &self.func.over {
1749                    Some(WindowType::WindowSpec(spec)) if !spec.order_by.is_empty() => spec,
1750                    _ => {
1751                        polars_bail!(SQLSyntax: "{} requires an OVER clause with ORDER BY", func_name)
1752                    },
1753                };
1754                let (order_exprs, all_desc) =
1755                    self.parse_order_by_in_window(&window_spec.order_by)?;
1756                let rank_expr = if order_exprs.len() == 1 {
1757                    order_exprs[0].clone().rank(
1758                        RankOptions {
1759                            method: rank_method,
1760                            descending: all_desc,
1761                        },
1762                        None,
1763                    )
1764                } else {
1765                    as_struct(order_exprs).rank(
1766                        RankOptions {
1767                            method: rank_method,
1768                            descending: all_desc,
1769                        },
1770                        None,
1771                    )
1772                };
1773                self.apply_window_spec(rank_expr, &self.func.over)
1774            },
1775            RowNumber => {
1776                let args = extract_args(function)?;
1777                if !args.is_empty() {
1778                    polars_bail!(SQLSyntax: "ROW_NUMBER expects 0 arguments (found {})", args.len());
1779                }
1780                // note: SQL is 1-indexed
1781                let row_num_expr = int_range(lit(0i64), len(), 1, DataType::UInt32) + lit(1u32);
1782                self.apply_window_spec(row_num_expr, &self.func.over)
1783            },
1784
1785            // ----
1786            // User-defined
1787            // ----
1788            Udf(func_name) => self.visit_udf(&func_name),
1789        }
1790    }
1791
1792    fn visit_window_offset_function(&mut self, offset_multiplier: i64) -> PolarsResult<Expr> {
1793        // LAG/LEAD require an OVER clause
1794        if self.func.over.is_none() {
1795            polars_bail!(SQLSyntax: "{} requires an OVER clause", self.func.name);
1796        }
1797
1798        // LAG/LEAD require ORDER BY in the OVER clause
1799        let window_type = self.func.over.as_ref().unwrap();
1800        let window_spec = self.resolve_window_spec(window_type)?;
1801        if window_spec.order_by.is_empty() {
1802            polars_bail!(SQLSyntax: "{} requires an ORDER BY in the OVER clause", self.func.name);
1803        }
1804
1805        let args = extract_args(self.func)?;
1806
1807        match args.as_slice() {
1808            [FunctionArgExpr::Expr(sql_expr)] => {
1809                let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1810                Ok(expr.shift(offset_multiplier.into()))
1811            },
1812            [FunctionArgExpr::Expr(sql_expr), FunctionArgExpr::Expr(offset_expr)] => {
1813                let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
1814                let offset = parse_sql_expr(offset_expr, self.ctx, self.active_schema)?;
1815                if let Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) = offset {
1816                    if n <= 0 {
1817                        polars_bail!(SQLSyntax: "offset must be positive (found {})", n)
1818                    }
1819                    Ok(expr.shift((offset_multiplier * n as i64).into()))
1820                } else {
1821                    polars_bail!(SQLSyntax: "offset must be an integer (found {:?})", offset)
1822                }
1823            },
1824            _ => polars_bail!(SQLSyntax: "{} expects 1 or 2 arguments (found {})", self.func.name, args.len()),
1825        }.and_then(|e| self.apply_window_spec(e, &self.func.over))
1826    }
1827
1828    fn visit_udf(&mut self, func_name: &str) -> PolarsResult<Expr> {
1829        let args = extract_args(self.func)?
1830            .into_iter()
1831            .map(|arg| {
1832                if let FunctionArgExpr::Expr(e) = arg {
1833                    parse_sql_expr(e, self.ctx, self.active_schema)
1834                } else {
1835                    polars_bail!(SQLInterface: "only expressions are supported in UDFs")
1836                }
1837            })
1838            .collect::<PolarsResult<Vec<_>>>()?;
1839
1840        Ok(self
1841            .ctx
1842            .function_registry
1843            .get_udf(func_name)?
1844            .ok_or_else(|| polars_err!(SQLInterface: "UDF {} not found", func_name))?
1845            .call(args))
1846    }
1847
1848    /// Validate window frame specifications.
1849    ///
1850    /// Polars only supports ROWS frame semantics, and does
1851    /// not currently support customising the window.
1852    ///
1853    /// **Supported Frame Spec**
1854    /// - `ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`
1855    ///
1856    /// **Unsupported Frame Spec**
1857    /// - `RANGE ...` (peer group semantics not implemented)
1858    /// - `GROUPS ...` (peer group semantics not implemented)
1859    /// - `ROWS` with other bounds (e.g., `<n> PRECEDING`, `FOLLOWING`, etc)
1860    fn validate_window_frame(&self, window_frame: &Option<WindowFrame>) -> PolarsResult<()> {
1861        if let Some(frame) = window_frame {
1862            match frame.units {
1863                WindowFrameUnits::Range => {
1864                    polars_bail!(
1865                        SQLInterface:
1866                        "RANGE-based window frames are not supported"
1867                    );
1868                },
1869                WindowFrameUnits::Groups => {
1870                    polars_bail!(
1871                        SQLInterface:
1872                        "GROUPS-based window frames are not supported"
1873                    );
1874                },
1875                WindowFrameUnits::Rows => {
1876                    if !matches!(
1877                        (&frame.start_bound, &frame.end_bound),
1878                        (
1879                            WindowFrameBound::Preceding(None),         // UNBOUNDED PRECEDING
1880                            None | Some(WindowFrameBound::CurrentRow)  // CURRENT ROW
1881                        )
1882                    ) {
1883                        polars_bail!(
1884                            SQLInterface:
1885                            "only 'ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW' is currently supported; found 'ROWS BETWEEN {} AND {}'",
1886                            frame.start_bound,
1887                            frame.end_bound.as_ref().map_or("CURRENT ROW", |b| {
1888                                match b {
1889                                    WindowFrameBound::CurrentRow => "CURRENT ROW",
1890                                    WindowFrameBound::Preceding(_) => "N PRECEDING",
1891                                    WindowFrameBound::Following(_) => "N FOLLOWING",
1892                                }
1893                            })
1894                        );
1895                    }
1896                },
1897            }
1898        }
1899        Ok(())
1900    }
1901
1902    /// Window specs that map to cumulative functions.
1903    ///
1904    /// Converts SQL window functions with ORDER BY to compatible cumulative ops:
1905    /// - `SUM(a) OVER (ORDER BY b)` → `a.cum_sum().over(order_by=b)`
1906    /// - `MAX(a) OVER (ORDER BY b)` → `a.cum_max().over(order_by=b)`
1907    /// - `MIN(a) OVER (ORDER BY b)` → `a.cum_min().over(order_by=b)`
1908    ///
1909    /// ROWS vs RANGE Semantics (show default behaviour if no frame spec):
1910    ///
1911    /// **Polars (ROWS)**
1912    /// Each row gets its own cumulative value row-by-row.
1913    /// ```text
1914    /// Data: [(A,X,10), (A,X,15), (A,Y,20)]
1915    /// Query: SUM(value) OVER (ORDER BY category, subcategory)
1916    /// Result: [10, 25, 45]  ← row-by-row cumulative
1917    /// ```
1918    ///
1919    /// **SQL (RANGE)**
1920    /// Rows with identical ORDER BY values (peers) get the same result.
1921    /// ```text
1922    /// Same data, query with RANGE (eg: using a relational DB):
1923    /// Result: [25, 25, 45]  ← both (A,X) rows get 25
1924    /// ```
1925    fn apply_cumulative_window(
1926        &mut self,
1927        f: impl Fn(Expr) -> Expr,
1928        cumulative_fn: impl Fn(Expr, bool) -> Expr,
1929        WindowSpec {
1930            partition_by,
1931            order_by,
1932            window_frame,
1933            ..
1934        }: &WindowSpec,
1935    ) -> PolarsResult<Expr> {
1936        self.validate_window_frame(window_frame)?;
1937
1938        if !order_by.is_empty() {
1939            // Extract ORDER BY exprs and sort direction
1940            let (order_by_exprs, all_desc) = self.parse_order_by_in_window(order_by)?;
1941
1942            // Get the base expr/column
1943            let args = extract_args(self.func)?;
1944            let base_expr = match args.as_slice() {
1945                [FunctionArgExpr::Expr(sql_expr)] => {
1946                    parse_sql_expr(sql_expr, self.ctx, self.active_schema)?
1947                },
1948                _ => return self.not_supported_error(),
1949            };
1950            let partition_by_exprs = if partition_by.is_empty() {
1951                None
1952            } else {
1953                Some(
1954                    partition_by
1955                        .iter()
1956                        .map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
1957                        .collect::<PolarsResult<Vec<_>>>()?,
1958                )
1959            };
1960
1961            // Apply cumulative function; the forward-fill ensures we match SQL semantics
1962            let cumulative_expr = cumulative_fn(base_expr, false)
1963                .fill_null_with_strategy(FillNullStrategy::Forward(None));
1964            let sort_opts = SortOptions::default().with_order_descending(all_desc);
1965            cumulative_expr.over_with_options(
1966                partition_by_exprs,
1967                Some((order_by_exprs, sort_opts)),
1968                Default::default(),
1969            )
1970        } else {
1971            self.visit_unary(f)
1972        }
1973    }
1974
1975    /// Parse an argument of the function currently being visited.
1976    ///
1977    /// Behaves like [`parse_sql_expr`] but also accounts for any
1978    /// active `FILTER (WHERE …)` clause from the surrounding call.
1979    fn parse_sql_arg(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
1980        let parsed = parse_sql_expr(expr, self.ctx, self.active_schema)?;
1981        Ok(match &self.filter {
1982            Some(pred) => parsed.filter(pred.clone()),
1983            None => parsed,
1984        })
1985    }
1986
1987    fn visit_unary(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult<Expr> {
1988        self.try_visit_unary(|e| Ok(f(e)))
1989    }
1990
1991    fn try_visit_unary(&mut self, f: impl Fn(Expr) -> PolarsResult<Expr>) -> PolarsResult<Expr> {
1992        let args = extract_args(self.func)?;
1993        match args.as_slice() {
1994            [FunctionArgExpr::Expr(sql_expr)] => f(self.parse_sql_arg(sql_expr)?),
1995            [FunctionArgExpr::Wildcard] => {
1996                f(self.parse_sql_arg(&SQLExpr::Wildcard(AttachedToken::empty()))?)
1997            },
1998            _ => self.not_supported_error(),
1999        }
2000        .and_then(|e| self.apply_window_spec(e, &self.func.over))
2001    }
2002
2003    /// Resolve a WindowType to a concrete WindowSpec (handles named window references)
2004    fn resolve_window_spec(&self, window_type: &WindowType) -> PolarsResult<WindowSpec> {
2005        match window_type {
2006            WindowType::WindowSpec(spec) => Ok(spec.clone()),
2007            WindowType::NamedWindow(name) => self
2008                .ctx
2009                .named_windows
2010                .get(&name.value)
2011                .cloned()
2012                .ok_or_else(|| {
2013                    polars_err!(
2014                        SQLInterface:
2015                        "named window '{}' was not found",
2016                        name.value
2017                    )
2018                }),
2019        }
2020    }
2021
2022    /// Some functions have cumulative equivalents that can be applied to window specs
2023    /// e.g. SUM(a) OVER (ORDER BY b DESC) -> CUMSUM(a, false)
2024    fn visit_unary_with_opt_cumulative(
2025        &mut self,
2026        f: impl Fn(Expr) -> Expr,
2027        cumulative_fn: impl Fn(Expr, bool) -> Expr,
2028    ) -> PolarsResult<Expr> {
2029        match self.func.over.as_ref() {
2030            Some(window_type) => {
2031                let spec = self.resolve_window_spec(window_type)?;
2032                self.apply_cumulative_window(f, cumulative_fn, &spec)
2033            },
2034            None => self.visit_unary(f),
2035        }
2036    }
2037
2038    fn visit_binary<Arg: FromSQLExpr>(
2039        &mut self,
2040        f: impl Fn(Expr, Arg) -> Expr,
2041    ) -> PolarsResult<Expr> {
2042        self.try_visit_binary(|e, a| Ok(f(e, a)))
2043    }
2044
2045    fn try_visit_binary<Arg: FromSQLExpr>(
2046        &mut self,
2047        f: impl Fn(Expr, Arg) -> PolarsResult<Expr>,
2048    ) -> PolarsResult<Expr> {
2049        let args = extract_args(self.func)?;
2050        match args.as_slice() {
2051            [
2052                FunctionArgExpr::Expr(sql_expr1),
2053                FunctionArgExpr::Expr(sql_expr2),
2054            ] => {
2055                let expr1 = self.parse_sql_arg(sql_expr1)?;
2056                let expr2 = Arg::from_sql_arg(sql_expr2, self)?;
2057                f(expr1, expr2)
2058            },
2059            _ => self.not_supported_error(),
2060        }
2061    }
2062
2063    fn visit_variadic(&mut self, f: impl Fn(&[Expr]) -> Expr) -> PolarsResult<Expr> {
2064        self.try_visit_variadic(|e| Ok(f(e)))
2065    }
2066
2067    fn try_visit_variadic(
2068        &mut self,
2069        f: impl Fn(&[Expr]) -> PolarsResult<Expr>,
2070    ) -> PolarsResult<Expr> {
2071        let args = extract_args(self.func)?;
2072        let mut expr_args = vec![];
2073        for arg in args {
2074            if let FunctionArgExpr::Expr(sql_expr) = arg {
2075                expr_args.push(self.parse_sql_arg(sql_expr)?);
2076            } else {
2077                return self.not_supported_error();
2078            };
2079        }
2080        f(&expr_args)
2081    }
2082
2083    fn try_visit_ternary<Arg: FromSQLExpr>(
2084        &mut self,
2085        f: impl Fn(Expr, Arg, Arg) -> PolarsResult<Expr>,
2086    ) -> PolarsResult<Expr> {
2087        let args = extract_args(self.func)?;
2088        match args.as_slice() {
2089            [
2090                FunctionArgExpr::Expr(sql_expr1),
2091                FunctionArgExpr::Expr(sql_expr2),
2092                FunctionArgExpr::Expr(sql_expr3),
2093            ] => {
2094                let expr1 = self.parse_sql_arg(sql_expr1)?;
2095                let expr2 = Arg::from_sql_arg(sql_expr2, self)?;
2096                let expr3 = Arg::from_sql_arg(sql_expr3, self)?;
2097                f(expr1, expr2, expr3)
2098            },
2099            _ => self.not_supported_error(),
2100        }
2101    }
2102
2103    fn visit_nullary(&self, f: impl Fn() -> Expr) -> PolarsResult<Expr> {
2104        let args = extract_args(self.func)?;
2105        if !args.is_empty() {
2106            return self.not_supported_error();
2107        }
2108        Ok(f())
2109    }
2110
2111    /// Apply in-arg "aggregate modifiers" inside an aggregate's argument
2112    /// list, eg: `ARRAY_AGG(DISTINCT x ORDER BY y LIMIT 5)`. Composes
2113    /// with the visitor-level `FILTER (WHERE …)` clause.
2114    fn apply_aggregate_clauses(
2115        &mut self,
2116        mut base: Expr,
2117        is_distinct: bool,
2118        clauses: &[FunctionArgumentClause],
2119        base_sql_expr: &SQLExpr,
2120        func_name: &str,
2121    ) -> PolarsResult<Expr> {
2122        let mut order_by_clause = None;
2123        let mut limit_clause = None;
2124        for clause in clauses {
2125            match clause {
2126                FunctionArgumentClause::OrderBy(order_exprs) => {
2127                    order_by_clause = Some(order_exprs.as_slice());
2128                },
2129                FunctionArgumentClause::Limit(limit_expr) => {
2130                    limit_clause = Some(limit_expr);
2131                },
2132                _ => {},
2133            }
2134        }
2135        if is_distinct {
2136            // DISTINCT: apply unique first, then sort the deduplicated result.
2137            base = base.unique_stable();
2138            if let Some(order_by) = order_by_clause {
2139                base = self.apply_order_by_to_distinct_array(base, order_by, base_sql_expr)?;
2140            }
2141        } else if let Some(order_by) = order_by_clause {
2142            base = self.apply_order_by(base, order_by)?;
2143        }
2144        if let Some(limit_expr) = limit_clause {
2145            let limit = parse_sql_expr(limit_expr, self.ctx, self.active_schema)?;
2146            match limit {
2147                Expr::Literal(LiteralValue::Dyn(DynLiteralValue::Int(n))) if n >= 0 => {
2148                    base = base.head(Some(n as usize))
2149                },
2150                _ => {
2151                    polars_bail!(SQLSyntax: "LIMIT in {} must be a positive integer", func_name)
2152                },
2153            };
2154        }
2155        Ok(base)
2156    }
2157
2158    fn visit_arr_agg(&mut self) -> PolarsResult<Expr> {
2159        let (args, is_distinct, clauses) = extract_args_and_clauses(self.func)?;
2160        match args.as_slice() {
2161            [FunctionArgExpr::Expr(sql_expr)] => {
2162                let base = self.parse_sql_arg(sql_expr)?;
2163                let base = self.apply_aggregate_clauses(
2164                    base,
2165                    is_distinct,
2166                    &clauses,
2167                    sql_expr,
2168                    "ARRAY_AGG",
2169                )?;
2170                Ok(base.implode(true))
2171            },
2172            _ => {
2173                polars_bail!(SQLSyntax: "ARRAY_AGG must have exactly one argument; found {}", args.len())
2174            },
2175        }
2176    }
2177
2178    fn visit_string_agg(&mut self) -> PolarsResult<Expr> {
2179        let (args, is_distinct, clauses) = extract_args_and_clauses(self.func)?;
2180        let (sql_expr, separator) = match args.as_slice() {
2181            [FunctionArgExpr::Expr(sql_expr)] => (sql_expr, lit(",")),
2182            [
2183                FunctionArgExpr::Expr(sql_expr),
2184                FunctionArgExpr::Expr(sep_sql_expr),
2185            ] => (
2186                sql_expr,
2187                parse_sql_expr(sep_sql_expr, self.ctx, self.active_schema)?,
2188            ),
2189            _ => polars_bail!(
2190                SQLSyntax: "STRING_AGG expects 1-2 arguments (found {})",
2191                args.len()
2192            ),
2193        };
2194        let base = self.parse_sql_arg(sql_expr)?;
2195        let base =
2196            self.apply_aggregate_clauses(base, is_distinct, &clauses, sql_expr, "STRING_AGG")?;
2197        Ok(base
2198            .cast(DataType::String)
2199            .implode(true)
2200            .list()
2201            .join(separator, true))
2202    }
2203
2204    fn visit_arr_to_string(&mut self) -> PolarsResult<Expr> {
2205        let args = extract_args(self.func)?;
2206        match args.len() {
2207            2 => self.try_visit_binary(|e, sep| {
2208                Ok(e.cast(DataType::List(Box::from(DataType::String)))
2209                    .list()
2210                    .join(sep, true))
2211            }),
2212            #[cfg(feature = "list_eval")]
2213            3 => self.try_visit_ternary(|e, sep, null_value| match null_value {
2214                Expr::Literal(lv) if lv.extract_str().is_some() => {
2215                    Ok(if lv.extract_str().unwrap().is_empty() {
2216                        e.cast(DataType::List(Box::from(DataType::String)))
2217                            .list()
2218                            .join(sep, true)
2219                    } else {
2220                        e.cast(DataType::List(Box::from(DataType::String)))
2221                            .list()
2222                            .eval(element().fill_null(lit(lv.extract_str().unwrap())))
2223                            .list()
2224                            .join(sep, false)
2225                    })
2226                },
2227                _ => {
2228                    polars_bail!(SQLSyntax: "invalid null value for ARRAY_TO_STRING ({})", args[2])
2229                },
2230            }),
2231            _ => {
2232                polars_bail!(SQLSyntax: "ARRAY_TO_STRING expects 2-3 arguments (found {})", args.len())
2233            },
2234        }
2235    }
2236
2237    fn visit_count(&mut self) -> PolarsResult<Expr> {
2238        let (args, is_distinct) = extract_args_distinct(self.func)?;
2239
2240        // Window function with an ORDER BY clause?
2241        let has_order_by = match &self.func.over {
2242            Some(WindowType::WindowSpec(spec)) => !spec.order_by.is_empty(),
2243            _ => false,
2244        };
2245        if has_order_by && !is_distinct {
2246            if let Some(WindowType::WindowSpec(spec)) = &self.func.over {
2247                self.validate_window_frame(&spec.window_frame)?;
2248
2249                let is_count_star = match args.as_slice() {
2250                    [FunctionArgExpr::Wildcard] | [] => true,
2251                    [FunctionArgExpr::Expr(e)] => is_non_null_literal(e),
2252                    _ => false,
2253                };
2254                match args.as_slice() {
2255                    _ if is_count_star => {
2256                        // COUNT(*) / COUNT(1) with ORDER BY -> map to `int_range`
2257                        let (order_by_exprs, all_desc) =
2258                            self.parse_order_by_in_window(&spec.order_by)?;
2259                        let partition_by_exprs = if spec.partition_by.is_empty() {
2260                            None
2261                        } else {
2262                            Some(
2263                                spec.partition_by
2264                                    .iter()
2265                                    .map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
2266                                    .collect::<PolarsResult<Vec<_>>>()?,
2267                            )
2268                        };
2269                        let sort_opts = SortOptions::default().with_order_descending(all_desc);
2270                        let row_number = int_range(lit(0), len(), 1, DataType::Int64).add(lit(1)); // SQL is 1-indexed
2271
2272                        return row_number.over_with_options(
2273                            partition_by_exprs,
2274                            Some((order_by_exprs, sort_opts)),
2275                            Default::default(),
2276                        );
2277                    },
2278                    [FunctionArgExpr::Expr(_)] => {
2279                        // COUNT(column) with ORDER BY -> use cum_count
2280                        return self.visit_unary_with_opt_cumulative(
2281                            |e| e.count(),
2282                            |e, reverse| e.cum_count(reverse),
2283                        );
2284                    },
2285                    _ => {},
2286                }
2287            }
2288        }
2289        // COUNT(*), COUNT(1) with FILTER: count rows where the predicate is true.
2290        let count_star = || match &self.filter {
2291            Some(pred) => pred.clone().sum(),
2292            None => len(),
2293        };
2294        let count_expr = match (is_distinct, args.as_slice()) {
2295            // COUNT(*), COUNT()
2296            (false, [FunctionArgExpr::Wildcard] | []) => count_star(),
2297            // COUNT(<non-null literal>) is equivalent to COUNT(*)
2298            (false, [FunctionArgExpr::Expr(sql_expr)]) if is_non_null_literal(sql_expr) => {
2299                count_star()
2300            },
2301            // COUNT(col)
2302            (false, [FunctionArgExpr::Expr(sql_expr)]) => {
2303                let expr = self.parse_sql_arg(sql_expr)?;
2304                expr.count()
2305            },
2306            // COUNT(DISTINCT col)
2307            (true, [FunctionArgExpr::Expr(sql_expr)]) => {
2308                let expr = self.parse_sql_arg(sql_expr)?;
2309                expr.clone().n_unique().sub(expr.null_count().gt(lit(0)))
2310            },
2311            _ => self.not_supported_error()?,
2312        };
2313        self.apply_window_spec(count_expr, &self.func.over)
2314    }
2315
2316    fn apply_order_by(&mut self, expr: Expr, order_by: &[OrderByExpr]) -> PolarsResult<Expr> {
2317        let mut by = Vec::with_capacity(order_by.len());
2318        let mut descending = Vec::with_capacity(order_by.len());
2319        let mut nulls_last = Vec::with_capacity(order_by.len());
2320
2321        for ob in order_by {
2322            // Note: if not specified 'NULLS FIRST' is default for DESC, 'NULLS LAST' otherwise
2323            // https://www.postgresql.org/docs/current/queries-order.html. Also: ORDER BY exprs
2324            // share their length with the (possibly filtered) base, so they have to go through
2325            // `parse_sql_arg` to apply any active FILTER.
2326            let desc_order = !ob.options.asc.unwrap_or(true);
2327            by.push(self.parse_sql_arg(&ob.expr)?);
2328            nulls_last.push(!ob.options.nulls_first.unwrap_or(desc_order));
2329            descending.push(desc_order);
2330        }
2331        Ok(expr.sort_by(
2332            by,
2333            SortMultipleOptions::default()
2334                .with_order_descending_multi(descending)
2335                .with_nulls_last_multi(nulls_last),
2336        ))
2337    }
2338
2339    fn apply_order_by_to_distinct_array(
2340        &mut self,
2341        expr: Expr,
2342        order_by: &[OrderByExpr],
2343        base_sql_expr: &SQLExpr,
2344    ) -> PolarsResult<Expr> {
2345        // If ORDER BY references the base expression, use .sort() directly
2346        if order_by.len() == 1 && order_by[0].expr == *base_sql_expr {
2347            let desc_order = !order_by[0].options.asc.unwrap_or(true);
2348            let nulls_last = !order_by[0].options.nulls_first.unwrap_or(desc_order);
2349            return Ok(expr.sort(
2350                SortOptions::default()
2351                    .with_order_descending(desc_order)
2352                    .with_nulls_last(nulls_last),
2353            ));
2354        }
2355        // Otherwise, fall back to `sort_by` (may need to handle further edge-cases later)
2356        self.apply_order_by(expr, order_by)
2357    }
2358
2359    /// Parse ORDER BY (in OVER clause), validating uniform direction.
2360    fn parse_order_by_in_window(
2361        &mut self,
2362        order_by: &[OrderByExpr],
2363    ) -> PolarsResult<(Vec<Expr>, bool)> {
2364        if order_by.is_empty() {
2365            return Ok((Vec::new(), false));
2366        }
2367        // Parse expressions and validate uniform direction
2368        let all_ascending = order_by[0].options.asc.unwrap_or(true);
2369        let mut exprs = Vec::with_capacity(order_by.len());
2370        for o in order_by {
2371            if all_ascending != o.options.asc.unwrap_or(true) {
2372                // TODO: mixed sort directions are not currently supported; we
2373                //  need to enhance `over_with_options` to take SortMultipleOptions
2374                polars_bail!(
2375                    SQLSyntax:
2376                    "OVER does not (yet) support mixed asc/desc directions for ORDER BY"
2377                )
2378            }
2379            let expr = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?;
2380            exprs.push(expr);
2381        }
2382        Ok((exprs, !all_ascending))
2383    }
2384
2385    fn apply_window_spec(
2386        &mut self,
2387        expr: Expr,
2388        window_type: &Option<WindowType>,
2389    ) -> PolarsResult<Expr> {
2390        let Some(window_type) = window_type else {
2391            return Ok(expr);
2392        };
2393        let window_spec = self.resolve_window_spec(window_type)?;
2394        self.validate_window_frame(&window_spec.window_frame)?;
2395
2396        let partition_by = if window_spec.partition_by.is_empty() {
2397            None
2398        } else {
2399            Some(
2400                window_spec
2401                    .partition_by
2402                    .iter()
2403                    .map(|p| parse_sql_expr(p, self.ctx, self.active_schema))
2404                    .collect::<PolarsResult<Vec<_>>>()?,
2405            )
2406        };
2407        let order_by = if window_spec.order_by.is_empty() {
2408            None
2409        } else {
2410            let (order_exprs, all_desc) = self.parse_order_by_in_window(&window_spec.order_by)?;
2411            let sort_opts = SortOptions::default().with_order_descending(all_desc);
2412            Some((order_exprs, sort_opts))
2413        };
2414
2415        // Apply window spec
2416        Ok(match (partition_by, order_by) {
2417            (None, None) => expr,
2418            (Some(part), None) => expr.over(part)?,
2419            (part, Some(order)) => expr.over_with_options(part, Some(order), Default::default())?,
2420        })
2421    }
2422
2423    fn not_supported_error(&self) -> PolarsResult<Expr> {
2424        polars_bail!(
2425            SQLInterface:
2426            "no function matches the given name and arguments: `{}`",
2427            self.func.to_string()
2428        );
2429    }
2430}
2431
2432/// Returns true if the SQL expression is a non-null literal value (e.g. `1`, `'hello'`, `TRUE`).
2433fn is_non_null_literal(expr: &SQLExpr) -> bool {
2434    matches!(
2435        expr,
2436        SQLExpr::Value(ValueWithSpan {
2437            value: v,
2438            ..
2439        }) if !matches!(v, SQLValue::Null)
2440    )
2441}
2442
2443fn extract_args(func: &SQLFunction) -> PolarsResult<Vec<&FunctionArgExpr>> {
2444    let (args, _, _) = _extract_func_args(func, false, false)?;
2445    Ok(args)
2446}
2447
2448fn extract_args_distinct(func: &SQLFunction) -> PolarsResult<(Vec<&FunctionArgExpr>, bool)> {
2449    let (args, is_distinct, _) = _extract_func_args(func, true, false)?;
2450    Ok((args, is_distinct))
2451}
2452
2453fn extract_args_and_clauses(
2454    func: &SQLFunction,
2455) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
2456    _extract_func_args(func, true, true)
2457}
2458
2459fn _extract_func_args(
2460    func: &SQLFunction,
2461    get_distinct: bool,
2462    get_clauses: bool,
2463) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec<FunctionArgumentClause>)> {
2464    match &func.args {
2465        FunctionArguments::List(FunctionArgumentList {
2466            args,
2467            duplicate_treatment,
2468            clauses,
2469        }) => {
2470            let is_distinct = matches!(duplicate_treatment, Some(DuplicateTreatment::Distinct));
2471            if !(get_clauses || get_distinct) && is_distinct {
2472                polars_bail!(SQLSyntax: "unexpected use of DISTINCT found in '{}'", func.name)
2473            } else if !get_clauses && !clauses.is_empty() {
2474                polars_bail!(SQLSyntax: "unexpected clause found in '{}' ({})", func.name, clauses[0])
2475            } else {
2476                let unpacked_args = args
2477                    .iter()
2478                    .map(|arg| match arg {
2479                        FunctionArg::Named { arg, .. } => arg,
2480                        FunctionArg::ExprNamed { arg, .. } => arg,
2481                        FunctionArg::Unnamed(arg) => arg,
2482                    })
2483                    .collect();
2484                Ok((unpacked_args, is_distinct, clauses.clone()))
2485            }
2486        },
2487        FunctionArguments::Subquery { .. } => {
2488            Err(polars_err!(SQLInterface: "subquery not expected in {}", func.name))
2489        },
2490        FunctionArguments::None => Ok((vec![], false, vec![])),
2491    }
2492}
2493
2494pub(crate) trait FromSQLExpr {
2495    fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
2496    where
2497        Self: Sized;
2498
2499    /// Parse SQL expression as an argument of the function being visited, taking
2500    /// the surrounding visitor into account. Allows active `FILTER (WHERE …)`
2501    /// clauses to be applied to all args without each call knowing about FILTER.
2502    fn from_sql_arg(expr: &SQLExpr, visitor: &mut SQLFunctionVisitor<'_>) -> PolarsResult<Self>
2503    where
2504        Self: Sized,
2505    {
2506        Self::from_sql_expr(expr, visitor.ctx)
2507    }
2508}
2509
2510impl FromSQLExpr for f64 {
2511    fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
2512    where
2513        Self: Sized,
2514    {
2515        match expr {
2516            SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2517                SQLValue::Number(s, _) => s
2518                    .parse()
2519                    .map_err(|_| polars_err!(SQLInterface: "cannot parse literal {:?}", s)),
2520                _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2521            },
2522            _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2523        }
2524    }
2525}
2526
2527impl FromSQLExpr for bool {
2528    fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
2529    where
2530        Self: Sized,
2531    {
2532        match expr {
2533            SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2534                SQLValue::Boolean(v) => Ok(*v),
2535                _ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", v),
2536            },
2537            _ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", expr),
2538        }
2539    }
2540}
2541
2542impl FromSQLExpr for String {
2543    fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2544    where
2545        Self: Sized,
2546    {
2547        match expr {
2548            SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2549                SQLValue::SingleQuotedString(s) => Ok(s.clone()),
2550                _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2551            },
2552            _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2553        }
2554    }
2555}
2556
2557impl FromSQLExpr for StrptimeOptions {
2558    fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
2559    where
2560        Self: Sized,
2561    {
2562        match expr {
2563            SQLExpr::Value(ValueWithSpan { value: v, .. }) => match v {
2564                SQLValue::SingleQuotedString(s) => Ok(StrptimeOptions {
2565                    format: Some(PlSmallStr::from_str(s)),
2566                    ..StrptimeOptions::default()
2567                }),
2568                _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v),
2569            },
2570            _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr),
2571        }
2572    }
2573}
2574
2575impl FromSQLExpr for Expr {
2576    fn from_sql_expr(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Self>
2577    where
2578        Self: Sized,
2579    {
2580        parse_sql_expr(expr, ctx, None)
2581    }
2582
2583    fn from_sql_arg(expr: &SQLExpr, visitor: &mut SQLFunctionVisitor<'_>) -> PolarsResult<Self> {
2584        visitor.parse_sql_arg(expr)
2585    }
2586}