Skip to main content

polars_io/
predicates.rs

1use std::fmt;
2
3use arrow::array::Array;
4use arrow::bitmap::{Bitmap, BitmapBuilder};
5use polars_core::prelude::*;
6#[cfg(feature = "parquet")]
7use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, SpecializedParquetColumnExpr};
8use polars_utils::format_pl_smallstr;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12pub trait PhysicalIoExpr: Send + Sync {
13    /// Take a [`DataFrame`] and produces a boolean [`Series`] that serves
14    /// as a predicate mask
15    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series>;
16}
17
18#[derive(Debug, Clone)]
19pub enum SpecializedColumnPredicate {
20    Equal(Scalar),
21    /// A closed (inclusive) range.
22    Between(Scalar, Scalar),
23    EqualOneOf(Box<[Scalar]>),
24    StartsWith(Box<[u8]>),
25    EndsWith(Box<[u8]>),
26    RegexMatch(regex::bytes::Regex),
27}
28
29#[derive(Clone)]
30pub struct ColumnPredicateExpr {
31    column_name: PlSmallStr,
32    dtype: DataType,
33    #[cfg(feature = "parquet")]
34    specialized: Option<SpecializedParquetColumnExpr>,
35    expr: Arc<dyn PhysicalIoExpr>,
36}
37
38impl ColumnPredicateExpr {
39    pub fn new(
40        column_name: PlSmallStr,
41        dtype: DataType,
42        expr: Arc<dyn PhysicalIoExpr>,
43        specialized: Option<SpecializedColumnPredicate>,
44    ) -> Self {
45        use SpecializedColumnPredicate as S;
46        #[cfg(feature = "parquet")]
47        use SpecializedParquetColumnExpr as P;
48        #[cfg(feature = "parquet")]
49        let specialized = specialized.and_then(|s| {
50            Some(match s {
51                S::Equal(s) => P::Equal(cast_to_parquet_scalar(s)?),
52                S::Between(low, high) => {
53                    P::Between(cast_to_parquet_scalar(low)?, cast_to_parquet_scalar(high)?)
54                },
55                S::EqualOneOf(scalars) => P::EqualOneOf(
56                    scalars
57                        .into_iter()
58                        .map(|s| cast_to_parquet_scalar(s).ok_or(()))
59                        .collect::<Result<Box<_>, ()>>()
60                        .ok()?,
61                ),
62                S::StartsWith(s) => P::StartsWith(s),
63                S::EndsWith(s) => P::EndsWith(s),
64                S::RegexMatch(s) => P::RegexMatch(s),
65            })
66        });
67
68        Self {
69            column_name,
70            dtype,
71            #[cfg(feature = "parquet")]
72            specialized,
73            expr,
74        }
75    }
76}
77
78#[cfg(feature = "parquet")]
79impl ParquetColumnExpr for ColumnPredicateExpr {
80    fn evaluate_mut(&self, values: &dyn Array, bm: &mut BitmapBuilder) {
81        // We should never evaluate nulls with this.
82        assert!(values.validity().is_none_or(|v| v.set_bits() == 0));
83
84        // @TODO: Probably these unwraps should be removed.
85        let series =
86            Series::from_chunk_and_dtype(self.column_name.clone(), values.to_boxed(), &self.dtype)
87                .unwrap();
88        let column = series.into_column();
89        let df = unsafe { DataFrame::new_unchecked(values.len(), vec![column]) };
90
91        // @TODO: Probably these unwraps should be removed.
92        let true_mask = self.expr.evaluate_io(&df).unwrap();
93        let true_mask = true_mask.bool().unwrap();
94
95        bm.reserve(true_mask.len());
96        for chunk in true_mask.downcast_iter() {
97            match chunk.validity() {
98                None => bm.extend_from_bitmap(chunk.values()),
99                Some(v) => bm.extend_from_bitmap(&(chunk.values() & v)),
100            }
101        }
102    }
103    fn evaluate_null(&self) -> bool {
104        let column = Column::full_null(self.column_name.clone(), 1, &self.dtype);
105        let df = unsafe { DataFrame::new_unchecked(1, vec![column]) };
106
107        // @TODO: Probably these unwraps should be removed.
108        let true_mask = self.expr.evaluate_io(&df).unwrap();
109        let true_mask = true_mask.bool().unwrap();
110
111        true_mask.get(0).unwrap_or(false)
112    }
113
114    fn as_specialized(&self) -> Option<&SpecializedParquetColumnExpr> {
115        self.specialized.as_ref()
116    }
117}
118
119#[cfg(feature = "parquet")]
120fn cast_to_parquet_scalar(scalar: Scalar) -> Option<ParquetScalar> {
121    use AnyValue as A;
122    use ParquetScalar as P;
123
124    Some(match scalar.into_value() {
125        A::Null => P::Null,
126        A::Boolean(v) => P::Boolean(v),
127
128        A::UInt8(v) => P::UInt8(v),
129        A::UInt16(v) => P::UInt16(v),
130        A::UInt32(v) => P::UInt32(v),
131        A::UInt64(v) => P::UInt64(v),
132
133        A::Int8(v) => P::Int8(v),
134        A::Int16(v) => P::Int16(v),
135        A::Int32(v) => P::Int32(v),
136        A::Int64(v) => P::Int64(v),
137
138        #[cfg(feature = "dtype-time")]
139        A::Date(v) => P::Int32(v),
140        #[cfg(feature = "dtype-datetime")]
141        A::Datetime(v, _, _) | A::DatetimeOwned(v, _, _) => P::Int64(v),
142        #[cfg(feature = "dtype-duration")]
143        A::Duration(v, _) => P::Int64(v),
144        #[cfg(feature = "dtype-time")]
145        A::Time(v) => P::Int64(v),
146
147        A::Float32(v) => P::Float32(v),
148        A::Float64(v) => P::Float64(v),
149
150        // @TODO: Cast to string
151        #[cfg(feature = "dtype-categorical")]
152        A::Categorical(_, _) | A::CategoricalOwned(_, _) | A::Enum(_, _) | A::EnumOwned(_, _) => {
153            return None;
154        },
155
156        A::String(v) => P::String(v.into()),
157        A::StringOwned(v) => P::String(v.as_str().into()),
158        A::Binary(v) => P::Binary(v.into()),
159        A::BinaryOwned(v) => P::Binary(v.into()),
160        _ => return None,
161    })
162}
163
164#[cfg(any(feature = "parquet", feature = "ipc"))]
165pub fn apply_predicate(
166    df: &mut DataFrame,
167    predicate: Option<&dyn PhysicalIoExpr>,
168    parallel: bool,
169) -> PolarsResult<()> {
170    if let (Some(predicate), false) = (&predicate, df.columns().is_empty()) {
171        let s = predicate.evaluate_io(df)?;
172        let mask = s.bool().expect("filter predicates was not of type boolean");
173
174        if parallel {
175            *df = df.filter(mask)?;
176        } else {
177            *df = df.filter_seq(mask)?;
178        }
179    }
180    Ok(())
181}
182
183/// Statistics of the values in a column.
184///
185/// The following statistics are tracked for each row group:
186/// - Null count
187/// - Minimum value
188/// - Maximum value
189#[derive(Debug, Clone)]
190#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
191pub struct ColumnStats {
192    field: Field,
193    // Each Series contains the stats for each row group.
194    null_count: Option<Series>,
195    min_value: Option<Series>,
196    max_value: Option<Series>,
197}
198
199impl ColumnStats {
200    /// Constructs a new [`ColumnStats`].
201    pub fn new(
202        field: Field,
203        null_count: Option<Series>,
204        min_value: Option<Series>,
205        max_value: Option<Series>,
206    ) -> Self {
207        Self {
208            field,
209            null_count,
210            min_value,
211            max_value,
212        }
213    }
214
215    /// Constructs a new [`ColumnStats`] with only the [`Field`] information and no statistics.
216    pub fn from_field(field: Field) -> Self {
217        Self {
218            field,
219            null_count: None,
220            min_value: None,
221            max_value: None,
222        }
223    }
224
225    /// Constructs a new [`ColumnStats`] from a single-value Series.
226    pub fn from_column_literal(s: Series) -> Self {
227        debug_assert_eq!(s.len(), 1);
228        Self {
229            field: s.field().into_owned(),
230            null_count: None,
231            min_value: Some(s.clone()),
232            max_value: Some(s),
233        }
234    }
235
236    pub fn field_name(&self) -> &PlSmallStr {
237        self.field.name()
238    }
239
240    /// Returns the [`DataType`] of the column.
241    pub fn dtype(&self) -> &DataType {
242        self.field.dtype()
243    }
244
245    /// Returns the null count of each row group of the column.
246    pub fn get_null_count_state(&self) -> Option<&Series> {
247        self.null_count.as_ref()
248    }
249
250    /// Returns the minimum value of each row group of the column.
251    pub fn get_min_state(&self) -> Option<&Series> {
252        self.min_value.as_ref()
253    }
254
255    /// Returns the maximum value of each row group of the column.
256    pub fn get_max_state(&self) -> Option<&Series> {
257        self.max_value.as_ref()
258    }
259
260    /// Returns the null count of the column.
261    pub fn null_count(&self) -> Option<usize> {
262        match self.dtype() {
263            #[cfg(feature = "dtype-struct")]
264            DataType::Struct(_) => None,
265            _ => {
266                let s = self.get_null_count_state()?;
267                // if all null, there are no statistics.
268                if s.null_count() != s.len() {
269                    s.sum().ok()
270                } else {
271                    None
272                }
273            },
274        }
275    }
276
277    /// Returns the minimum and maximum values of the column as a single [`Series`].
278    pub fn to_min_max(&self) -> Option<Series> {
279        let min_val = self.get_min_state()?;
280        let max_val = self.get_max_state()?;
281        let dtype = self.dtype();
282
283        if !use_min_max(dtype) {
284            return None;
285        }
286
287        let mut min_max_values = min_val.clone();
288        min_max_values.append(max_val).unwrap();
289        if min_max_values.null_count() > 0 {
290            None
291        } else {
292            Some(min_max_values)
293        }
294    }
295
296    /// Returns the minimum value of the column as a single-value [`Series`].
297    ///
298    /// Returns `None` if no maximum value is available.
299    pub fn to_min(&self) -> Option<&Series> {
300        // @scalar-opt
301        let min_val = self.min_value.as_ref()?;
302        let dtype = min_val.dtype();
303
304        if !use_min_max(dtype) || min_val.len() != 1 {
305            return None;
306        }
307
308        if min_val.null_count() > 0 {
309            None
310        } else {
311            Some(min_val)
312        }
313    }
314
315    /// Returns the maximum value of the column as a single-value [`Series`].
316    ///
317    /// Returns `None` if no maximum value is available.
318    pub fn to_max(&self) -> Option<&Series> {
319        // @scalar-opt
320        let max_val = self.max_value.as_ref()?;
321        let dtype = max_val.dtype();
322
323        if !use_min_max(dtype) || max_val.len() != 1 {
324            return None;
325        }
326
327        if max_val.null_count() > 0 {
328            None
329        } else {
330            Some(max_val)
331        }
332    }
333}
334
335/// Returns whether the [`DataType`] supports minimum/maximum operations.
336fn use_min_max(dtype: &DataType) -> bool {
337    dtype.is_primitive_numeric()
338        || dtype.is_temporal()
339        || matches!(
340            dtype,
341            DataType::String | DataType::Binary | DataType::Boolean
342        )
343}
344
345pub struct ColumnStatistics {
346    pub dtype: DataType,
347    pub min: AnyValue<'static>,
348    pub max: AnyValue<'static>,
349    pub null_count: Option<IdxSize>,
350}
351
352pub trait SkipBatchPredicate: Send + Sync {
353    fn schema(&self) -> &SchemaRef;
354
355    fn can_skip_batch(
356        &self,
357        batch_size: IdxSize,
358        live_columns: &PlIndexSet<PlSmallStr>,
359        mut statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
360    ) -> PolarsResult<bool> {
361        let mut columns = Vec::with_capacity(1 + live_columns.len() * 3);
362
363        columns.push(Column::new_scalar(
364            PlSmallStr::from_static("len"),
365            Scalar::new(IDX_DTYPE, batch_size.into()),
366            1,
367        ));
368
369        for col in live_columns.iter() {
370            let dtype = self.schema().get(col).unwrap();
371            let (min, max, nc) = match statistics.swap_remove(col) {
372                None => (
373                    Scalar::null(dtype.clone()),
374                    Scalar::null(dtype.clone()),
375                    Scalar::null(IDX_DTYPE),
376                ),
377                Some(stat) => (
378                    Scalar::new(dtype.clone(), stat.min),
379                    Scalar::new(dtype.clone(), stat.max),
380                    Scalar::new(
381                        IDX_DTYPE,
382                        stat.null_count.map_or(AnyValue::Null, |nc| nc.into()),
383                    ),
384                ),
385            };
386            columns.extend([
387                Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1),
388                Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1),
389                Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1),
390            ]);
391        }
392
393        // SAFETY:
394        // * Each column is length = 1
395        // * We have an IndexSet, so each column name is unique
396        let df = unsafe { DataFrame::new_unchecked(1, columns) };
397        Ok(self.evaluate_with_stat_df(&df)?.get_bit(0))
398    }
399    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap>;
400}
401
402#[derive(Clone)]
403pub struct ColumnPredicates {
404    pub predicates:
405        PlHashMap<PlSmallStr, (Arc<dyn PhysicalIoExpr>, Option<SpecializedColumnPredicate>)>,
406    pub is_sumwise_complete: bool,
407}
408
409// I want to be explicit here.
410#[allow(clippy::derivable_impls)]
411impl Default for ColumnPredicates {
412    fn default() -> Self {
413        Self {
414            predicates: PlHashMap::default(),
415            is_sumwise_complete: false,
416        }
417    }
418}
419
420pub struct PhysicalExprWithConstCols<T> {
421    constants: Vec<(PlSmallStr, Scalar)>,
422    child: T,
423}
424
425impl SkipBatchPredicate for PhysicalExprWithConstCols<Arc<dyn SkipBatchPredicate>> {
426    fn schema(&self) -> &SchemaRef {
427        self.child.schema()
428    }
429
430    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
431        let mut df = df.clone();
432        for (name, scalar) in self.constants.iter() {
433            df.with_column(Column::new_scalar(
434                name.clone(),
435                scalar.clone(),
436                df.height(),
437            ))?;
438        }
439        self.child.evaluate_with_stat_df(&df)
440    }
441}
442
443impl PhysicalIoExpr for PhysicalExprWithConstCols<Arc<dyn PhysicalIoExpr>> {
444    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
445        let mut df = df.clone();
446        for (name, scalar) in self.constants.iter() {
447            df.with_column(Column::new_scalar(
448                name.clone(),
449                scalar.clone(),
450                df.height(),
451            ))?;
452        }
453
454        self.child.evaluate_io(&df)
455    }
456}
457
458#[derive(Clone)]
459pub struct ScanIOPredicate {
460    pub predicate: Arc<dyn PhysicalIoExpr>,
461
462    /// Column names that are used in the predicate.
463    pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
464
465    /// A predicate that gets given statistics and evaluates whether a batch can be skipped.
466    pub skip_batch_predicate: Option<Arc<dyn SkipBatchPredicate>>,
467
468    /// A predicate that gets given statistics and evaluates whether a batch can be skipped.
469    pub column_predicates: Arc<ColumnPredicates>,
470
471    /// Predicate parts only referring to hive columns.
472    pub hive_predicate: Option<Arc<dyn PhysicalIoExpr>>,
473
474    pub hive_predicate_is_full_predicate: bool,
475}
476
477impl ScanIOPredicate {
478    pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) {
479        if constant_columns.is_empty() {
480            return;
481        }
482
483        let mut live_columns = self.live_columns.as_ref().clone();
484        for (c, _) in constant_columns.iter() {
485            live_columns.swap_remove(c);
486        }
487        self.live_columns = Arc::new(live_columns);
488
489        if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() {
490            let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3);
491            for (c, v) in constant_columns.iter() {
492                sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone()));
493                sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone()));
494                let nc = if v.is_null() {
495                    AnyValue::Null
496                } else {
497                    (0 as IdxSize).into()
498                };
499                sbp_constant_columns
500                    .push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc)));
501            }
502            self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols {
503                constants: sbp_constant_columns,
504                child: skip_batch_predicate,
505            }));
506        }
507
508        let mut column_predicates = self.column_predicates.as_ref().clone();
509        for (c, _) in constant_columns.iter() {
510            column_predicates.predicates.remove(c);
511        }
512        self.column_predicates = Arc::new(column_predicates);
513
514        self.predicate = Arc::new(PhysicalExprWithConstCols {
515            constants: constant_columns,
516            child: self.predicate.clone(),
517        });
518    }
519}
520
521impl fmt::Debug for ScanIOPredicate {
522    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
523        f.write_str("scan_io_predicate")
524    }
525}