1use std::fmt;
2
3use arrow::array::Array;
4use arrow::bitmap::{Bitmap, MutableBitmap};
5use polars_core::prelude::*;
6#[cfg(feature = "parquet")]
7use polars_parquet::read::expr::{ParquetColumnExpr, ParquetScalar, ParquetScalarRange};
8use polars_utils::format_pl_smallstr;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12pub trait PhysicalIoExpr: Send + Sync {
13 fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series>;
16
17 fn as_stats_evaluator(&self) -> Option<&dyn StatsEvaluator> {
21 None
22 }
23}
24
25#[derive(Debug, Clone)]
26pub enum SpecializedColumnPredicateExpr {
27 Eq(Scalar),
28 EqMissing(Scalar),
29}
30
31#[derive(Clone)]
32pub struct ColumnPredicateExpr {
33 column_name: PlSmallStr,
34 dtype: DataType,
35 specialized: Option<SpecializedColumnPredicateExpr>,
36 expr: Arc<dyn PhysicalIoExpr>,
37}
38
39impl ColumnPredicateExpr {
40 pub fn new(
41 column_name: PlSmallStr,
42 dtype: DataType,
43 expr: Arc<dyn PhysicalIoExpr>,
44 specialized: Option<SpecializedColumnPredicateExpr>,
45 ) -> Self {
46 Self {
47 column_name,
48 dtype,
49 specialized,
50 expr,
51 }
52 }
53
54 pub fn is_eq_scalar(&self) -> bool {
55 self.to_eq_scalar().is_some()
56 }
57 pub fn to_eq_scalar(&self) -> Option<&Scalar> {
58 match &self.specialized {
59 Some(SpecializedColumnPredicateExpr::Eq(sc)) if !sc.is_null() => Some(sc),
60 Some(SpecializedColumnPredicateExpr::EqMissing(sc)) => Some(sc),
61 _ => None,
62 }
63 }
64}
65
66#[cfg(feature = "parquet")]
67impl ParquetColumnExpr for ColumnPredicateExpr {
68 fn evaluate_mut(&self, values: &dyn Array, bm: &mut MutableBitmap) {
69 assert!(values.validity().is_none_or(|v| v.set_bits() == 0));
71
72 let series =
74 Series::from_arrow_chunks(self.column_name.clone(), vec![values.to_boxed()]).unwrap();
75 let column = series.into_column();
76 let df = unsafe { DataFrame::new_no_checks(values.len(), vec![column]) };
77
78 let true_mask = self.expr.evaluate_io(&df).unwrap();
80 let true_mask = true_mask.bool().unwrap();
81
82 bm.reserve(true_mask.len());
83 for chunk in true_mask.downcast_iter() {
84 match chunk.validity() {
85 None => bm.extend(chunk.values()),
86 Some(v) => bm.extend(chunk.values() & v),
87 }
88 }
89 }
90 fn evaluate_null(&self) -> bool {
91 let column = Column::full_null(self.column_name.clone(), 1, &self.dtype);
92 let df = unsafe { DataFrame::new_no_checks(1, vec![column]) };
93
94 let true_mask = self.expr.evaluate_io(&df).unwrap();
96 let true_mask = true_mask.bool().unwrap();
97
98 true_mask.get(0).unwrap_or(false)
99 }
100
101 fn to_equals_scalar(&self) -> Option<ParquetScalar> {
102 self.to_eq_scalar()
103 .and_then(|s| cast_to_parquet_scalar(s.clone()))
104 }
105
106 fn to_range_scalar(&self) -> Option<ParquetScalarRange> {
107 None
108 }
109}
110
111#[cfg(feature = "parquet")]
112fn cast_to_parquet_scalar(scalar: Scalar) -> Option<ParquetScalar> {
113 use {AnyValue as A, ParquetScalar as P};
114
115 Some(match scalar.into_value() {
116 A::Null => P::Null,
117 A::Boolean(v) => P::Boolean(v),
118
119 A::UInt8(v) => P::UInt8(v),
120 A::UInt16(v) => P::UInt16(v),
121 A::UInt32(v) => P::UInt32(v),
122 A::UInt64(v) => P::UInt64(v),
123
124 A::Int8(v) => P::Int8(v),
125 A::Int16(v) => P::Int16(v),
126 A::Int32(v) => P::Int32(v),
127 A::Int64(v) => P::Int64(v),
128
129 #[cfg(feature = "dtype-time")]
130 A::Date(v) => P::Int32(v),
131 #[cfg(feature = "dtype-datetime")]
132 A::Datetime(v, _, _) | A::DatetimeOwned(v, _, _) => P::Int64(v),
133 #[cfg(feature = "dtype-duration")]
134 A::Duration(v, _) => P::Int64(v),
135 #[cfg(feature = "dtype-time")]
136 A::Time(v) => P::Int64(v),
137
138 A::Float32(v) => P::Float32(v),
139 A::Float64(v) => P::Float64(v),
140
141 #[cfg(feature = "dtype-categorical")]
143 A::Categorical(_, _, _)
144 | A::CategoricalOwned(_, _, _)
145 | A::Enum(_, _, _)
146 | A::EnumOwned(_, _, _) => return None,
147
148 A::String(v) => P::String(v.into()),
149 A::StringOwned(v) => P::String(v.as_str().into()),
150 A::Binary(v) => P::Binary(v.into()),
151 A::BinaryOwned(v) => P::Binary(v.into()),
152 _ => return None,
153 })
154}
155
156pub trait StatsEvaluator {
157 fn should_read(&self, stats: &BatchStats) -> PolarsResult<bool>;
158}
159
160#[cfg(any(feature = "parquet", feature = "ipc"))]
161pub fn apply_predicate(
162 df: &mut DataFrame,
163 predicate: Option<&dyn PhysicalIoExpr>,
164 parallel: bool,
165) -> PolarsResult<()> {
166 if let (Some(predicate), false) = (&predicate, df.get_columns().is_empty()) {
167 let s = predicate.evaluate_io(df)?;
168 let mask = s.bool().expect("filter predicates was not of type boolean");
169
170 if parallel {
171 *df = df.filter(mask)?;
172 } else {
173 *df = df._filter_seq(mask)?;
174 }
175 }
176 Ok(())
177}
178
179#[derive(Debug, Clone)]
186#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
187pub struct ColumnStats {
188 field: Field,
189 null_count: Option<Series>,
191 min_value: Option<Series>,
192 max_value: Option<Series>,
193}
194
195impl ColumnStats {
196 pub fn new(
198 field: Field,
199 null_count: Option<Series>,
200 min_value: Option<Series>,
201 max_value: Option<Series>,
202 ) -> Self {
203 Self {
204 field,
205 null_count,
206 min_value,
207 max_value,
208 }
209 }
210
211 pub fn from_field(field: Field) -> Self {
213 Self {
214 field,
215 null_count: None,
216 min_value: None,
217 max_value: None,
218 }
219 }
220
221 pub fn from_column_literal(s: Series) -> Self {
223 debug_assert_eq!(s.len(), 1);
224 Self {
225 field: s.field().into_owned(),
226 null_count: None,
227 min_value: Some(s.clone()),
228 max_value: Some(s),
229 }
230 }
231
232 pub fn field_name(&self) -> &PlSmallStr {
233 self.field.name()
234 }
235
236 pub fn dtype(&self) -> &DataType {
238 self.field.dtype()
239 }
240
241 pub fn get_null_count_state(&self) -> Option<&Series> {
243 self.null_count.as_ref()
244 }
245
246 pub fn get_min_state(&self) -> Option<&Series> {
248 self.min_value.as_ref()
249 }
250
251 pub fn get_max_state(&self) -> Option<&Series> {
253 self.max_value.as_ref()
254 }
255
256 pub fn null_count(&self) -> Option<usize> {
258 match self.dtype() {
259 #[cfg(feature = "dtype-struct")]
260 DataType::Struct(_) => None,
261 _ => {
262 let s = self.get_null_count_state()?;
263 if s.null_count() != s.len() {
265 s.sum().ok()
266 } else {
267 None
268 }
269 },
270 }
271 }
272
273 pub fn to_min_max(&self) -> Option<Series> {
275 let min_val = self.get_min_state()?;
276 let max_val = self.get_max_state()?;
277 let dtype = self.dtype();
278
279 if !use_min_max(dtype) {
280 return None;
281 }
282
283 let mut min_max_values = min_val.clone();
284 min_max_values.append(max_val).unwrap();
285 if min_max_values.null_count() > 0 {
286 None
287 } else {
288 Some(min_max_values)
289 }
290 }
291
292 pub fn to_min(&self) -> Option<&Series> {
296 let min_val = self.min_value.as_ref()?;
298 let dtype = min_val.dtype();
299
300 if !use_min_max(dtype) || min_val.len() != 1 {
301 return None;
302 }
303
304 if min_val.null_count() > 0 {
305 None
306 } else {
307 Some(min_val)
308 }
309 }
310
311 pub fn to_max(&self) -> Option<&Series> {
315 let max_val = self.max_value.as_ref()?;
317 let dtype = max_val.dtype();
318
319 if !use_min_max(dtype) || max_val.len() != 1 {
320 return None;
321 }
322
323 if max_val.null_count() > 0 {
324 None
325 } else {
326 Some(max_val)
327 }
328 }
329}
330
331fn use_min_max(dtype: &DataType) -> bool {
333 dtype.is_primitive_numeric()
334 || dtype.is_temporal()
335 || matches!(
336 dtype,
337 DataType::String | DataType::Binary | DataType::Boolean
338 )
339}
340
341pub struct ColumnStatistics {
342 pub dtype: DataType,
343 pub min: AnyValue<'static>,
344 pub max: AnyValue<'static>,
345 pub null_count: Option<IdxSize>,
346}
347
348pub trait SkipBatchPredicate: Send + Sync {
349 fn schema(&self) -> &SchemaRef;
350
351 fn can_skip_batch(
352 &self,
353 batch_size: IdxSize,
354 live_columns: &PlIndexSet<PlSmallStr>,
355 mut statistics: PlIndexMap<PlSmallStr, ColumnStatistics>,
356 ) -> PolarsResult<bool> {
357 let mut columns = Vec::with_capacity(1 + live_columns.len() * 3);
358
359 columns.push(Column::new_scalar(
360 PlSmallStr::from_static("len"),
361 Scalar::new(IDX_DTYPE, batch_size.into()),
362 1,
363 ));
364
365 for col in live_columns.iter() {
366 let dtype = self.schema().get(col).unwrap();
367 let (min, max, nc) = match statistics.swap_remove(col) {
368 None => (
369 Scalar::null(dtype.clone()),
370 Scalar::null(dtype.clone()),
371 Scalar::null(IDX_DTYPE),
372 ),
373 Some(stat) => (
374 Scalar::new(dtype.clone(), stat.min),
375 Scalar::new(dtype.clone(), stat.max),
376 Scalar::new(
377 IDX_DTYPE,
378 stat.null_count.map_or(AnyValue::Null, |nc| nc.into()),
379 ),
380 ),
381 };
382 columns.extend([
383 Column::new_scalar(format_pl_smallstr!("{col}_min"), min, 1),
384 Column::new_scalar(format_pl_smallstr!("{col}_max"), max, 1),
385 Column::new_scalar(format_pl_smallstr!("{col}_nc"), nc, 1),
386 ]);
387 }
388
389 let df = unsafe { DataFrame::new_no_checks(1, columns) };
393 Ok(self.evaluate_with_stat_df(&df)?.get_bit(0))
394 }
395 fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap>;
396}
397
398#[derive(Clone)]
399pub struct ColumnPredicates {
400 pub predicates: PlHashMap<
401 PlSmallStr,
402 (
403 Arc<dyn PhysicalIoExpr>,
404 Option<SpecializedColumnPredicateExpr>,
405 ),
406 >,
407 pub is_sumwise_complete: bool,
408}
409
410#[allow(clippy::derivable_impls)]
412impl Default for ColumnPredicates {
413 fn default() -> Self {
414 Self {
415 predicates: PlHashMap::default(),
416 is_sumwise_complete: false,
417 }
418 }
419}
420
421pub struct PhysicalExprWithConstCols<T> {
422 constants: Vec<(PlSmallStr, Scalar)>,
423 child: T,
424}
425
426impl SkipBatchPredicate for PhysicalExprWithConstCols<Arc<dyn SkipBatchPredicate>> {
427 fn schema(&self) -> &SchemaRef {
428 self.child.schema()
429 }
430
431 fn evaluate_with_stat_df(&self, df: &DataFrame) -> PolarsResult<Bitmap> {
432 let mut df = df.clone();
433 for (name, scalar) in self.constants.iter() {
434 df.with_column(Column::new_scalar(
435 name.clone(),
436 scalar.clone(),
437 df.height(),
438 ))?;
439 }
440 self.child.evaluate_with_stat_df(&df)
441 }
442}
443
444impl PhysicalIoExpr for PhysicalExprWithConstCols<Arc<dyn PhysicalIoExpr>> {
445 fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
446 let mut df = df.clone();
447 for (name, scalar) in self.constants.iter() {
448 df.with_column(Column::new_scalar(
449 name.clone(),
450 scalar.clone(),
451 df.height(),
452 ))?;
453 }
454
455 self.child.evaluate_io(&df)
456 }
457}
458
459#[derive(Clone)]
460pub struct ScanIOPredicate {
461 pub predicate: Arc<dyn PhysicalIoExpr>,
462
463 pub live_columns: Arc<PlIndexSet<PlSmallStr>>,
465
466 pub skip_batch_predicate: Option<Arc<dyn SkipBatchPredicate>>,
468
469 pub column_predicates: Arc<ColumnPredicates>,
471}
472impl ScanIOPredicate {
473 pub fn set_external_constant_columns(&mut self, constant_columns: Vec<(PlSmallStr, Scalar)>) {
474 if constant_columns.is_empty() {
475 return;
476 }
477
478 let mut live_columns = self.live_columns.as_ref().clone();
479 for (c, _) in constant_columns.iter() {
480 live_columns.swap_remove(c);
481 }
482 self.live_columns = Arc::new(live_columns);
483
484 if let Some(skip_batch_predicate) = self.skip_batch_predicate.take() {
485 let mut sbp_constant_columns = Vec::with_capacity(constant_columns.len() * 3);
486 for (c, v) in constant_columns.iter() {
487 sbp_constant_columns.push((format_pl_smallstr!("{c}_min"), v.clone()));
488 sbp_constant_columns.push((format_pl_smallstr!("{c}_max"), v.clone()));
489 let nc = if v.is_null() {
490 AnyValue::Null
491 } else {
492 (0 as IdxSize).into()
493 };
494 sbp_constant_columns
495 .push((format_pl_smallstr!("{c}_nc"), Scalar::new(IDX_DTYPE, nc)));
496 }
497 self.skip_batch_predicate = Some(Arc::new(PhysicalExprWithConstCols {
498 constants: sbp_constant_columns,
499 child: skip_batch_predicate,
500 }));
501 }
502
503 let mut column_predicates = self.column_predicates.as_ref().clone();
504 for (c, _) in constant_columns.iter() {
505 column_predicates.predicates.remove(c);
506 }
507 self.column_predicates = Arc::new(column_predicates);
508
509 self.predicate = Arc::new(PhysicalExprWithConstCols {
510 constants: constant_columns,
511 child: self.predicate.clone(),
512 });
513 }
514}
515
516impl fmt::Debug for ScanIOPredicate {
517 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
518 f.write_str("scan_io_predicate")
519 }
520}
521
522#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
524#[derive(Debug, Clone)]
525pub struct BatchStats {
526 schema: SchemaRef,
527 stats: Vec<ColumnStats>,
528 num_rows: Option<usize>,
530}
531
532impl Default for BatchStats {
533 fn default() -> Self {
534 Self {
535 schema: Arc::new(Schema::default()),
536 stats: Vec::new(),
537 num_rows: None,
538 }
539 }
540}
541
542impl BatchStats {
543 pub fn new(schema: SchemaRef, stats: Vec<ColumnStats>, num_rows: Option<usize>) -> Self {
547 Self {
548 schema,
549 stats,
550 num_rows,
551 }
552 }
553
554 pub fn schema(&self) -> &SchemaRef {
556 &self.schema
557 }
558
559 pub fn column_stats(&self) -> &[ColumnStats] {
561 self.stats.as_ref()
562 }
563
564 pub fn get_stats(&self, column: &str) -> PolarsResult<&ColumnStats> {
568 self.schema.try_index_of(column).map(|i| &self.stats[i])
569 }
570
571 pub fn num_rows(&self) -> Option<usize> {
575 self.num_rows
576 }
577
578 pub fn with_schema(&mut self, schema: SchemaRef) {
579 self.schema = schema;
580 }
581
582 pub fn take_indices(&mut self, indices: &[usize]) {
583 self.stats = indices.iter().map(|&i| self.stats[i].clone()).collect();
584 }
585}