polars_io/
predicates.rs

1use std::fmt;
2
3use arrow::array::Array;
4use arrow::bitmap::{Bitmap, BitmapBuilder};
5use polars_core::prelude::*;
6#[cfg(feature = "parquet")]
7use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, SpecializedParquetColumnExpr};
8use polars_utils::format_pl_smallstr;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12pub trait PhysicalIoExpr: Send + Sync {
13    /// Take a [`DataFrame`] and produces a boolean [`Series`] that serves
14    /// as a predicate mask
15    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series>;
16}
17
18#[derive(Debug, Clone)]
19pub enum SpecializedColumnPredicate {
20    Equal(Scalar),
21
22    Between(Scalar, Scalar),
23
24    EqualOneOf(Box<[Scalar]>),
25
26    StartsWith(Box<[u8]>),
27    EndsWith(Box<[u8]>),
28    StartEndsWith(Box<[u8]>, Box<[u8]>),
29}
30
31#[derive(Clone)]
32pub struct ColumnPredicateExpr {
33    column_name: PlSmallStr,
34    dtype: DataType,
35    #[cfg(feature = "parquet")]
36    specialized: Option<SpecializedParquetColumnExpr>,
37    expr: Arc<dyn PhysicalIoExpr>,
38}
39
40impl ColumnPredicateExpr {
41    pub fn new(
42        column_name: PlSmallStr,
43        dtype: DataType,
44        expr: Arc<dyn PhysicalIoExpr>,
45        specialized: Option<SpecializedColumnPredicate>,
46    ) -> Self {
47        use SpecializedColumnPredicate as S;
48        #[cfg(feature = "parquet")]
49        use SpecializedParquetColumnExpr as P;
50        #[cfg(feature = "parquet")]
51        let specialized = specialized.and_then(|s| {
52            Some(match s {
53                S::Equal(s) => P::Equal(cast_to_parquet_scalar(s)?),
54                S::Between(low, high) => {
55                    P::Between(cast_to_parquet_scalar(low)?, cast_to_parquet_scalar(high)?)
56                },
57                S::EqualOneOf(scalars) => P::EqualOneOf(
58                    scalars
59                        .into_iter()
60                        .map(|s| cast_to_parquet_scalar(s).ok_or(()))
61                        .collect::<Result<Box<_>, ()>>()
62                        .ok()?,
63                ),
64                S::StartsWith(s) => P::StartsWith(s),
65                S::EndsWith(s) => P::EndsWith(s),
66                S::StartEndsWith(start, end) => P::StartEndsWith(start, end),
67            })
68        });
69
70        Self {
71            column_name,
72            dtype,
73            #[cfg(feature = "parquet")]
74            specialized,
75            expr,
76        }
77    }
78}
79
80#[cfg(feature = "parquet")]
81impl ParquetColumnExpr for ColumnPredicateExpr {
82    fn evaluate_mut(&self, values: &dyn Array, bm: &mut BitmapBuilder) {
83        // We should never evaluate nulls with this.
84        assert!(values.validity().is_none_or(|v| v.set_bits() == 0));
85
86        // @TODO: Probably these unwraps should be removed.
87        let series =
88            Series::from_chunk_and_dtype(self.column_name.clone(), values.to_boxed(), &self.dtype)
89                .unwrap();
90        let column = series.into_column();
91        let df = unsafe { DataFrame::new_no_checks(values.len(), vec![column]) };
92
93        // @TODO: Probably these unwraps should be removed.
94        let true_mask = self.expr.evaluate_io(&df).unwrap();
95        let true_mask = true_mask.bool().unwrap();
96
97        bm.reserve(true_mask.len());
98        for chunk in true_mask.downcast_iter() {
99            match chunk.validity() {
100                None => bm.extend_from_bitmap(chunk.values()),
101                Some(v) => bm.extend_from_bitmap(&(chunk.values() & v)),
102            }
103        }
104    }
105    fn evaluate_null(&self) -> bool {
106        let column = Column::full_null(self.column_name.clone(), 1, &self.dtype);
107        let df = unsafe { DataFrame::new_no_checks(1, vec![column]) };
108
109        // @TODO: Probably these unwraps should be removed.
110        let true_mask = self.expr.evaluate_io(&df).unwrap();
111        let true_mask = true_mask.bool().unwrap();
112
113        true_mask.get(0).unwrap_or(false)
114    }
115
116    fn as_specialized(&self) -> Option<&SpecializedParquetColumnExpr> {
117        self.specialized.as_ref()
118    }
119}
120
121#[cfg(feature = "parquet")]
122fn cast_to_parquet_scalar(scalar: Scalar) -> Option<ParquetScalar> {
123    use {AnyValue as A, ParquetScalar as P};
124
125    Some(match scalar.into_value() {
126        A::Null => P::Null,
127        A::Boolean(v) => P::Boolean(v),
128
129        A::UInt8(v) => P::UInt8(v),
130        A::UInt16(v) => P::UInt16(v),
131        A::UInt32(v) => P::UInt32(v),
132        A::UInt64(v) => P::UInt64(v),
133
134        A::Int8(v) => P::Int8(v),
135        A::Int16(v) => P::Int16(v),
136        A::Int32(v) => P::Int32(v),
137        A::Int64(v) => P::Int64(v),
138
139        #[cfg(feature = "dtype-time")]
140        A::Date(v) => P::Int32(v),
141        #[cfg(feature = "dtype-datetime")]
142        A::Datetime(v, _, _) | A::DatetimeOwned(v, _, _) => P::Int64(v),
143        #[cfg(feature = "dtype-duration")]
144        A::Duration(v, _) => P::Int64(v),
145        #[cfg(feature = "dtype-time")]
146        A::Time(v) => P::Int64(v),
147
148        A::Float32(v) => P::Float32(v),
149        A::Float64(v) => P::Float64(v),
150
151        // @TODO: Cast to string
152        #[cfg(feature = "dtype-categorical")]
153        A::Categorical(_, _) | A::CategoricalOwned(_, _) | A::Enum(_, _) | A::EnumOwned(_, _) => {
154            return None;
155        },
156
157        A::String(v) => P::String(v.into()),
158        A::StringOwned(v) => P::String(v.as_str().into()),
159        A::Binary(v) => P::Binary(v.into()),
160        A::BinaryOwned(v) => P::Binary(v.into()),
161        _ => return None,
162    })
163}
164
165#[cfg(any(feature = "parquet", feature = "ipc"))]
166pub fn apply_predicate(
167    df: &mut DataFrame,
168    predicate: Option<&dyn PhysicalIoExpr>,
169    parallel: bool,
170) -> PolarsResult<()> {
171    if let (Some(predicate), false) = (&predicate, df.get_columns().is_empty()) {
172        let s = predicate.evaluate_io(df)?;
173        let mask = s.bool().expect("filter predicates was not of type boolean");
174
175        if parallel {
176            *df = df.filter(mask)?;
177        } else {
178            *df = df._filter_seq(mask)?;
179        }
180    }
181    Ok(())
182}
183
184/// Statistics of the values in a column.
185///
186/// The following statistics are tracked for each row group:
187/// - Null count
188/// - Minimum value
189/// - Maximum value
190#[derive(Debug, Clone)]
191#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
192pub struct ColumnStats {
193    field: Field,
194    // Each Series contains the stats for each row group.
195    null_count: Option<Series>,
196    min_value: Option<Series>,
197    max_value: Option<Series>,
198}
199
200impl ColumnStats {
201    /// Constructs a new [`ColumnStats`].
202    pub fn new(
203        field: Field,
204        null_count: Option<Series>,
205        min_value: Option<Series>,
206        max_value: Option<Series>,
207    ) -> Self {
208        Self {
209            field,
210            null_count,
211            min_value,
212            max_value,
213        }
214    }
215
216    /// Constructs a new [`ColumnStats`] with only the [`Field`] information and no statistics.
217    pub fn from_field(field: Field) -> Self {
218        Self {
219            field,
220            null_count: None,
221            min_value: None,
222            max_value: None,
223        }
224    }
225
226    /// Constructs a new [`ColumnStats`] from a single-value Series.
227    pub fn from_column_literal(s: Series) -> Self {
228        debug_assert_eq!(s.len(), 1);
229        Self {
230            field: s.field().into_owned(),
231            null_count: None,
232            min_value: Some(s.clone()),
233            max_value: Some(s),
234        }
235    }
236
237    pub fn field_name(&self) -> &PlSmallStr {
238        self.field.name()
239    }
240
241    /// Returns the [`DataType`] of the column.
242    pub fn dtype(&self) -> &DataType {
243        self.field.dtype()
244    }
245
246    /// Returns the null count of each row group of the column.
247    pub fn get_null_count_state(&self) -> Option<&Series> {
248        self.null_count.as_ref()
249    }
250
251    /// Returns the minimum value of each row group of the column.
252    pub fn get_min_state(&self) -> Option<&Series> {
253        self.min_value.as_ref()
254    }
255
256    /// Returns the maximum value of each row group of the column.
257    pub fn get_max_state(&self) -> Option<&Series> {
258        self.max_value.as_ref()
259    }
260
261    /// Returns the null count of the column.
262    pub fn null_count(&self) -> Option<usize> {
263        match self.dtype() {
264            #[cfg(feature = "dtype-struct")]
265            DataType::Struct(_) => None,
266            _ => {
267                let s = self.get_null_count_state()?;
268                // if all null, there are no statistics.
269                if s.null_count() != s.len() {
270                    s.sum().ok()
271                } else {
272                    None
273                }
274            },
275        }
276    }
277
278    /// Returns the minimum and maximum values of the column as a single [`Series`].
279    pub fn to_min_max(&self) -> Option<Series> {
280        let min_val = self.get_min_state()?;
281        let max_val = self.get_max_state()?;
282        let dtype = self.dtype();
283
284        if !use_min_max(dtype) {
285            return None;
286        }
287
288        let mut min_max_values = min_val.clone();
289        min_max_values.append(max_val).unwrap();
290        if min_max_values.null_count() > 0 {
291            None
292        } else {
293            Some(min_max_values)
294        }
295    }
296
297    /// Returns the minimum value of the column as a single-value [`Series`].
298    ///
299    /// Returns `None` if no maximum value is available.
300    pub fn to_min(&self) -> Option<&Series> {
301        // @scalar-opt
302        let min_val = self.min_value.as_ref()?;
303        let dtype = min_val.dtype();
304
305        if !use_min_max(dtype) || min_val.len() != 1 {
306            return None;
307        }
308
309        if min_val.null_count() > 0 {
310            None
311        } else {
312            Some(min_val)
313        }
314    }
315
316    /// Returns the maximum value of the column as a single-value [`Series`].
317    ///
318    /// Returns `None` if no maximum value is available.
319    pub fn to_max(&self) -> Option<&Series> {
320        // @scalar-opt
321        let max_val = self.max_value.as_ref()?;
322        let dtype = max_val.dtype();
323
324        if !use_min_max(dtype) || max_val.len() != 1 {
325            return None;
326        }
327
328        if max_val.null_count() > 0 {
329            None
330        } else {
331            Some(max_val)
332        }
333    }
334}
335
336/// Returns whether the [`DataType`] supports minimum/maximum operations.
337fn use_min_max(dtype: &DataType) -> bool {
338    dtype.is_primitive_numeric()
339        || dtype.is_temporal()
340        || matches!(
341            dtype,
342            DataType::String | DataType::Binary | DataType::Boolean
343        )
344}
345
346pub struct ColumnStatistics {
347    pub dtype: DataType,
348    pub min: AnyValue<'static>,
349    pub max: AnyValue<'static>,
350    pub null_count: Option<IdxSize>,
351}
352
353pub trait SkipBatchPredicate: Send + Sync {
354    fn schema(&self) -> &SchemaRef;
355
356    fn can_skip_batch(
357        &self,
358        batch_size: IdxSize,
359        live_columns: &PlIndexSet<PlSmallStr>,
360        mut statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
361    ) -> PolarsResult<bool> {
362        let mut columns = Vec::with_capacity(1 + live_columns.len() * 3);
363
364        columns.push(Column::new_scalar(
365            PlSmallStr::from_static("len"),
366            Scalar::new(IDX_DTYPE, batch_size.into()),
367            1,
368        ));
369
370        for col in live_columns.iter() {
371            let dtype = self.schema().get(col).unwrap();
372            let (min, max, nc) = match statistics.swap_remove(col) {
373                None => (
374                    Scalar::null(dtype.clone()),
375                    Scalar::null(dtype.clone()),
376                    Scalar::null(IDX_DTYPE),
377                ),
378                Some(stat) => (
379                    Scalar::new(dtype.clone(), stat.min),
380                    Scalar::new(dtype.clone(), stat.max),
381                    Scalar::new(
382                        IDX_DTYPE,
383                        stat.null_count.map_or(AnyValue::Null, |nc| nc.into()),
384                    ),
385                ),
386            };
387            columns.extend([
388                Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1),
389                Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1),
390                Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1),
391            ]);
392        }
393
394        // SAFETY:
395        // * Each column is length = 1
396        // * We have an IndexSet, so each column name is unique
397        let df = unsafe { DataFrame::new_no_checks(1, columns) };
398        Ok(self.evaluate_with_stat_df(&df)?.get_bit(0))
399    }
400    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap>;
401}
402
403#[derive(Clone)]
404pub struct ColumnPredicates {
405    pub predicates:
406        PlHashMap<PlSmallStr, (Arc<dyn PhysicalIoExpr>, Option<SpecializedColumnPredicate>)>,
407    pub is_sumwise_complete: bool,
408}
409
410// I want to be explicit here.
411#[allow(clippy::derivable_impls)]
412impl Default for ColumnPredicates {
413    fn default() -> Self {
414        Self {
415            predicates: PlHashMap::default(),
416            is_sumwise_complete: false,
417        }
418    }
419}
420
421pub struct PhysicalExprWithConstCols<T> {
422    constants: Vec<(PlSmallStr, Scalar)>,
423    child: T,
424}
425
426impl SkipBatchPredicate for PhysicalExprWithConstCols<Arc<dyn SkipBatchPredicate>> {
427    fn schema(&self) -> &SchemaRef {
428        self.child.schema()
429    }
430
431    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
432        let mut df = df.clone();
433        for (name, scalar) in self.constants.iter() {
434            df.with_column(Column::new_scalar(
435                name.clone(),
436                scalar.clone(),
437                df.height(),
438            ))?;
439        }
440        self.child.evaluate_with_stat_df(&df)
441    }
442}
443
444impl PhysicalIoExpr for PhysicalExprWithConstCols<Arc<dyn PhysicalIoExpr>> {
445    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
446        let mut df = df.clone();
447        for (name, scalar) in self.constants.iter() {
448            df.with_column(Column::new_scalar(
449                name.clone(),
450                scalar.clone(),
451                df.height(),
452            ))?;
453        }
454
455        self.child.evaluate_io(&df)
456    }
457}
458
459#[derive(Clone)]
460pub struct ScanIOPredicate {
461    pub predicate: Arc<dyn PhysicalIoExpr>,
462
463    /// Column names that are used in the predicate.
464    pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
465
466    /// A predicate that gets given statistics and evaluates whether a batch can be skipped.
467    pub skip_batch_predicate: Option<Arc<dyn SkipBatchPredicate>>,
468
469    /// A predicate that gets given statistics and evaluates whether a batch can be skipped.
470    pub column_predicates: Arc<ColumnPredicates>,
471
472    /// Predicate parts only referring to hive columns.
473    pub hive_predicate: Option<Arc<dyn PhysicalIoExpr>>,
474
475    pub hive_predicate_is_full_predicate: bool,
476}
477
478impl ScanIOPredicate {
479    pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) {
480        if constant_columns.is_empty() {
481            return;
482        }
483
484        let mut live_columns = self.live_columns.as_ref().clone();
485        for (c, _) in constant_columns.iter() {
486            live_columns.swap_remove(c);
487        }
488        self.live_columns = Arc::new(live_columns);
489
490        if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() {
491            let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3);
492            for (c, v) in constant_columns.iter() {
493                sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone()));
494                sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone()));
495                let nc = if v.is_null() {
496                    AnyValue::Null
497                } else {
498                    (0 as IdxSize).into()
499                };
500                sbp_constant_columns
501                    .push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc)));
502            }
503            self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols {
504                constants: sbp_constant_columns,
505                child: skip_batch_predicate,
506            }));
507        }
508
509        let mut column_predicates = self.column_predicates.as_ref().clone();
510        for (c, _) in constant_columns.iter() {
511            column_predicates.predicates.remove(c);
512        }
513        self.column_predicates = Arc::new(column_predicates);
514
515        self.predicate = Arc::new(PhysicalExprWithConstCols {
516            constants: constant_columns,
517            child: self.predicate.clone(),
518        });
519    }
520}
521
522impl fmt::Debug for ScanIOPredicate {
523    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
524        f.write_str("scan_io_predicate")
525    }
526}