polars_io/
predicates.rs

1use std::fmt;
2
3use arrow::array::Array;
4use arrow::bitmap::{Bitmap, BitmapBuilder};
5use polars_core::prelude::*;
6#[cfg(feature = "parquet")]
7use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, SpecializedParquetColumnExpr};
8use polars_utils::format_pl_smallstr;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12pub trait PhysicalIoExpr: Send + Sync {
13    /// Take a [`DataFrame`] and produces a boolean [`Series`] that serves
14    /// as a predicate mask
15    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series>;
16}
17
18#[derive(Debug, Clone)]
19pub enum SpecializedColumnPredicate {
20    Equal(Scalar),
21    /// A closed (inclusive) range.
22    Between(Scalar, Scalar),
23    EqualOneOf(Box<[Scalar]>),
24    StartsWith(Box<[u8]>),
25    EndsWith(Box<[u8]>),
26    RegexMatch(regex::bytes::Regex),
27}
28
29#[derive(Clone)]
30pub struct ColumnPredicateExpr {
31    column_name: PlSmallStr,
32    dtype: DataType,
33    #[cfg(feature = "parquet")]
34    specialized: Option<SpecializedParquetColumnExpr>,
35    expr: Arc<dyn PhysicalIoExpr>,
36}
37
38impl ColumnPredicateExpr {
39    pub fn new(
40        column_name: PlSmallStr,
41        dtype: DataType,
42        expr: Arc<dyn PhysicalIoExpr>,
43        specialized: Option<SpecializedColumnPredicate>,
44    ) -> Self {
45        use SpecializedColumnPredicate as S;
46        #[cfg(feature = "parquet")]
47        use SpecializedParquetColumnExpr as P;
48        #[cfg(feature = "parquet")]
49        let specialized = specialized.and_then(|s| {
50            Some(match s {
51                S::Equal(s) => P::Equal(cast_to_parquet_scalar(s)?),
52                S::Between(low, high) => {
53                    P::Between(cast_to_parquet_scalar(low)?, cast_to_parquet_scalar(high)?)
54                },
55                S::EqualOneOf(scalars) => P::EqualOneOf(
56                    scalars
57                        .into_iter()
58                        .map(|s| cast_to_parquet_scalar(s).ok_or(()))
59                        .collect::<Result<Box<_>, ()>>()
60                        .ok()?,
61                ),
62                S::StartsWith(s) => P::StartsWith(s),
63                S::EndsWith(s) => P::EndsWith(s),
64                S::RegexMatch(s) => P::RegexMatch(s),
65            })
66        });
67
68        Self {
69            column_name,
70            dtype,
71            #[cfg(feature = "parquet")]
72            specialized,
73            expr,
74        }
75    }
76}
77
78#[cfg(feature = "parquet")]
79impl ParquetColumnExpr for ColumnPredicateExpr {
80    fn evaluate_mut(&self, values: &dyn Array, bm: &mut BitmapBuilder) {
81        // We should never evaluate nulls with this.
82        assert!(values.validity().is_none_or(|v| v.set_bits() == 0));
83
84        // @TODO: Probably these unwraps should be removed.
85        let series =
86            Series::from_chunk_and_dtype(self.column_name.clone(), values.to_boxed(), &self.dtype)
87                .unwrap();
88        let column = series.into_column();
89        let df = unsafe { DataFrame::new_no_checks(values.len(), vec![column]) };
90
91        // @TODO: Probably these unwraps should be removed.
92        let true_mask = self.expr.evaluate_io(&df).unwrap();
93        let true_mask = true_mask.bool().unwrap();
94
95        bm.reserve(true_mask.len());
96        for chunk in true_mask.downcast_iter() {
97            match chunk.validity() {
98                None => bm.extend_from_bitmap(chunk.values()),
99                Some(v) => bm.extend_from_bitmap(&(chunk.values() & v)),
100            }
101        }
102    }
103    fn evaluate_null(&self) -> bool {
104        let column = Column::full_null(self.column_name.clone(), 1, &self.dtype);
105        let df = unsafe { DataFrame::new_no_checks(1, vec![column]) };
106
107        // @TODO: Probably these unwraps should be removed.
108        let true_mask = self.expr.evaluate_io(&df).unwrap();
109        let true_mask = true_mask.bool().unwrap();
110
111        true_mask.get(0).unwrap_or(false)
112    }
113
114    fn as_specialized(&self) -> Option<&SpecializedParquetColumnExpr> {
115        self.specialized.as_ref()
116    }
117}
118
119#[cfg(feature = "parquet")]
120fn cast_to_parquet_scalar(scalar: Scalar) -> Option<ParquetScalar> {
121    use {AnyValue as A, ParquetScalar as P};
122
123    Some(match scalar.into_value() {
124        A::Null => P::Null,
125        A::Boolean(v) => P::Boolean(v),
126
127        A::UInt8(v) => P::UInt8(v),
128        A::UInt16(v) => P::UInt16(v),
129        A::UInt32(v) => P::UInt32(v),
130        A::UInt64(v) => P::UInt64(v),
131
132        A::Int8(v) => P::Int8(v),
133        A::Int16(v) => P::Int16(v),
134        A::Int32(v) => P::Int32(v),
135        A::Int64(v) => P::Int64(v),
136
137        #[cfg(feature = "dtype-time")]
138        A::Date(v) => P::Int32(v),
139        #[cfg(feature = "dtype-datetime")]
140        A::Datetime(v, _, _) | A::DatetimeOwned(v, _, _) => P::Int64(v),
141        #[cfg(feature = "dtype-duration")]
142        A::Duration(v, _) => P::Int64(v),
143        #[cfg(feature = "dtype-time")]
144        A::Time(v) => P::Int64(v),
145
146        A::Float32(v) => P::Float32(v),
147        A::Float64(v) => P::Float64(v),
148
149        // @TODO: Cast to string
150        #[cfg(feature = "dtype-categorical")]
151        A::Categorical(_, _) | A::CategoricalOwned(_, _) | A::Enum(_, _) | A::EnumOwned(_, _) => {
152            return None;
153        },
154
155        A::String(v) => P::String(v.into()),
156        A::StringOwned(v) => P::String(v.as_str().into()),
157        A::Binary(v) => P::Binary(v.into()),
158        A::BinaryOwned(v) => P::Binary(v.into()),
159        _ => return None,
160    })
161}
162
163#[cfg(any(feature = "parquet", feature = "ipc"))]
164pub fn apply_predicate(
165    df: &mut DataFrame,
166    predicate: Option<&dyn PhysicalIoExpr>,
167    parallel: bool,
168) -> PolarsResult<()> {
169    if let (Some(predicate), false) = (&predicate, df.get_columns().is_empty()) {
170        let s = predicate.evaluate_io(df)?;
171        let mask = s.bool().expect("filter predicates was not of type boolean");
172
173        if parallel {
174            *df = df.filter(mask)?;
175        } else {
176            *df = df._filter_seq(mask)?;
177        }
178    }
179    Ok(())
180}
181
182/// Statistics of the values in a column.
183///
184/// The following statistics are tracked for each row group:
185/// - Null count
186/// - Minimum value
187/// - Maximum value
188#[derive(Debug, Clone)]
189#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
190pub struct ColumnStats {
191    field: Field,
192    // Each Series contains the stats for each row group.
193    null_count: Option<Series>,
194    min_value: Option<Series>,
195    max_value: Option<Series>,
196}
197
198impl ColumnStats {
199    /// Constructs a new [`ColumnStats`].
200    pub fn new(
201        field: Field,
202        null_count: Option<Series>,
203        min_value: Option<Series>,
204        max_value: Option<Series>,
205    ) -> Self {
206        Self {
207            field,
208            null_count,
209            min_value,
210            max_value,
211        }
212    }
213
214    /// Constructs a new [`ColumnStats`] with only the [`Field`] information and no statistics.
215    pub fn from_field(field: Field) -> Self {
216        Self {
217            field,
218            null_count: None,
219            min_value: None,
220            max_value: None,
221        }
222    }
223
224    /// Constructs a new [`ColumnStats`] from a single-value Series.
225    pub fn from_column_literal(s: Series) -> Self {
226        debug_assert_eq!(s.len(), 1);
227        Self {
228            field: s.field().into_owned(),
229            null_count: None,
230            min_value: Some(s.clone()),
231            max_value: Some(s),
232        }
233    }
234
235    pub fn field_name(&self) -> &PlSmallStr {
236        self.field.name()
237    }
238
239    /// Returns the [`DataType`] of the column.
240    pub fn dtype(&self) -> &DataType {
241        self.field.dtype()
242    }
243
244    /// Returns the null count of each row group of the column.
245    pub fn get_null_count_state(&self) -> Option<&Series> {
246        self.null_count.as_ref()
247    }
248
249    /// Returns the minimum value of each row group of the column.
250    pub fn get_min_state(&self) -> Option<&Series> {
251        self.min_value.as_ref()
252    }
253
254    /// Returns the maximum value of each row group of the column.
255    pub fn get_max_state(&self) -> Option<&Series> {
256        self.max_value.as_ref()
257    }
258
259    /// Returns the null count of the column.
260    pub fn null_count(&self) -> Option<usize> {
261        match self.dtype() {
262            #[cfg(feature = "dtype-struct")]
263            DataType::Struct(_) => None,
264            _ => {
265                let s = self.get_null_count_state()?;
266                // if all null, there are no statistics.
267                if s.null_count() != s.len() {
268                    s.sum().ok()
269                } else {
270                    None
271                }
272            },
273        }
274    }
275
276    /// Returns the minimum and maximum values of the column as a single [`Series`].
277    pub fn to_min_max(&self) -> Option<Series> {
278        let min_val = self.get_min_state()?;
279        let max_val = self.get_max_state()?;
280        let dtype = self.dtype();
281
282        if !use_min_max(dtype) {
283            return None;
284        }
285
286        let mut min_max_values = min_val.clone();
287        min_max_values.append(max_val).unwrap();
288        if min_max_values.null_count() > 0 {
289            None
290        } else {
291            Some(min_max_values)
292        }
293    }
294
295    /// Returns the minimum value of the column as a single-value [`Series`].
296    ///
297    /// Returns `None` if no maximum value is available.
298    pub fn to_min(&self) -> Option<&Series> {
299        // @scalar-opt
300        let min_val = self.min_value.as_ref()?;
301        let dtype = min_val.dtype();
302
303        if !use_min_max(dtype) || min_val.len() != 1 {
304            return None;
305        }
306
307        if min_val.null_count() > 0 {
308            None
309        } else {
310            Some(min_val)
311        }
312    }
313
314    /// Returns the maximum value of the column as a single-value [`Series`].
315    ///
316    /// Returns `None` if no maximum value is available.
317    pub fn to_max(&self) -> Option<&Series> {
318        // @scalar-opt
319        let max_val = self.max_value.as_ref()?;
320        let dtype = max_val.dtype();
321
322        if !use_min_max(dtype) || max_val.len() != 1 {
323            return None;
324        }
325
326        if max_val.null_count() > 0 {
327            None
328        } else {
329            Some(max_val)
330        }
331    }
332}
333
334/// Returns whether the [`DataType`] supports minimum/maximum operations.
335fn use_min_max(dtype: &DataType) -> bool {
336    dtype.is_primitive_numeric()
337        || dtype.is_temporal()
338        || matches!(
339            dtype,
340            DataType::String | DataType::Binary | DataType::Boolean
341        )
342}
343
344pub struct ColumnStatistics {
345    pub dtype: DataType,
346    pub min: AnyValue<'static>,
347    pub max: AnyValue<'static>,
348    pub null_count: Option<IdxSize>,
349}
350
351pub trait SkipBatchPredicate: Send + Sync {
352    fn schema(&self) -> &SchemaRef;
353
354    fn can_skip_batch(
355        &self,
356        batch_size: IdxSize,
357        live_columns: &PlIndexSet<PlSmallStr>,
358        mut statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
359    ) -> PolarsResult<bool> {
360        let mut columns = Vec::with_capacity(1 + live_columns.len() * 3);
361
362        columns.push(Column::new_scalar(
363            PlSmallStr::from_static("len"),
364            Scalar::new(IDX_DTYPE, batch_size.into()),
365            1,
366        ));
367
368        for col in live_columns.iter() {
369            let dtype = self.schema().get(col).unwrap();
370            let (min, max, nc) = match statistics.swap_remove(col) {
371                None => (
372                    Scalar::null(dtype.clone()),
373                    Scalar::null(dtype.clone()),
374                    Scalar::null(IDX_DTYPE),
375                ),
376                Some(stat) => (
377                    Scalar::new(dtype.clone(), stat.min),
378                    Scalar::new(dtype.clone(), stat.max),
379                    Scalar::new(
380                        IDX_DTYPE,
381                        stat.null_count.map_or(AnyValue::Null, |nc| nc.into()),
382                    ),
383                ),
384            };
385            columns.extend([
386                Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1),
387                Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1),
388                Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1),
389            ]);
390        }
391
392        // SAFETY:
393        // * Each column is length = 1
394        // * We have an IndexSet, so each column name is unique
395        let df = unsafe { DataFrame::new_no_checks(1, columns) };
396        Ok(self.evaluate_with_stat_df(&df)?.get_bit(0))
397    }
398    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap>;
399}
400
401#[derive(Clone)]
402pub struct ColumnPredicates {
403    pub predicates:
404        PlHashMap<PlSmallStr, (Arc<dyn PhysicalIoExpr>, Option<SpecializedColumnPredicate>)>,
405    pub is_sumwise_complete: bool,
406}
407
408// I want to be explicit here.
409#[allow(clippy::derivable_impls)]
410impl Default for ColumnPredicates {
411    fn default() -> Self {
412        Self {
413            predicates: PlHashMap::default(),
414            is_sumwise_complete: false,
415        }
416    }
417}
418
419pub struct PhysicalExprWithConstCols<T> {
420    constants: Vec<(PlSmallStr, Scalar)>,
421    child: T,
422}
423
424impl SkipBatchPredicate for PhysicalExprWithConstCols<Arc<dyn SkipBatchPredicate>> {
425    fn schema(&self) -> &SchemaRef {
426        self.child.schema()
427    }
428
429    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
430        let mut df = df.clone();
431        for (name, scalar) in self.constants.iter() {
432            df.with_column(Column::new_scalar(
433                name.clone(),
434                scalar.clone(),
435                df.height(),
436            ))?;
437        }
438        self.child.evaluate_with_stat_df(&df)
439    }
440}
441
442impl PhysicalIoExpr for PhysicalExprWithConstCols<Arc<dyn PhysicalIoExpr>> {
443    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
444        let mut df = df.clone();
445        for (name, scalar) in self.constants.iter() {
446            df.with_column(Column::new_scalar(
447                name.clone(),
448                scalar.clone(),
449                df.height(),
450            ))?;
451        }
452
453        self.child.evaluate_io(&df)
454    }
455}
456
457#[derive(Clone)]
458pub struct ScanIOPredicate {
459    pub predicate: Arc<dyn PhysicalIoExpr>,
460
461    /// Column names that are used in the predicate.
462    pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
463
464    /// A predicate that gets given statistics and evaluates whether a batch can be skipped.
465    pub skip_batch_predicate: Option<Arc<dyn SkipBatchPredicate>>,
466
467    /// A predicate that gets given statistics and evaluates whether a batch can be skipped.
468    pub column_predicates: Arc<ColumnPredicates>,
469
470    /// Predicate parts only referring to hive columns.
471    pub hive_predicate: Option<Arc<dyn PhysicalIoExpr>>,
472
473    pub hive_predicate_is_full_predicate: bool,
474}
475
476impl ScanIOPredicate {
477    pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) {
478        if constant_columns.is_empty() {
479            return;
480        }
481
482        let mut live_columns = self.live_columns.as_ref().clone();
483        for (c, _) in constant_columns.iter() {
484            live_columns.swap_remove(c);
485        }
486        self.live_columns = Arc::new(live_columns);
487
488        if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() {
489            let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3);
490            for (c, v) in constant_columns.iter() {
491                sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone()));
492                sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone()));
493                let nc = if v.is_null() {
494                    AnyValue::Null
495                } else {
496                    (0 as IdxSize).into()
497                };
498                sbp_constant_columns
499                    .push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc)));
500            }
501            self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols {
502                constants: sbp_constant_columns,
503                child: skip_batch_predicate,
504            }));
505        }
506
507        let mut column_predicates = self.column_predicates.as_ref().clone();
508        for (c, _) in constant_columns.iter() {
509            column_predicates.predicates.remove(c);
510        }
511        self.column_predicates = Arc::new(column_predicates);
512
513        self.predicate = Arc::new(PhysicalExprWithConstCols {
514            constants: constant_columns,
515            child: self.predicate.clone(),
516        });
517    }
518}
519
520impl fmt::Debug for ScanIOPredicate {
521    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
522        f.write_str("scan_io_predicate")
523    }
524}