diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 0ab2046097c4..340caaba5c4a 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -16,9 +16,9 @@ // under the License. use arrow::{array::ArrayRef, datatypes::Schema}; +use arrow_array::BooleanArray; use arrow_schema::FieldRef; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{Column, DataFusionError, Result, ScalarValue}; +use datafusion_common::{Column, ScalarValue}; use parquet::file::metadata::ColumnChunkMetaData; use parquet::schema::types::SchemaDescriptor; use parquet::{ @@ -26,19 +26,13 @@ use parquet::{ bloom_filter::Sbbf, file::metadata::RowGroupMetaData, }; -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; +use std::collections::{HashMap, HashSet}; use crate::datasource::listing::FileRange; use crate::datasource::physical_plan::parquet::statistics::{ max_statistics, min_statistics, parquet_column, }; -use crate::logical_expr::Operator; -use crate::physical_expr::expressions as phys_expr; use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; -use crate::physical_plan::PhysicalExpr; use super::ParquetFileMetrics; @@ -122,182 +116,118 @@ pub(crate) async fn prune_row_groups_by_bloom_filters< predicate: &PruningPredicate, metrics: &ParquetFileMetrics, ) -> Vec { - let bf_predicates = match BloomFilterPruningPredicate::try_new(predicate.orig_expr()) - { - Ok(predicates) => predicates, - Err(_) => { - return row_groups.to_vec(); - } - }; + println!( + "prune_row_groups_by_bloom_filters with pruning predicate: {:#?}", + predicate + ); let mut filtered = Vec::with_capacity(groups.len()); for idx in row_groups { let rg_metadata = &groups[*idx]; // get all columns bloom filter - let mut column_sbbf = - HashMap::with_capacity(bf_predicates.required_columns.len()); - for column_name in bf_predicates.required_columns.iter() { - let column_idx = match rg_metadata + let literal_columns = predicate.literal_columns(); + let mut column_sbbf = HashMap::with_capacity(literal_columns.len()); + + for column_name in literal_columns { + // This is very likely incorrect as it will not work for nested columns + // should use parquet_column instead + let Some((column_idx, _)) = rg_metadata .columns() .iter() .enumerate() - .find(|(_, column)| column.column_path().string().eq(column_name)) - { - Some((column_idx, _)) => column_idx, - None => continue, + .find(|(_, column)| column.column_path().string().eq(&column_name)) + else { + continue; }; + let bf = match builder .get_row_group_column_bloom_filter(*idx, column_idx) .await { - Ok(bf) => match bf { - Some(bf) => bf, - None => { - continue; - } - }, + Ok(Some(bf)) => bf, + Ok(None) => continue, // no bloom filter for this column Err(e) => { - log::error!("Error evaluating row group predicate values when using BloomFilterPruningPredicate {e}"); + log::error!("Ignoring error reading bloom filter: {e}"); metrics.predicate_evaluation_errors.add(1); continue; } }; - column_sbbf.insert(column_name.to_owned(), bf); + column_sbbf.insert(column_name.to_string(), bf); } - if bf_predicates.prune(&column_sbbf) { + + let stats = BloomFilterStatistics { column_sbbf }; + + // Can this group be pruned? + let prune_result = predicate.prune(&stats); + println!("prune result: {:?}", prune_result); + let prune_group = match prune_result { + Ok(values) => !values[0], + Err(e) => { + log::debug!("Error evaluating row group predicate on bloom filter: {e}"); + metrics.predicate_evaluation_errors.add(1); + false + } + }; + + println!("prune group: {}", prune_group); + + if prune_group { metrics.row_groups_pruned.add(1); - continue; + } else { + filtered.push(*idx); } - filtered.push(*idx); } filtered } -struct BloomFilterPruningPredicate { - /// Actual pruning predicate - predicate_expr: Option, - /// The statistics required to evaluate this predicate - required_columns: Vec, +struct BloomFilterStatistics { + column_sbbf: HashMap, } -impl BloomFilterPruningPredicate { - fn try_new(expr: &Arc) -> Result { - let binary_expr = expr.as_any().downcast_ref::(); - match binary_expr { - Some(binary_expr) => { - let columns = Self::get_predicate_columns(expr); - Ok(Self { - predicate_expr: Some(binary_expr.clone()), - required_columns: columns.into_iter().collect(), - }) - } - None => Err(DataFusionError::Execution( - "BloomFilterPruningPredicate only support binary expr".to_string(), - )), - } - } - - fn prune(&self, column_sbbf: &HashMap) -> bool { - Self::prune_expr_with_bloom_filter(self.predicate_expr.as_ref(), column_sbbf) +impl PruningStatistics for BloomFilterStatistics { + fn num_containers(&self) -> usize { + 1 } - /// Return true if the `expr` can be proved not `true` - /// based on the bloom filter. - /// - /// We only checked `BinaryExpr` but it also support `InList`, - /// Because of the `optimizer` will convert `InList` to `BinaryExpr`. - fn prune_expr_with_bloom_filter( - expr: Option<&phys_expr::BinaryExpr>, - column_sbbf: &HashMap, - ) -> bool { - let Some(expr) = expr else { - // unsupported predicate - return false; + /// Use bloom filters to determine if we are sure this column can not contain `value` + fn contains( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + println!("Checking column {} for values {:?}", column.name, values); + let sbbf = self.column_sbbf.get(column.name.as_str())?; + println!(" have sbbf: {:?}", sbbf); + + // if true, means column probably contains value + // if false, means column definitely DOES NOT contain value + let known_not_present = values + .iter() + .map(|value| match value { + ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Boolean(Some(v)) => sbbf.check(v), + ScalarValue::Float64(Some(v)) => sbbf.check(v), + ScalarValue::Float32(Some(v)) => sbbf.check(v), + ScalarValue::Int64(Some(v)) => sbbf.check(v), + ScalarValue::Int32(Some(v)) => sbbf.check(v), + ScalarValue::Int16(Some(v)) => sbbf.check(v), + ScalarValue::Int8(Some(v)) => sbbf.check(v), + _ => true, + }) + // We know the row group doesn't contain any of the values if the checks are all + // false + .all(|v| !v); + println!("known_not_present result: {}", known_not_present); + + let contains = if known_not_present { + Some(false) + } else { + // The column might contain one of the values + None }; - match expr.op() { - Operator::And | Operator::Or => { - let left = Self::prune_expr_with_bloom_filter( - expr.left().as_any().downcast_ref::(), - column_sbbf, - ); - let right = Self::prune_expr_with_bloom_filter( - expr.right() - .as_any() - .downcast_ref::(), - column_sbbf, - ); - match expr.op() { - Operator::And => left || right, - Operator::Or => left && right, - _ => false, - } - } - Operator::Eq => { - if let Some((col, val)) = Self::check_expr_is_col_equal_const(expr) { - if let Some(sbbf) = column_sbbf.get(col.name()) { - match val { - ScalarValue::Utf8(Some(v)) => !sbbf.check(&v.as_str()), - ScalarValue::Boolean(Some(v)) => !sbbf.check(&v), - ScalarValue::Float64(Some(v)) => !sbbf.check(&v), - ScalarValue::Float32(Some(v)) => !sbbf.check(&v), - ScalarValue::Int64(Some(v)) => !sbbf.check(&v), - ScalarValue::Int32(Some(v)) => !sbbf.check(&v), - ScalarValue::Int16(Some(v)) => !sbbf.check(&v), - ScalarValue::Int8(Some(v)) => !sbbf.check(&v), - _ => false, - } - } else { - false - } - } else { - false - } - } - _ => false, - } - } - - fn get_predicate_columns(expr: &Arc) -> HashSet { - let mut columns = HashSet::new(); - expr.apply(&mut |expr| { - if let Some(binary_expr) = - expr.as_any().downcast_ref::() - { - if let Some((column, _)) = - Self::check_expr_is_col_equal_const(binary_expr) - { - columns.insert(column.name().to_string()); - } - } - Ok(VisitRecursion::Continue) - }) - // no way to fail as only Ok(VisitRecursion::Continue) is returned - .unwrap(); - - columns - } - - fn check_expr_is_col_equal_const( - exr: &phys_expr::BinaryExpr, - ) -> Option<(phys_expr::Column, ScalarValue)> { - if Operator::Eq.ne(exr.op()) { - return None; - } - let left_any = exr.left().as_any(); - let right_any = exr.right().as_any(); - if let (Some(col), Some(liter)) = ( - left_any.downcast_ref::(), - right_any.downcast_ref::(), - ) { - return Some((col.clone(), liter.value().clone())); - } - if let (Some(liter), Some(col)) = ( - left_any.downcast_ref::(), - right_any.downcast_ref::(), - ) { - return Some((col.clone(), liter.value().clone())); - } - None + let result = Some(BooleanArray::from(vec![contains])); + println!("result: {:?}", result); + result } } @@ -350,6 +280,7 @@ mod tests { use arrow::datatypes::Schema; use arrow::datatypes::{DataType, Field}; use datafusion_common::{config::ConfigOptions, TableReference, ToDFSchema}; + use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ builder::LogicalTableSource, cast, col, lit, AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF, diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index de508327fade..7cca1947bb31 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -35,12 +35,13 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::{downcast_value, plan_datafusion_err, ScalarValue}; +use arrow_array::cast::AsArray; use datafusion_common::{ internal_err, plan_err, tree_node::{Transformed, TreeNode}, }; -use datafusion_physical_expr::utils::collect_columns; +use datafusion_common::{plan_datafusion_err, ScalarValue}; +use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use log::trace; @@ -68,11 +69,15 @@ use log::trace; pub trait PruningStatistics { /// return the minimum values for the named column, if known. /// Note: the returned array must contain `num_containers()` rows - fn min_values(&self, column: &Column) -> Option; + fn min_values(&self, _column: &Column) -> Option { + None + } /// return the maximum values for the named column, if known. /// Note: the returned array must contain `num_containers()` rows. - fn max_values(&self, column: &Column) -> Option; + fn max_values(&self, _column: &Column) -> Option { + None + } /// return the number of containers (e.g. row groups) being /// pruned with these statistics @@ -82,7 +87,32 @@ pub trait PruningStatistics { /// `Option`. /// /// Note: the returned array must contain `num_containers()` rows. - fn null_counts(&self, column: &Column) -> Option; + fn null_counts(&self, _column: &Column) -> Option { + None + } + + /// Returns an array where each element represents if the value of the + /// column CERTAINLY DOES NOT contain any of the specified `values`. + /// + /// This can be used to prune containers based on structures such as Bloom + /// Filters which can test set membership quickly. + /// + /// The returned array has one row for each container, with the following: + /// * `true` if the value of column CERTAINLY IS one of `values` + /// * `false` if the value of column CERTAINLY IS NOT one of `values` + /// * `null` if the value of column may or may not be in values + /// + /// If these statistics can not determine column membership for any + /// container, return `None` (the default). + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn contains( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } } /// Evaluates filter expressions on statistics, rather than the actual data. If @@ -129,10 +159,12 @@ pub struct PruningPredicate { schema: SchemaRef, /// Actual pruning predicate (rewritten in terms of column min/max statistics) predicate_expr: Arc, - /// The statistics required to evaluate this predicate - required_columns: RequiredStatColumns, + /// The statistics required to evaluate `predicate_expr` + required_columns: RequiredColumns, /// Original physical predicate from which this predicate expr is derived (required for serialization) orig_expr: Arc, + /// Any col = literal expressions + literal_guarantees: Vec, } impl PruningPredicate { @@ -157,14 +189,18 @@ impl PruningPredicate { /// `(column_min / 2) <= 4 && 4 <= (column_max / 2))` pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { // build predicate expression once - let mut required_columns = RequiredStatColumns::new(); + let mut required_columns = RequiredColumns::new(); let predicate_expr = build_predicate_expression(&expr, schema.as_ref(), &mut required_columns); + + let literal_guarantees = LiteralGuarantee::analyze(&expr); + Ok(Self { schema, predicate_expr, required_columns, orig_expr: expr, + literal_guarantees, }) } @@ -183,40 +219,36 @@ impl PruningPredicate { /// /// [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier pub fn prune(&self, statistics: &S) -> Result> { + let mut builder = BoolVecBuilder::new(statistics.num_containers()); + + // First, check any expr_op_literals + for literal_guarantee in &self.literal_guarantees { + let LiteralGuarantee { + column, + guarantee, + literals, + } = literal_guarantee; + // Can the statistics tell us anything about this column? + if let Some(results) = statistics.contains(column, literals) { + match guarantee { + Guarantee::In => builder.append_array(&results), + Guarantee::NotIn => { + let results = arrow::compute::not(&results)?; + builder.append_array(&results) + } + } + } + } + // build a RecordBatch that contains the min/max values in the - // appropriate statistics columns + // appropriate statistics columns for the min/max predicate let statistics_batch = build_statistics_record_batch(statistics, &self.required_columns)?; - // Evaluate the pruning predicate on that record batch. - // - // Use true when the result of evaluating a predicate - // expression on a row group is null (aka `None`). Null can - // arise when the statistics are unknown or some calculation - // in the predicate means we don't know for sure if the row - // group can be filtered out or not. To maintain correctness - // the row group must be kept and thus `true` is returned. - match self.predicate_expr.evaluate(&statistics_batch)? { - ColumnarValue::Array(array) => { - let predicate_array = downcast_value!(array, BooleanArray); + // Evaluate the pruning predicate on that record batch and append any results to the builder + builder.append_value(self.predicate_expr.evaluate(&statistics_batch)?); - Ok(predicate_array - .into_iter() - .map(|x| x.unwrap_or(true)) // None -> true per comments above - .collect::>()) - } - // result was a column - ColumnarValue::Scalar(ScalarValue::Boolean(v)) => { - let v = v.unwrap_or(true); // None -> true per comments above - Ok(vec![v; statistics.num_containers()]) - } - other => { - internal_err!( - "Unexpected result of pruning predicate evaluation. Expected Boolean array \ - or scalar but got {other:?}" - ) - } - } + Ok(builder.build()) } /// Return a reference to the input schema @@ -239,9 +271,79 @@ impl PruningPredicate { is_always_true(&self.predicate_expr) } - pub(crate) fn required_columns(&self) -> &RequiredStatColumns { + pub(crate) fn required_columns(&self) -> &RequiredColumns { &self.required_columns } + + /// returns the names of the columns that are known to be a constant (and + /// that may be used as part of a Contains query + pub fn literal_columns(&self) -> Vec { + let mut seen = HashSet::new(); + self.literal_guarantees + .iter() + .map(|e| &e.column.name) + // avoid duplicates + .filter(|name| seen.insert(*name)) + .map(|s| s.to_string()) + .collect() + } +} + +/// Builds a Vec that is true if container CERTAINLY DOES NOT pass the +/// predicate, and false if it MAY pass the predicate +/// +/// Use true when the result of evaluating a predicate +/// expression on a row group is null (aka `None`). Null can +/// arise when the statistics are unknown or some calculation +/// in the predicate means we don't know for sure if the row +/// group can be filtered out or not. To maintain correctness +/// the row group must be kept and thus `true` is returned. +#[derive(Debug)] +struct BoolVecBuilder { + // true if the container may pass the predicate, false if we know for sure + // it did not pass the predicate + inner: Vec, +} + +impl BoolVecBuilder { + fn new(num_containers: usize) -> Self { + Self { + inner: vec![true; num_containers], + } + } + + /// Combines the results in an array to the currently in progress array + fn append_array(&mut self, array: &BooleanArray) { + assert_eq!(array.len(), self.inner.len()); + // set any locations to false if we know for sure they did not pass the predicate + for (cur, new) in self.inner.iter_mut().zip(array.iter()) { + if let Some(false) = new { + *cur = false; + } + } + } + + /// Combines the results in the [`ColumnarValue`] to the currently in progress array + fn append_value(&mut self, value: ColumnarValue) { + match value { + ColumnarValue::Array(array) => { + self.append_array(array.as_boolean()); + } + ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))) => { + // False means all containers can not pass the predicate + self.inner = vec![false; self.inner.len()]; + } + _ => { + // Null or true means we don't know if the container can pass the predicate + // so we must keep it + } + } + } + + /// Convert this builder into a Vec of bools + fn build(self) -> Vec { + self.inner + } } fn is_always_true(expr: &Arc) -> bool { @@ -257,21 +359,21 @@ fn is_always_true(expr: &Arc) -> bool { /// Handles creating references to the min/max statistics /// for columns as well as recording which statistics are needed #[derive(Debug, Default, Clone)] -pub(crate) struct RequiredStatColumns { +pub(crate) struct RequiredColumns { /// The statistics required to evaluate this predicate: /// * The unqualified column in the input schema /// * Statistics type (e.g. Min or Max or Null_Count) /// * The field the statistics value should be placed in for - /// pruning predicate evaluation + /// pruning predicate evaluation (e.g. `min_value` or `max_value`) columns: Vec<(phys_expr::Column, StatisticsType, Field)>, } -impl RequiredStatColumns { +impl RequiredColumns { fn new() -> Self { Self::default() } - /// Returns number of unique columns. + /// Returns number of unique columns pub(crate) fn n_columns(&self) -> usize { self.iter() .map(|(c, _s, _f)| c) @@ -325,11 +427,10 @@ impl RequiredStatColumns { // only add statistics column if not previously added if need_to_insert { - let stat_field = Field::new( - stat_column.name(), - field.data_type().clone(), - field.is_nullable(), - ); + // may be null if statistics are not present + let nullable = true; + let stat_field = + Field::new(stat_column.name(), field.data_type().clone(), nullable); self.columns.push((column.clone(), stat_type, stat_field)); } rewrite_column_expr(column_expr.clone(), column, &stat_column) @@ -372,7 +473,7 @@ impl RequiredStatColumns { } } -impl From> for RequiredStatColumns { +impl From> for RequiredColumns { fn from(columns: Vec<(phys_expr::Column, StatisticsType, Field)>) -> Self { Self { columns } } @@ -405,7 +506,7 @@ impl From> for RequiredStatColum /// ``` fn build_statistics_record_batch( statistics: &S, - required_columns: &RequiredStatColumns, + required_columns: &RequiredColumns, ) -> Result { let mut fields = Vec::::new(); let mut arrays = Vec::::new(); @@ -461,7 +562,7 @@ struct PruningExpressionBuilder<'a> { op: Operator, scalar_expr: Arc, field: &'a Field, - required_columns: &'a mut RequiredStatColumns, + required_columns: &'a mut RequiredColumns, } impl<'a> PruningExpressionBuilder<'a> { @@ -470,7 +571,7 @@ impl<'a> PruningExpressionBuilder<'a> { right: &'a Arc, op: Operator, schema: &'a Schema, - required_columns: &'a mut RequiredStatColumns, + required_columns: &'a mut RequiredColumns, ) -> Result { // find column name; input could be a more complicated expression let left_columns = collect_columns(left); @@ -685,7 +786,7 @@ fn reverse_operator(op: Operator) -> Result { fn build_single_column_expr( column: &phys_expr::Column, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, is_not: bool, // if true, treat as !col ) -> Option> { let field = schema.field_with_name(column.name()).ok()?; @@ -726,7 +827,7 @@ fn build_single_column_expr( fn build_is_null_column_expr( expr: &Arc, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Option> { if let Some(col) = expr.as_any().downcast_ref::() { let field = schema.field_with_name(col.name()).ok()?; @@ -756,7 +857,7 @@ fn build_is_null_column_expr( fn build_predicate_expression( expr: &Arc, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Arc { // Returned for unsupported expressions. Such expressions are // converted to TRUE. @@ -1184,7 +1285,7 @@ mod tests { #[test] fn test_build_statistics_record_batch() { // Request a record batch with of s1_min, s2_max, s3_max, s3_min - let required_columns = RequiredStatColumns::from(vec![ + let required_columns = RequiredColumns::from(vec![ // min of original column s1, named s1_min ( phys_expr::Column::new("s1", 1), @@ -1256,7 +1357,7 @@ mod tests { // which is what Parquet does // Request a record batch with of s1_min as a timestamp - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new( @@ -1288,7 +1389,7 @@ mod tests { #[test] fn test_build_statistics_no_required_stats() { - let required_columns = RequiredStatColumns::new(); + let required_columns = RequiredColumns::new(); let statistics = OneContainerStats { min_values: Some(Arc::new(Int64Array::from(vec![Some(10)]))), @@ -1306,7 +1407,7 @@ mod tests { // Test requesting a Utf8 column when the stats return some other type // Request a record batch with of s1_min as a timestamp - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new("s1_min", DataType::Utf8, true), @@ -1335,7 +1436,7 @@ mod tests { #[test] fn test_build_statistics_inconsistent_length() { // return an inconsistent length to the actual statistics arrays - let required_columns = RequiredStatColumns::from(vec![( + let required_columns = RequiredColumns::from(vec![( phys_expr::Column::new("s1", 3), StatisticsType::Min, Field::new("s1_min", DataType::Int64, true), @@ -1366,20 +1467,14 @@ mod tests { // test column on the left let expr = col("c1").eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1392,20 +1487,14 @@ mod tests { // test column on the left let expr = col("c1").not_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).not_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1418,20 +1507,14 @@ mod tests { // test column on the left let expr = col("c1").gt(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).lt(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1444,19 +1527,13 @@ mod tests { // test column on the left let expr = col("c1").gt_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).lt_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1469,20 +1546,14 @@ mod tests { // test column on the left let expr = col("c1").lt(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).gt(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1495,19 +1566,13 @@ mod tests { // test column on the left let expr = col("c1").lt_eq(lit(1)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).gt_eq(col("c1")); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1523,11 +1588,8 @@ mod tests { // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); let expected_expr = "c1_min@0 < 1"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1542,11 +1604,8 @@ mod tests { // test OR operator joining supported c1 < 1 expression and unsupported c2 % 2 = 0 expression let expr = col("c1").lt(lit(1)).or(col("c2").rem(lit(2)).eq(lit(0))); let expected_expr = "true"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1558,11 +1617,8 @@ mod tests { let expected_expr = "true"; let expr = col("c1").not(); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1574,11 +1630,8 @@ mod tests { let expected_expr = "NOT c1_min@0 AND c1_max@1"; let expr = col("c1").not(); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1590,11 +1643,8 @@ mod tests { let expected_expr = "c1_min@0 OR c1_max@1"; let expr = col("c1"); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1608,11 +1658,8 @@ mod tests { // DF doesn't support arithmetic on boolean columns so // this predicate will error when evaluated let expr = col("c1").lt(lit(true)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1624,7 +1671,7 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), ]); - let mut required_columns = RequiredStatColumns::new(); + let mut required_columns = RequiredColumns::new(); // c1 < 1 and (c2 = 2 or c2 = 3) let expr = col("c1") .lt(lit(1)) @@ -1640,7 +1687,7 @@ mod tests { ( phys_expr::Column::new("c1", 0), StatisticsType::Min, - c1_min_field + c1_min_field.with_nullable(true) // could be nullable if stats are not present ) ); // c2 = 2 should add c2_min and c2_max @@ -1650,7 +1697,7 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::Min, - c2_min_field + c2_min_field.with_nullable(true) // could be nullable if stats are not present ) ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); @@ -1659,7 +1706,7 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::Max, - c2_max_field + c2_max_field.with_nullable(true) // could be nullable if stats are not present ) ); // c2 = 3 shouldn't add any new statistics fields @@ -1681,11 +1728,8 @@ mod tests { false, )); let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_min@0 <= 3 AND 3 <= c1_max@1"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1700,11 +1744,8 @@ mod tests { // test c1 in() let expr = Expr::InList(InList::new(Box::new(col("c1")), vec![], false)); let expected_expr = "true"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1725,11 +1766,8 @@ mod tests { let expected_expr = "(c1_min@0 != 1 OR 1 != c1_max@1) \ AND (c1_min@0 != 2 OR 2 != c1_max@1) \ AND (c1_min@0 != 3 OR 3 != c1_max@1)"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1743,20 +1781,14 @@ mod tests { // test column on the left let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1)))); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"), DataType::Int64)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); let expected_expr = "TRY_CAST(c1_max@0 AS Int64) > 1"; @@ -1764,21 +1796,15 @@ mod tests { // test column on the left let expr = try_cast(col("c1"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(1)))); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"), DataType::Int64)); - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -1798,11 +1824,8 @@ mod tests { false, )); let expected_expr = "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); let expr = Expr::InList(InList::new( @@ -1818,11 +1841,8 @@ mod tests { "(CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) \ AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) \ AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))"; - let predicate_expr = test_build_predicate_expression( - &expr, - &schema, - &mut RequiredStatColumns::new(), - ); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) @@ -2468,7 +2488,7 @@ mod tests { fn test_build_predicate_expression( expr: &Expr, schema: &Schema, - required_columns: &mut RequiredStatColumns, + required_columns: &mut RequiredColumns, ) -> Arc { let expr = logical2physical(expr, schema); build_predicate_expression(&expr, schema, required_columns) diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs new file mode 100644 index 000000000000..f7b371b6a5f4 --- /dev/null +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -0,0 +1,492 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`LiteralGuarantee`] to analyze predicates and determine if a column is a +//constant. + +use crate::utils::split_disjunction; +use crate::{split_conjunction, PhysicalExpr}; +use datafusion_common::{Column, ScalarValue}; +use datafusion_expr::Operator; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +/// Represents a predicate where it is known that a column is either: +/// +/// 1. One of particular set of values. For example, `(a = 1)`, `(a = 1 OR a = +/// 2) or `a IN (1, 2, 3)` +/// +/// 2. Not one of a particular set of values. For example, `(a != 1)`, `(a != 1 +/// AND a != 2)` or `a NOT IN (1, 2, 3)` +#[derive(Debug, Clone, PartialEq)] +pub struct LiteralGuarantee { + pub column: Column, + pub guarantee: Guarantee, + pub literals: HashSet, +} + +/// What can be guaranteed about the values? +#[derive(Debug, Clone, PartialEq)] +pub enum Guarantee { + /// `column` is one of a set of constant values + In, + /// `column` is NOT one of a set of constant values + NotIn, +} + +impl LiteralGuarantee { + /// Create a new instance of the guarantee if the provided operator is supported + fn try_new<'a>( + column_name: impl Into, + op: &Operator, + literals: impl IntoIterator, + ) -> Option { + let guarantee = match op { + Operator::Eq => Guarantee::In, + Operator::NotEq => Guarantee::NotIn, + _ => return None, + }; + + let literals: HashSet<_> = literals.into_iter().cloned().collect(); + + Some(Self { + column: Column::from_name(column_name), + guarantee, + literals, + }) + } + + /// return a list of `LiteralGuarantees` that can be deduced for this + /// expression. + /// + /// `expr` should be a boolean expression, for example a filter expression + /// + /// Notes: this API assumes the expression has already been simplified and + /// returns duplicate guarantees for expressions like `a = 1 AND a = 1`. + pub fn analyze(expr: &Arc) -> Vec { + split_conjunction(expr) + .into_iter() + .fold(GuaranteeBuilder::new(), |builder, expr| { + if let Some(cel) = ColOpLit::try_new(expr) { + return builder.aggregate_conjunct(cel); + } else { + // look for pattern like + // (col literal) OR (col literal) ... + let disjunctions = split_disjunction(expr); + + let terms = disjunctions + .iter() + .filter_map(|expr| ColOpLit::try_new(expr)) + .collect::>(); + + if terms.is_empty() { + return builder; + } + + if terms.len() != disjunctions.len() { + // not all terms are of the form (col literal) + return builder; + } + + // if all terms are 'col literal' then we can say something about the column + let first_term = &terms[0]; + if terms.iter().all(|term| { + term.col.name() == first_term.col.name() + && term.op == first_term.op + }) { + builder.aggregate_multi_conjunct( + first_term.col, + first_term.op, + terms.iter().map(|term| term.lit.value()), + ) + } else { + // ignore it + builder + } + } + }) + .build() + } +} + +/// Combines conjuncts together into guarantees, preserving insert order +struct GuaranteeBuilder<'a> { + /// List of guarantees that have been created so far + /// if we have determined a subsequent conjunct invalidates a guarantee + /// e.g. `a = foo AND a = bar` then the relevant guarantee will be None + guarantees: Vec>, + + // Key is the column name, type and value is the index into `guarantees` + map: HashMap<(&'a crate::expressions::Column, &'a Operator), usize>, +} + +impl<'a> GuaranteeBuilder<'a> { + fn new() -> Self { + Self { + guarantees: vec![], + map: HashMap::new(), + } + } + + /// Aggregate a new single guarantee to this builder combining with existing guarantees + /// if possible + fn aggregate_conjunct(self, col_op_lit: ColOpLit<'a>) -> Self { + self.aggregate_multi_conjunct( + col_op_lit.col, + col_op_lit.op, + [col_op_lit.lit.value()], + ) + } + + /// Aggreates a new single new guarantee with multiple literals `a IN (1,2,3)` or `a NOT IN (1,2,3)`. So the new values are combined with OR + fn aggregate_multi_conjunct( + mut self, + col: &'a crate::expressions::Column, + op: &'a Operator, + new_values: impl IntoIterator, + ) -> Self { + let key = (col, op); + if let Some(index) = self.map.get(&key) { + // already have a guarantee for this column + let entry = &mut self.guarantees[*index]; + + let Some(existing) = entry else { + // guarantee has been previously invalidated, nothing to do + return self; + }; + + // can only combine conjuncts if we have `a != foo AND a != bar`. + // `a = foo AND a = bar` is not correct. Also, can't extend with more than one value. + match existing.guarantee { + Guarantee::NotIn => { + // can extend if only single literal, otherwise invalidate + let new_values: HashSet<_> = new_values.into_iter().collect(); + if new_values.len() == 1 { + existing.literals.extend(new_values.into_iter().cloned()) + } else { + // this is like (a != foo AND (a != bar OR a != baz)). + // We can't combine the (a!=bar OR a!=baz) part, but it + // also doesn't invalidate a != foo guarantee. + } + } + Guarantee::In => { + // for an IN guarantee, it is ok if the value is the same + // e.g. `a = foo AND a = foo` but not if the value is different + // e.g. `a = foo AND a = bar` + if new_values + .into_iter() + .all(|new_value| existing.literals.contains(new_value)) + { + // all values are already in the set + } else { + // at least one was not, so invalidate the guarantee + *entry = None; + } + } + } + } else { + // This is a new guarantee + let new_values: HashSet<_> = new_values.into_iter().collect(); + + // new_values are combined with OR, so we can only create a + // multi-column guarantee for `=` (or a single value). + // (e.g. ignore `a != foo OR a != bar`) + if op == &Operator::Eq || new_values.len() == 1 { + if let Some(guarantee) = + LiteralGuarantee::try_new(col.name(), op, new_values) + { + // add it to the list of guarantees + self.guarantees.push(Some(guarantee)); + self.map.insert(key, self.guarantees.len() - 1); + } + } + } + + self + } + + /// Return all guarantees that have been created so far + fn build(self) -> Vec { + // filter out any guarantees that have been invalidated + self.guarantees.into_iter().flatten().collect() + } +} + +/// Represents a single `col literal` expression +struct ColOpLit<'a> { + col: &'a crate::expressions::Column, + op: &'a Operator, + lit: &'a crate::expressions::Literal, +} + +impl<'a> ColOpLit<'a> { + /// Returns Some(ColEqLit) if the expression is either: + /// 1. `col literal` + /// 2. `literal col` + /// + /// Returns None otherwise + fn try_new(expr: &'a Arc) -> Option { + let binary_expr = expr + .as_any() + .downcast_ref::()?; + + let (left, op, right) = ( + binary_expr.left().as_any(), + binary_expr.op(), + binary_expr.right().as_any(), + ); + + if let (Some(col), Some(lit)) = ( + left.downcast_ref::(), + right.downcast_ref::(), + ) { + Some(Self { col, op, lit }) + } + // literal col + else if let (Some(lit), Some(col)) = ( + left.downcast_ref::(), + right.downcast_ref::(), + ) { + Some(Self { col, op, lit }) + } else { + None + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::create_physical_expr; + use crate::execution_props::ExecutionProps; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::ToDFSchema; + use datafusion_expr::expr_fn::*; + use datafusion_expr::{lit, Expr}; + use std::sync::OnceLock; + + #[test] + fn test_literal() { + // a single literal offers no guarantee + test_analyze(lit(true), vec![]) + } + + #[test] + fn test_single() { + // a = "foo" + test_analyze(col("a").eq(lit("foo")), vec![in_guarantee("a", ["foo"])]); + // "foo" = a + test_analyze(lit("foo").eq(col("a")), vec![in_guarantee("a", ["foo"])]); + // a != "foo" + test_analyze( + col("a").not_eq(lit("foo")), + vec![not_in_guarantee("a", ["foo"])], + ); + // a != "foo" + test_analyze( + lit("foo").not_eq(col("a")), + vec![not_in_guarantee("a", ["foo"])], + ); + } + + #[test] + fn test_conjunction() { + // a = "foo" AND b = 1 + test_analyze( + col("a").eq(lit("foo")).and(col("b").eq(lit(1))), + vec![ + // should find both column guarantees + in_guarantee("a", ["foo"]), + in_guarantee("b", [1]), + ], + ); + // a != "foo" AND b != 1 + test_analyze( + col("a").not_eq(lit("foo")).and(col("b").not_eq(lit(1))), + // should find both column guarantees + vec![not_in_guarantee("a", ["foo"]), not_in_guarantee("b", [1])], + ); + // a = "foo" AND a = "bar" + test_analyze( + col("a").eq(lit("foo")).and(col("a").eq(lit("bar"))), + // this predicate is impossible ( can't be both foo and bar), + vec![], + ); + // a = "foo" AND b != "bar" + test_analyze( + col("a").eq(lit("foo")).and(col("a").not_eq(lit("bar"))), + vec![in_guarantee("a", ["foo"]), not_in_guarantee("a", ["bar"])], + ); + // a != "foo" AND a != "bar" + test_analyze( + col("a").not_eq(lit("foo")).and(col("a").not_eq(lit("bar"))), + // know it isn't "foo" or "bar" + vec![not_in_guarantee("a", ["foo", "bar"])], + ); + // a != "foo" AND a != "bar" and a != "baz" + test_analyze( + col("a") + .not_eq(lit("foo")) + .and(col("a").not_eq(lit("bar"))) + .and(col("a").not_eq(lit("baz"))), + // know it isn't "foo" or "bar" or "baz" + vec![not_in_guarantee("a", ["foo", "bar", "baz"])], + ); + // a = "foo" AND a = "foo" + let expr = col("a").eq(lit("foo")); + test_analyze(expr.clone().and(expr), vec![in_guarantee("a", ["foo"])]); + // b > 5 AND b = 10 (should get an b = 10 guarantee) + test_analyze( + col("b").gt(lit(5)).and(col("b").eq(lit(10))), + vec![in_guarantee("b", [10])], + ); + + // a != "foo" and (a != "bar" OR a != "baz") + test_analyze( + col("a") + .not_eq(lit("foo")) + .and(col("a").not_eq(lit("bar")).or(col("a").not_eq(lit("baz")))), + // a is not foo (we can't represent other knowledge about a) + vec![not_in_guarantee("a", ["foo"])], + ); + } + + #[test] + fn test_disjunction() { + // a = "foo" OR b = 1 + test_analyze( + col("a").eq(lit("foo")).or(col("b").eq(lit(1))), + // no can't have a single column guarantee (if a = "foo" then b != 1) etc + vec![], + ); + // a != "foo" OR b != 1 + test_analyze( + col("a").not_eq(lit("foo")).or(col("b").not_eq(lit(1))), + // No single column guarantee + vec![], + ); + // a = "foo" OR a = "bar" + test_analyze( + col("a").eq(lit("foo")).or(col("a").eq(lit("bar"))), + vec![in_guarantee("a", ["foo", "bar"])], + ); + // a = "foo" OR a = "foo" + test_analyze( + col("a").eq(lit("foo")).or(col("a").eq(lit("foo"))), + vec![in_guarantee("a", ["foo"])], + ); + // a != "foo" OR a != "bar" + test_analyze( + col("a").not_eq(lit("foo")).or(col("a").not_eq(lit("bar"))), + // can't represent knowledge about a in this case + vec![], + ); + // a = "foo" OR a = "bar" OR a = "baz" + test_analyze( + col("a") + .eq(lit("foo")) + .or(col("a").eq(lit("bar"))) + .or(col("a").eq(lit("baz"))), + vec![in_guarantee("a", ["foo", "bar", "baz"])], + ); + // (a = "foo" OR a = "bar") AND (a = "baz)" + test_analyze( + (col("a").eq(lit("foo")).or(col("a").eq(lit("bar")))) + .and(col("a").eq(lit("baz"))), + // this could potentially be represented as 2 constraints with a more + // sophisticated analysis + vec![], + ); + // (a = "foo" OR a = "bar") AND (b = 1) + test_analyze( + (col("a").eq(lit("foo")).or(col("a").eq(lit("bar")))) + .and(col("b").eq(lit(1))), + vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [1])], + ); + // (a = "foo" OR a = "bar") OR (b = 1) + test_analyze( + col("a") + .eq(lit("foo")) + .or(col("a").eq(lit("bar"))) + .or(col("b").eq(lit(1))), + // can't represent knowledge about a or b in this case + vec![], + ); + } + + // TODO file ticket to add tests for : + // a IN (...) + // b NOT IN (...) + + /// Tests that analyzing expr results in the expected guarantees + fn test_analyze(expr: Expr, expected: Vec) { + println!("Begin analyze of {expr}"); + let schema = schema(); + let physical_expr = logical2physical(&expr, &schema); + + let actual = LiteralGuarantee::analyze(&physical_expr); + assert_eq!( + expected, actual, + "expr: {expr}\ + \n\nexpected: {expected:#?}\ + \n\nactual: {actual:#?}\ + \n\nexpr: {expr:#?}\ + \n\nphysical_expr: {physical_expr:#?}" + ); + } + + /// Guarantee that column is a specified value + fn in_guarantee<'a, I, S>(column: &str, literals: I) -> LiteralGuarantee + where + I: IntoIterator, + S: Into + 'a, + { + let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); + LiteralGuarantee::try_new(column, &Operator::Eq, literals.iter()).unwrap() + } + + /// Guarantee that column is NOT a specified value + fn not_in_guarantee<'a, I, S>(column: &str, literals: I) -> LiteralGuarantee + where + I: IntoIterator, + S: Into + 'a, + { + let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); + LiteralGuarantee::try_new(column, &Operator::NotEq, literals.iter()).unwrap() + } + + /// Convert a logical expression to a physical expression (without any simplification, etc) + fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { + let df_schema = schema.clone().to_dfschema().unwrap(); + let execution_props = ExecutionProps::new(); + create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + } + + // Schema for testing + fn schema() -> SchemaRef { + SCHEMA + .get_or_init(|| { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])) + }) + .clone() + } + + static SCHEMA: OnceLock = OnceLock::new(); +} diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils/mod.rs similarity index 96% rename from datafusion/physical-expr/src/utils.rs rename to datafusion/physical-expr/src/utils/mod.rs index 71a7ff5fb778..87ef36558b96 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +mod guarantee; +pub use guarantee::{Guarantee, LiteralGuarantee}; + use std::borrow::Borrow; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -41,25 +44,29 @@ use petgraph::stable_graph::StableGraph; pub fn split_conjunction( predicate: &Arc, ) -> Vec<&Arc> { - split_conjunction_impl(predicate, vec![]) + split_impl(Operator::And, predicate, vec![]) } -fn split_conjunction_impl<'a>( +/// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs. +/// +/// For example, split "a1 = a2 OR b1 <= b2 OR c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"] +pub fn split_disjunction( + predicate: &Arc, +) -> Vec<&Arc> { + split_impl(Operator::Or, predicate, vec![]) +} + +fn split_impl<'a>( + operator: Operator, predicate: &'a Arc, mut exprs: Vec<&'a Arc>, ) -> Vec<&'a Arc> { match predicate.as_any().downcast_ref::() { - Some(binary) => match binary.op() { - Operator::And => { - let exprs = split_conjunction_impl(binary.left(), exprs); - split_conjunction_impl(binary.right(), exprs) - } - _ => { - exprs.push(predicate); - exprs - } - }, - None => { + Some(binary) if binary.op() == &operator => { + let exprs = split_impl(operator, binary.left(), exprs); + split_impl(operator, binary.right(), exprs) + } + Some(_) | None => { exprs.push(predicate); exprs }