polars_io/
predicates.rs

1use std::fmt;
2
3use arrow::array::Array;
4use arrow::bitmap::{Bitmap, MutableBitmap};
5use polars_core::prelude::*;
6#[cfg(feature = "parquet")]
7use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, ParquetScalarRange};
8use polars_utils::format_pl_smallstr;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12pub trait PhysicalIoExpr: Send + Sync {
13    /// Take a [`DataFrame`] and produces a boolean [`Series`] that serves
14    /// as a predicate mask
15    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series>;
16
17    /// Can take &dyn Statistics and determine of a file should be
18    /// read -> `true`
19    /// or not -> `false`
20    fn as_stats_evaluator(&self) -> Option<&dyn StatsEvaluator> {
21        None
22    }
23}
24
25#[derive(Debug, Clone)]
26pub enum SpecializedColumnPredicateExpr {
27    Eq(Scalar),
28    EqMissing(Scalar),
29}
30
31#[derive(Clone)]
32pub struct ColumnPredicateExpr {
33    column_name: PlSmallStr,
34    dtype: DataType,
35    specialized: Option<SpecializedColumnPredicateExpr>,
36    expr: Arc<dyn PhysicalIoExpr>,
37}
38
39impl ColumnPredicateExpr {
40    pub fn new(
41        column_name: PlSmallStr,
42        dtype: DataType,
43        expr: Arc<dyn PhysicalIoExpr>,
44        specialized: Option<SpecializedColumnPredicateExpr>,
45    ) -> Self {
46        Self {
47            column_name,
48            dtype,
49            specialized,
50            expr,
51        }
52    }
53
54    pub fn is_eq_scalar(&self) -> bool {
55        self.to_eq_scalar().is_some()
56    }
57    pub fn to_eq_scalar(&self) -> Option<&Scalar> {
58        match &self.specialized {
59            Some(SpecializedColumnPredicateExpr::Eq(sc)) if !sc.is_null() => Some(sc),
60            Some(SpecializedColumnPredicateExpr::EqMissing(sc)) => Some(sc),
61            _ => None,
62        }
63    }
64}
65
66#[cfg(feature = "parquet")]
67impl ParquetColumnExpr for ColumnPredicateExpr {
68    fn evaluate_mut(&self, values: &dyn Array, bm: &mut MutableBitmap) {
69        // We should never evaluate nulls with this.
70        assert!(values.validity().is_none_or(|v| v.set_bits() == 0));
71
72        // @NOTE: This is okay because we don't have Enums, Categoricals, or Decimals
73        let series =
74            Series::from_arrow_chunks(self.column_name.clone(), vec![values.to_boxed()]).unwrap();
75        let column = series.into_column();
76        let df = unsafe { DataFrame::new_no_checks(values.len(), vec![column]) };
77
78        // @TODO: Probably these unwraps should be removed.
79        let true_mask = self.expr.evaluate_io(&df).unwrap();
80        let true_mask = true_mask.bool().unwrap();
81
82        bm.reserve(true_mask.len());
83        for chunk in true_mask.downcast_iter() {
84            match chunk.validity() {
85                None => bm.extend(chunk.values()),
86                Some(v) => bm.extend(chunk.values() & v),
87            }
88        }
89    }
90    fn evaluate_null(&self) -> bool {
91        let column = Column::full_null(self.column_name.clone(), 1, &self.dtype);
92        let df = unsafe { DataFrame::new_no_checks(1, vec![column]) };
93
94        // @TODO: Probably these unwraps should be removed.
95        let true_mask = self.expr.evaluate_io(&df).unwrap();
96        let true_mask = true_mask.bool().unwrap();
97
98        true_mask.get(0).unwrap_or(false)
99    }
100
101    fn to_equals_scalar(&self) -> Option<ParquetScalar> {
102        self.to_eq_scalar()
103            .and_then(|s| cast_to_parquet_scalar(s.clone()))
104    }
105
106    fn to_range_scalar(&self) -> Option<ParquetScalarRange> {
107        None
108    }
109}
110
111#[cfg(feature = "parquet")]
112fn cast_to_parquet_scalar(scalar: Scalar) -> Option<ParquetScalar> {
113    use {AnyValue as A, ParquetScalar as P};
114
115    Some(match scalar.into_value() {
116        A::Null => P::Null,
117        A::Boolean(v) => P::Boolean(v),
118
119        A::UInt8(v) => P::UInt8(v),
120        A::UInt16(v) => P::UInt16(v),
121        A::UInt32(v) => P::UInt32(v),
122        A::UInt64(v) => P::UInt64(v),
123
124        A::Int8(v) => P::Int8(v),
125        A::Int16(v) => P::Int16(v),
126        A::Int32(v) => P::Int32(v),
127        A::Int64(v) => P::Int64(v),
128
129        #[cfg(feature = "dtype-time")]
130        A::Date(v) => P::Int32(v),
131        #[cfg(feature = "dtype-datetime")]
132        A::Datetime(v, _, _) | A::DatetimeOwned(v, _, _) => P::Int64(v),
133        #[cfg(feature = "dtype-duration")]
134        A::Duration(v, _) => P::Int64(v),
135        #[cfg(feature = "dtype-time")]
136        A::Time(v) => P::Int64(v),
137
138        A::Float32(v) => P::Float32(v),
139        A::Float64(v) => P::Float64(v),
140
141        // @TODO: Cast to string
142        #[cfg(feature = "dtype-categorical")]
143        A::Categorical(_, _, _)
144        | A::CategoricalOwned(_, _, _)
145        | A::Enum(_, _, _)
146        | A::EnumOwned(_, _, _) => return None,
147
148        A::String(v) => P::String(v.into()),
149        A::StringOwned(v) => P::String(v.as_str().into()),
150        A::Binary(v) => P::Binary(v.into()),
151        A::BinaryOwned(v) => P::Binary(v.into()),
152        _ => return None,
153    })
154}
155
156pub trait StatsEvaluator {
157    fn should_read(&self, stats: &BatchStats) -> PolarsResult<bool>;
158}
159
160#[cfg(any(feature = "parquet", feature = "ipc"))]
161pub fn apply_predicate(
162    df: &mut DataFrame,
163    predicate: Option<&dyn PhysicalIoExpr>,
164    parallel: bool,
165) -> PolarsResult<()> {
166    if let (Some(predicate), false) = (&predicate, df.get_columns().is_empty()) {
167        let s = predicate.evaluate_io(df)?;
168        let mask = s.bool().expect("filter predicates was not of type boolean");
169
170        if parallel {
171            *df = df.filter(mask)?;
172        } else {
173            *df = df._filter_seq(mask)?;
174        }
175    }
176    Ok(())
177}
178
179/// Statistics of the values in a column.
180///
181/// The following statistics are tracked for each row group:
182/// - Null count
183/// - Minimum value
184/// - Maximum value
185#[derive(Debug, Clone)]
186#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
187pub struct ColumnStats {
188    field: Field,
189    // Each Series contains the stats for each row group.
190    null_count: Option<Series>,
191    min_value: Option<Series>,
192    max_value: Option<Series>,
193}
194
195impl ColumnStats {
196    /// Constructs a new [`ColumnStats`].
197    pub fn new(
198        field: Field,
199        null_count: Option<Series>,
200        min_value: Option<Series>,
201        max_value: Option<Series>,
202    ) -> Self {
203        Self {
204            field,
205            null_count,
206            min_value,
207            max_value,
208        }
209    }
210
211    /// Constructs a new [`ColumnStats`] with only the [`Field`] information and no statistics.
212    pub fn from_field(field: Field) -> Self {
213        Self {
214            field,
215            null_count: None,
216            min_value: None,
217            max_value: None,
218        }
219    }
220
221    /// Constructs a new [`ColumnStats`] from a single-value Series.
222    pub fn from_column_literal(s: Series) -> Self {
223        debug_assert_eq!(s.len(), 1);
224        Self {
225            field: s.field().into_owned(),
226            null_count: None,
227            min_value: Some(s.clone()),
228            max_value: Some(s),
229        }
230    }
231
232    pub fn field_name(&self) -> &PlSmallStr {
233        self.field.name()
234    }
235
236    /// Returns the [`DataType`] of the column.
237    pub fn dtype(&self) -> &DataType {
238        self.field.dtype()
239    }
240
241    /// Returns the null count of each row group of the column.
242    pub fn get_null_count_state(&self) -> Option<&Series> {
243        self.null_count.as_ref()
244    }
245
246    /// Returns the minimum value of each row group of the column.
247    pub fn get_min_state(&self) -> Option<&Series> {
248        self.min_value.as_ref()
249    }
250
251    /// Returns the maximum value of each row group of the column.
252    pub fn get_max_state(&self) -> Option<&Series> {
253        self.max_value.as_ref()
254    }
255
256    /// Returns the null count of the column.
257    pub fn null_count(&self) -> Option<usize> {
258        match self.dtype() {
259            #[cfg(feature = "dtype-struct")]
260            DataType::Struct(_) => None,
261            _ => {
262                let s = self.get_null_count_state()?;
263                // if all null, there are no statistics.
264                if s.null_count() != s.len() {
265                    s.sum().ok()
266                } else {
267                    None
268                }
269            },
270        }
271    }
272
273    /// Returns the minimum and maximum values of the column as a single [`Series`].
274    pub fn to_min_max(&self) -> Option<Series> {
275        let min_val = self.get_min_state()?;
276        let max_val = self.get_max_state()?;
277        let dtype = self.dtype();
278
279        if !use_min_max(dtype) {
280            return None;
281        }
282
283        let mut min_max_values = min_val.clone();
284        min_max_values.append(max_val).unwrap();
285        if min_max_values.null_count() > 0 {
286            None
287        } else {
288            Some(min_max_values)
289        }
290    }
291
292    /// Returns the minimum value of the column as a single-value [`Series`].
293    ///
294    /// Returns `None` if no maximum value is available.
295    pub fn to_min(&self) -> Option<&Series> {
296        // @scalar-opt
297        let min_val = self.min_value.as_ref()?;
298        let dtype = min_val.dtype();
299
300        if !use_min_max(dtype) || min_val.len() != 1 {
301            return None;
302        }
303
304        if min_val.null_count() > 0 {
305            None
306        } else {
307            Some(min_val)
308        }
309    }
310
311    /// Returns the maximum value of the column as a single-value [`Series`].
312    ///
313    /// Returns `None` if no maximum value is available.
314    pub fn to_max(&self) -> Option<&Series> {
315        // @scalar-opt
316        let max_val = self.max_value.as_ref()?;
317        let dtype = max_val.dtype();
318
319        if !use_min_max(dtype) || max_val.len() != 1 {
320            return None;
321        }
322
323        if max_val.null_count() > 0 {
324            None
325        } else {
326            Some(max_val)
327        }
328    }
329}
330
331/// Returns whether the [`DataType`] supports minimum/maximum operations.
332fn use_min_max(dtype: &DataType) -> bool {
333    dtype.is_primitive_numeric()
334        || dtype.is_temporal()
335        || matches!(
336            dtype,
337            DataType::String | DataType::Binary | DataType::Boolean
338        )
339}
340
341pub struct ColumnStatistics {
342    pub dtype: DataType,
343    pub min: AnyValue<'static>,
344    pub max: AnyValue<'static>,
345    pub null_count: Option<IdxSize>,
346}
347
348pub trait SkipBatchPredicate: Send + Sync {
349    fn schema(&self) -> &SchemaRef;
350
351    fn can_skip_batch(
352        &self,
353        batch_size: IdxSize,
354        live_columns: &PlIndexSet<PlSmallStr>,
355        mut statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
356    ) -> PolarsResult<bool> {
357        let mut columns = Vec::with_capacity(1 + live_columns.len() * 3);
358
359        columns.push(Column::new_scalar(
360            PlSmallStr::from_static("len"),
361            Scalar::new(IDX_DTYPE, batch_size.into()),
362            1,
363        ));
364
365        for col in live_columns.iter() {
366            let dtype = self.schema().get(col).unwrap();
367            let (min, max, nc) = match statistics.swap_remove(col) {
368                None => (
369                    Scalar::null(dtype.clone()),
370                    Scalar::null(dtype.clone()),
371                    Scalar::null(IDX_DTYPE),
372                ),
373                Some(stat) => (
374                    Scalar::new(dtype.clone(), stat.min),
375                    Scalar::new(dtype.clone(), stat.max),
376                    Scalar::new(
377                        IDX_DTYPE,
378                        stat.null_count.map_or(AnyValue::Null, |nc| nc.into()),
379                    ),
380                ),
381            };
382            columns.extend([
383                Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1),
384                Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1),
385                Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1),
386            ]);
387        }
388
389        // SAFETY:
390        // * Each column is length = 1
391        // * We have an IndexSet, so each column name is unique
392        let df = unsafe { DataFrame::new_no_checks(1, columns) };
393        Ok(self.evaluate_with_stat_df(&df)?.get_bit(0))
394    }
395    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap>;
396}
397
398#[derive(Clone)]
399pub struct ColumnPredicates {
400    pub predicates: PlHashMap<
401        PlSmallStr,
402        (
403            Arc<dyn PhysicalIoExpr>,
404            Option<SpecializedColumnPredicateExpr>,
405        ),
406    >,
407    pub is_sumwise_complete: bool,
408}
409
410// I want to be explicit here.
411#[allow(clippy::derivable_impls)]
412impl Default for ColumnPredicates {
413    fn default() -> Self {
414        Self {
415            predicates: PlHashMap::default(),
416            is_sumwise_complete: false,
417        }
418    }
419}
420
421pub struct PhysicalExprWithConstCols<T> {
422    constants: Vec<(PlSmallStr, Scalar)>,
423    child: T,
424}
425
426impl SkipBatchPredicate for PhysicalExprWithConstCols<Arc<dyn SkipBatchPredicate>> {
427    fn schema(&self) -> &SchemaRef {
428        self.child.schema()
429    }
430
431    fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
432        let mut df = df.clone();
433        for (name, scalar) in self.constants.iter() {
434            df.with_column(Column::new_scalar(
435                name.clone(),
436                scalar.clone(),
437                df.height(),
438            ))?;
439        }
440        self.child.evaluate_with_stat_df(&df)
441    }
442}
443
444impl PhysicalIoExpr for PhysicalExprWithConstCols<Arc<dyn PhysicalIoExpr>> {
445    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
446        let mut df = df.clone();
447        for (name, scalar) in self.constants.iter() {
448            df.with_column(Column::new_scalar(
449                name.clone(),
450                scalar.clone(),
451                df.height(),
452            ))?;
453        }
454
455        self.child.evaluate_io(&df)
456    }
457}
458
459#[derive(Clone)]
460pub struct ScanIOPredicate {
461    pub predicate: Arc<dyn PhysicalIoExpr>,
462
463    /// Column names that are used in the predicate.
464    pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
465
466    /// A predicate that gets given statistics and evaluates whether a batch can be skipped.
467    pub skip_batch_predicate: Option<Arc<dyn SkipBatchPredicate>>,
468
469    /// A predicate that gets given statistics and evaluates whether a batch can be skipped.
470    pub column_predicates: Arc<ColumnPredicates>,
471}
472impl ScanIOPredicate {
473    pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) {
474        if constant_columns.is_empty() {
475            return;
476        }
477
478        let mut live_columns = self.live_columns.as_ref().clone();
479        for (c, _) in constant_columns.iter() {
480            live_columns.swap_remove(c);
481        }
482        self.live_columns = Arc::new(live_columns);
483
484        if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() {
485            let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3);
486            for (c, v) in constant_columns.iter() {
487                sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone()));
488                sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone()));
489                let nc = if v.is_null() {
490                    AnyValue::Null
491                } else {
492                    (0 as IdxSize).into()
493                };
494                sbp_constant_columns
495                    .push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc)));
496            }
497            self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols {
498                constants: sbp_constant_columns,
499                child: skip_batch_predicate,
500            }));
501        }
502
503        let mut column_predicates = self.column_predicates.as_ref().clone();
504        for (c, _) in constant_columns.iter() {
505            column_predicates.predicates.remove(c);
506        }
507        self.column_predicates = Arc::new(column_predicates);
508
509        self.predicate = Arc::new(PhysicalExprWithConstCols {
510            constants: constant_columns,
511            child: self.predicate.clone(),
512        });
513    }
514}
515
516impl fmt::Debug for ScanIOPredicate {
517    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
518        f.write_str("scan_io_predicate")
519    }
520}
521
522/// A collection of column stats with a known schema.
523#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
524#[derive(Debug, Clone)]
525pub struct BatchStats {
526    schema: SchemaRef,
527    stats: Vec<ColumnStats>,
528    // This might not be available, as when pruning hive partitions.
529    num_rows: Option<usize>,
530}
531
532impl Default for BatchStats {
533    fn default() -> Self {
534        Self {
535            schema: Arc::new(Schema::default()),
536            stats: Vec::new(),
537            num_rows: None,
538        }
539    }
540}
541
542impl BatchStats {
543    /// Constructs a new [`BatchStats`].
544    ///
545    /// The `stats` should match the order of the `schema`.
546    pub fn new(schema: SchemaRef, stats: Vec<ColumnStats>, num_rows: Option<usize>) -> Self {
547        Self {
548            schema,
549            stats,
550            num_rows,
551        }
552    }
553
554    /// Returns the [`Schema`] of the batch.
555    pub fn schema(&self) -> &SchemaRef {
556        &self.schema
557    }
558
559    /// Returns the [`ColumnStats`] of all columns in the batch, if known.
560    pub fn column_stats(&self) -> &[ColumnStats] {
561        self.stats.as_ref()
562    }
563
564    /// Returns the [`ColumnStats`] of a single column in the batch.
565    ///
566    /// Returns an `Err` if no statistics are available for the given column.
567    pub fn get_stats(&self, column: &str) -> PolarsResult<&ColumnStats> {
568        self.schema.try_index_of(column).map(|i| &self.stats[i])
569    }
570
571    /// Returns the number of rows in the batch.
572    ///
573    /// Returns `None` if the number of rows is unknown.
574    pub fn num_rows(&self) -> Option<usize> {
575        self.num_rows
576    }
577
578    pub fn with_schema(&mut self, schema: SchemaRef) {
579        self.schema = schema;
580    }
581
582    pub fn take_indices(&mut self, indices: &[usize]) {
583        self.stats = indices.iter().map(|&i| self.stats[i].clone()).collect();
584    }
585}