polars_lazy/physical_plan/streaming/
convert_alp.rs

1use polars_core::prelude::*;
2use polars_pipe::pipeline::swap_join_order;
3use polars_plan::prelude::*;
4
5use super::checks::*;
6use crate::physical_plan::streaming::tree::*;
7
8// The index of the pipeline tree we are building at this moment
9// if we have a node we cannot do streaming, we have finished that pipeline tree
10// and start a new one.
11type CurrentIdx = usize;
12
13// Frame in the stack of logical plans to process while inserting streaming nodes
14struct StackFrame {
15    node: Node, // LogicalPlan node
16    state: Branch,
17    current_idx: CurrentIdx,
18    insert_sink: bool,
19}
20
21impl StackFrame {
22    fn root(node: Node) -> StackFrame {
23        StackFrame {
24            node,
25            state: Branch::default(),
26            current_idx: 0,
27            insert_sink: false,
28        }
29    }
30
31    fn new(node: Node, state: Branch, current_idx: CurrentIdx) -> StackFrame {
32        StackFrame {
33            node,
34            state,
35            current_idx,
36            insert_sink: false,
37        }
38    }
39
40    // Create a new streaming subtree below a non-streaming node
41    fn new_subtree(node: Node, current_idx: CurrentIdx) -> StackFrame {
42        StackFrame {
43            node,
44            state: Branch::default(),
45            current_idx,
46            insert_sink: true,
47        }
48    }
49}
50
51fn process_non_streamable_node(
52    current_idx: &mut CurrentIdx,
53    state: &mut Branch,
54    stack: &mut Vec<StackFrame>,
55    scratch: &mut Vec<Node>,
56    pipeline_trees: &mut Vec<Vec<Branch>>,
57    lp: &IR,
58) {
59    lp.copy_inputs(scratch);
60    while let Some(input) = scratch.pop() {
61        if state.streamable {
62            *current_idx += 1;
63            // create a completely new streaming pipeline
64            // maybe we can stream a subsection of the plan
65            pipeline_trees.push(vec![]);
66        }
67        stack.push(StackFrame::new_subtree(input, *current_idx));
68    }
69    state.streamable = false;
70}
71
72fn insert_file_sink(mut root: Node, lp_arena: &mut Arena<IR>) -> Node {
73    // The pipelines need a final sink, we insert that here.
74    // this allows us to split at joins/unions and share a sink
75    if !matches!(lp_arena.get(root), IR::Sink { .. }) {
76        root = lp_arena.add(IR::Sink {
77            input: root,
78            payload: SinkTypeIR::Memory,
79        })
80    }
81    root
82}
83
84pub(crate) fn insert_streaming_nodes(
85    root: Node,
86    lp_arena: &mut Arena<IR>,
87    expr_arena: &mut Arena<AExpr>,
88    scratch: &mut Vec<Node>,
89    fmt: bool,
90    // whether the full plan needs to be translated
91    // to streaming
92    allow_partial: bool,
93    row_estimate: bool,
94) -> PolarsResult<bool> {
95    scratch.clear();
96
97    // This is needed to determine which side of the joins should be
98    // traversed first. As we want to keep the smallest table in the build phase as that keeps most
99    // data in memory.
100    if row_estimate {
101        set_estimated_row_counts(root, lp_arena, expr_arena, 0, scratch);
102    }
103
104    scratch.clear();
105
106    // The pipelines always need to end in a SINK, we insert that here.
107    // this allows us to split at joins/unions and share a sink
108    let root = insert_file_sink(root, lp_arena);
109
110    // We use a bool flag in the stack to communicate when we need to insert a file sink.
111    // This happens for instance when we
112    //
113    //     ________*non-streamable part of query
114    //   /\
115    //     ________*streamable below this line so we must insert
116    //    /\        a file sink here so the pipeline can be built
117    //     /\
118
119    let mut stack = Vec::with_capacity(16);
120
121    stack.push(StackFrame::root(root));
122
123    // A state holds a full pipeline until the breaker
124    //  1/\
125    //   2/\
126    //     3\
127    //
128    // so 1 and 2 are short pipelines and 3 goes all the way to the root.
129    // but 3 can only run if 1 and 2 have finished and set the join as operator in 3
130    // and state are filled with pipeline 1, 2, 3 in that order
131    //
132    //     / \
133    //  /\  3/\
134    //  1 2    4\
135    // or in this case 1, 2, 3, 4
136    // every inner vec contains a branch/pipeline of a complete pipeline tree
137    // the outer vec contains whole pipeline trees
138    //
139    // # Execution order
140    // Trees can have arbitrary splits via joins and unions
141    // the branches we have accumulated are flattened into a single Vec<Branch>
142    // this therefore has lost the information of the tree. To know in which
143    // order the branches need to be executed. For this reason we keep track of
144    // an `execution_id` which will be incremented on every stack operation.
145    // This way we know in which order the stack/tree was traversed and can
146    // use that info to determine the execution order of the single branch/pipelines
147    let mut pipeline_trees: Vec<Tree> = vec![vec![]];
148    // keep the counter global so that the order will match traversal order
149    let mut execution_id = 0;
150
151    use IR::*;
152    while let Some(StackFrame {
153        node: mut root,
154        mut state,
155        mut current_idx,
156        insert_sink,
157    }) = stack.pop()
158    {
159        if insert_sink {
160            root = insert_file_sink(root, lp_arena);
161        }
162        state.execution_id = execution_id;
163        execution_id += 1;
164        match lp_arena.get(root) {
165            Filter { input, predicate } if is_elementwise_rec(predicate.node(), expr_arena) => {
166                state.streamable = true;
167                state.operators_sinks.push(PipelineNode::Operator(root));
168                stack.push(StackFrame::new(*input, state, current_idx))
169            },
170            HStack { input, exprs, .. } if all_elementwise(exprs, expr_arena) => {
171                state.streamable = true;
172                state.operators_sinks.push(PipelineNode::Operator(root));
173                stack.push(StackFrame::new(*input, state, current_idx))
174            },
175            Slice { input, offset, .. } if *offset >= 0 => {
176                state.streamable = true;
177                state.operators_sinks.push(PipelineNode::Sink(root));
178                stack.push(StackFrame::new(*input, state, current_idx))
179            },
180            Sink { input, .. } => {
181                state.streamable = true;
182                state.operators_sinks.push(PipelineNode::Sink(root));
183                stack.push(StackFrame::new(*input, state, current_idx))
184            },
185            Sort {
186                input,
187                by_column,
188                slice,
189                sort_options,
190            } if is_streamable_sort(slice, sort_options) && all_column(by_column, expr_arena) => {
191                state.streamable = true;
192                state.operators_sinks.push(PipelineNode::Sink(root));
193                stack.push(StackFrame::new(*input, state, current_idx))
194            },
195            Select { input, expr, .. } if all_elementwise(expr, expr_arena) => {
196                state.streamable = true;
197                state.operators_sinks.push(PipelineNode::Operator(root));
198                stack.push(StackFrame::new(*input, state, current_idx))
199            },
200            SimpleProjection { input, .. } => {
201                state.streamable = true;
202                state.operators_sinks.push(PipelineNode::Operator(root));
203                stack.push(StackFrame::new(*input, state, current_idx))
204            },
205            // Rechunks are ignored
206            MapFunction {
207                input,
208                function: FunctionIR::Rechunk,
209            } => {
210                state.streamable = true;
211                stack.push(StackFrame::new(*input, state, current_idx))
212            },
213            // Streamable functions will be converted
214            lp @ MapFunction { input, function } => {
215                if function.is_streamable() {
216                    state.streamable = true;
217                    state.operators_sinks.push(PipelineNode::Operator(root));
218                    stack.push(StackFrame::new(*input, state, current_idx))
219                } else {
220                    process_non_streamable_node(
221                        &mut current_idx,
222                        &mut state,
223                        &mut stack,
224                        scratch,
225                        &mut pipeline_trees,
226                        lp,
227                    )
228                }
229            },
230            Scan {
231                scan_type,
232                file_options,
233                ..
234            } if scan_type.streamable()
235                && file_options
236                    .pre_slice
237                    .map(|slice| slice.0 >= 0)
238                    .unwrap_or(true) =>
239            {
240                if state.streamable {
241                    state.sources.push(root);
242                    pipeline_trees[current_idx].push(state)
243                }
244            },
245            DataFrameScan { .. } => {
246                if state.streamable {
247                    state.sources.push(root);
248                    pipeline_trees[current_idx].push(state)
249                }
250            },
251            Join {
252                input_left,
253                input_right,
254                options,
255                ..
256            } if streamable_join(&options.args) => {
257                let input_left = *input_left;
258                let input_right = *input_right;
259                state.streamable = true;
260                state.join_count += 1;
261
262                // We swap so that the build phase contains the smallest table
263                // and then we stream the larger table
264                // *except* for a left join. In a left join we use the right
265                // table as build table and we stream the left table. This way
266                // we maintain order in the left join.
267                let (input_left, input_right) = if swap_join_order(options) {
268                    (input_right, input_left)
269                } else {
270                    (input_left, input_right)
271                };
272                let mut state_left = state.split();
273
274                // Rhs is second, so that is first on the stack.
275                let mut state_right = state;
276                state_right.join_count = 0;
277                state_right
278                    .operators_sinks
279                    .push(PipelineNode::RhsJoin(root));
280
281                // We want to traverse lhs last, so push it first on the stack
282                // rhs is a new pipeline.
283                state_left.operators_sinks.push(PipelineNode::Sink(root));
284                stack.push(StackFrame::new(input_left, state_left, current_idx));
285                stack.push(StackFrame::new(input_right, state_right, current_idx));
286            },
287            // add globbing patterns
288            #[cfg(any(feature = "csv", feature = "parquet"))]
289            Union { inputs, options }
290                if options.slice.is_none()
291                    && inputs.iter().all(|node| match lp_arena.get(*node) {
292                        Scan { .. } => true,
293                        MapFunction {
294                            input,
295                            function: FunctionIR::Rechunk,
296                        } => matches!(lp_arena.get(*input), Scan { .. }),
297                        _ => false,
298                    }) =>
299            {
300                state.sources.push(root);
301                pipeline_trees[current_idx].push(state);
302            },
303            Union { inputs, .. } => {
304                {
305                    state.streamable = true;
306                    for (i, input) in inputs.iter().enumerate() {
307                        let mut state = if i == 0 {
308                            // note the clone!
309                            let mut state = state.clone();
310                            state.join_count += inputs.len() as u32 - 1;
311                            state
312                        } else {
313                            let mut state = state.split_from_sink();
314                            state.join_count = 0;
315                            state
316                        };
317                        state.operators_sinks.push(PipelineNode::Union(root));
318                        stack.push(StackFrame::new(*input, state, current_idx));
319                    }
320                }
321            },
322            Distinct { input, options }
323                if !options.maintain_order
324                    && !matches!(options.keep_strategy, UniqueKeepStrategy::None) =>
325            {
326                state.streamable = true;
327                state.operators_sinks.push(PipelineNode::Sink(root));
328                stack.push(StackFrame::new(*input, state, current_idx))
329            },
330            #[allow(unused_variables)]
331            lp @ GroupBy {
332                input,
333                keys,
334                aggs,
335                maintain_order: false,
336                apply: None,
337                schema: output_schema,
338                options,
339                ..
340            } => {
341                #[cfg(feature = "dtype-categorical")]
342                let string_cache = polars_core::using_string_cache();
343                #[cfg(not(feature = "dtype-categorical"))]
344                let string_cache = true;
345
346                #[allow(unused_variables)]
347                fn allowed_dtype(dt: &DataType, string_cache: bool) -> bool {
348                    match dt {
349                        #[cfg(feature = "object")]
350                        DataType::Object(_) => false,
351                        #[cfg(feature = "dtype-categorical")]
352                        DataType::Categorical(_, _) => string_cache,
353                        DataType::List(inner) => allowed_dtype(inner, string_cache),
354                        #[cfg(feature = "dtype-struct")]
355                        DataType::Struct(fields) => fields
356                            .iter()
357                            .all(|fld| allowed_dtype(fld.dtype(), string_cache)),
358                        // We need to be able to sink to disk or produce the aggregate return dtype.
359                        DataType::Unknown(_) => false,
360                        #[cfg(feature = "dtype-decimal")]
361                        DataType::Decimal(_, _) => false,
362                        DataType::Int128 => false,
363                        _ => true,
364                    }
365                }
366                let input_schema = lp_arena.get(*input).schema(lp_arena);
367                #[allow(unused_mut)]
368                let mut can_stream = true;
369
370                #[cfg(feature = "dynamic_group_by")]
371                {
372                    if options.rolling.is_some() || options.dynamic.is_some() {
373                        can_stream = false
374                    }
375                }
376
377                let valid_agg = || {
378                    aggs.iter().all(|e| {
379                        polars_pipe::pipeline::can_convert_to_hash_agg(
380                            e.node(),
381                            expr_arena,
382                            &input_schema,
383                        )
384                    })
385                };
386
387                let valid_key = || {
388                    keys.iter().all(|e| {
389                        output_schema
390                            .get(e.output_name())
391                            .map(|dt| !matches!(dt, DataType::List(_)))
392                            .unwrap_or(false)
393                    })
394                };
395
396                let valid_types = || {
397                    output_schema
398                        .iter_values()
399                        .all(|dt| allowed_dtype(dt, string_cache))
400                };
401
402                if can_stream && valid_agg() && valid_key() && valid_types() {
403                    state.streamable = true;
404                    state.operators_sinks.push(PipelineNode::Sink(root));
405                    stack.push(StackFrame::new(*input, state, current_idx))
406                } else if allow_partial {
407                    process_non_streamable_node(
408                        &mut current_idx,
409                        &mut state,
410                        &mut stack,
411                        scratch,
412                        &mut pipeline_trees,
413                        lp,
414                    )
415                } else {
416                    return Ok(false);
417                }
418            },
419            lp => {
420                if allow_partial {
421                    process_non_streamable_node(
422                        &mut current_idx,
423                        &mut state,
424                        &mut stack,
425                        scratch,
426                        &mut pipeline_trees,
427                        lp,
428                    )
429                } else {
430                    return Ok(false);
431                }
432            },
433        }
434    }
435
436    let mut inserted = false;
437    for tree in pipeline_trees {
438        if is_valid_tree(&tree)
439            && super::construct_pipeline::construct(tree, lp_arena, expr_arena, fmt)?.is_some()
440        {
441            inserted = true;
442        }
443    }
444
445    Ok(inserted)
446}