diff --git a/rust/src/delta_datafusion/expr.rs b/rust/src/delta_datafusion/expr.rs index d60fe6666c..815b01831f 100644 --- a/rust/src/delta_datafusion/expr.rs +++ b/rust/src/delta_datafusion/expr.rs @@ -21,16 +21,82 @@ //! Utility functions for Datafusion's Expressions -use std::fmt::{self, Display, Formatter, Write}; +use std::{ + fmt::{self, Display, Formatter, Write}, + sync::Arc, +}; -use datafusion_common::ScalarValue; +use arrow_schema::DataType; +use datafusion::execution::context::SessionState; +use datafusion_common::Result as DFResult; +use datafusion_common::{config::ConfigOptions, DFSchema, Result, ScalarValue, TableReference}; use datafusion_expr::{ expr::{InList, ScalarUDF}, - Between, BinaryExpr, Expr, Like, + AggregateUDF, Between, BinaryExpr, Cast, Expr, Like, TableSource, }; +use datafusion_sql::planner::{ContextProvider, SqlToRel}; use sqlparser::ast::escape_quoted_string; +use sqlparser::dialect::GenericDialect; +use sqlparser::parser::Parser; +use sqlparser::tokenizer::Tokenizer; + +use crate::{DeltaResult, DeltaTableError}; + +pub(crate) struct DeltaContextProvider<'a> { + state: &'a SessionState, +} + +impl<'a> ContextProvider for DeltaContextProvider<'a> { + fn get_table_provider(&self, _name: TableReference) -> DFResult> { + unimplemented!() + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.state.scalar_functions().get(name).cloned() + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions().get(name).cloned() + } -use crate::DeltaTableError; + fn get_variable_type(&self, _var: &[String]) -> Option { + unimplemented!() + } + + fn options(&self) -> &ConfigOptions { + self.state.config_options() + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } +} + +/// Parse a string predicate into an `Expr` +pub(crate) fn parse_predicate_expression( + schema: &DFSchema, + expr: impl AsRef, + df_state: &SessionState, +) -> DeltaResult { + let dialect = &GenericDialect {}; + let mut tokenizer = Tokenizer::new(dialect, expr.as_ref()); + let tokens = tokenizer + .tokenize() + .map_err(|err| DeltaTableError::GenericError { + source: Box::new(err), + })?; + let sql = Parser::new(dialect) + .with_tokens(tokens) + .parse_expr() + .map_err(|err| DeltaTableError::GenericError { + source: Box::new(err), + })?; + + let context_provider = DeltaContextProvider { state: df_state }; + let sql_to_rel = SqlToRel::new(&context_provider); + + Ok(sql_to_rel.sql_to_expr(sql, schema, &mut Default::default())?) +} struct SqlFormat<'a> { expr: &'a Expr, @@ -115,6 +181,9 @@ impl<'a> Display for SqlFormat<'a> { Expr::BinaryExpr(expr) => write!(f, "{}", BinaryExprFormat { expr }), Expr::ScalarFunction(func) => fmt_function(f, &func.fun.to_string(), false, &func.args), Expr::ScalarUDF(ScalarUDF { fun, args }) => fmt_function(f, &fun.name, false, args), + Expr::Cast(Cast { expr, data_type }) => { + write!(f, "arrow_cast({}, '{}')", SqlFormat { expr }, data_type) + } Expr::Between(Between { expr, negated, @@ -271,9 +340,10 @@ impl<'a> fmt::Display for ScalarValueFormat<'a> { mod test { use std::collections::HashMap; + use arrow_schema::DataType; use datafusion::prelude::SessionContext; use datafusion_common::{DFSchema, ScalarValue}; - use datafusion_expr::{col, decode, lit, substring, Expr, ExprSchemable}; + use datafusion_expr::{col, decode, lit, substring, Cast, Expr, ExprSchemable}; use crate::{DeltaOps, DeltaTable, Schema, SchemaDataType, SchemaField}; @@ -368,6 +438,13 @@ mod test { // String expression that we output must be parsable for conflict resolution. let tests = vec![ + simple!( + Expr::Cast(Cast { + expr: Box::new(lit(1_i64)), + data_type: DataType::Int32 + }), + "arrow_cast(1, 'Int32')".to_string() + ), simple!(col("value").eq(lit(3_i64)), "value = 3".to_string()), simple!(col("active").is_true(), "active IS TRUE".to_string()), simple!(col("active"), "active".to_string()), @@ -443,6 +520,23 @@ mod test { substring(col("modified"), lit(0_i64), lit(4_i64)).eq(lit("2021")), "substr(modified, 0, 4) = '2021'".to_string() ), + simple!( + col("value") + .cast_to::( + &arrow_schema::DataType::Utf8, + &table + .state + .input_schema() + .unwrap() + .as_ref() + .to_owned() + .try_into() + .unwrap() + ) + .unwrap() + .eq(lit("1")), + "arrow_cast(value, 'Utf8') = '1'".to_string() + ), ]; let session = SessionContext::new(); @@ -479,23 +573,6 @@ mod test { ))), "".to_string() ), - simple!( - col("value") - .cast_to::( - &arrow_schema::DataType::Utf8, - &table - .state - .input_schema() - .unwrap() - .as_ref() - .to_owned() - .try_into() - .unwrap() - ) - .unwrap() - .eq(lit("1")), - "CAST(value as STRING) = '1'".to_string() - ), ]; for test in unsupported_types { diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index d52dd26819..fa6f586ad0 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -16,8 +16,9 @@ //! ```rust ignore //! let table = open_table("../path/to/table")?; //! let (table, metrics) = DeltaOps(table) -//! .merge(source, col("id").eq(col("source.id"))) +//! .merge(source, col("target.id").eq(col("source.id"))) //! .with_source_alias("source") +//! .with_target_alias("target") //! .when_matched_update(|update| { //! update //! .update("value", col("source.value") + lit(1)) @@ -38,13 +39,14 @@ use std::time::{Instant, SystemTime, UNIX_EPOCH}; use arrow_schema::SchemaRef; use datafusion::error::Result as DataFusionResult; +use datafusion::logical_expr::build_join_schema; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::{ execution::context::SessionState, physical_plan::{ filter::FilterExec, joins::{ - utils::{build_join_schema, JoinFilter}, + utils::{build_join_schema as physical_build_join_schema, JoinFilter}, NestedLoopJoinExec, }, metrics::{MetricBuilder, MetricsSet}, @@ -53,7 +55,7 @@ use datafusion::{ }, prelude::{DataFrame, SessionContext}, }; -use datafusion_common::{Column, DFSchema, ScalarValue}; +use datafusion_common::{Column, DFField, DFSchema, ScalarValue, TableReference}; use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; use datafusion_physical_expr::{create_physical_expr, expressions, PhysicalExpr}; use futures::future::BoxFuture; @@ -62,7 +64,7 @@ use serde_json::{Map, Value}; use super::datafusion_utils::{into_expr, maybe_into_expr, Expression}; use super::transaction::commit; -use crate::delta_datafusion::expr::fmt_expr_to_sql; +use crate::delta_datafusion::expr::{fmt_expr_to_sql, parse_predicate_expression}; use crate::delta_datafusion::{parquet_scan_from_actions, register_store}; use crate::operations::datafusion_utils::MetricObserverExec; use crate::operations::write::write_execution_plan; @@ -89,13 +91,15 @@ pub struct MergeBuilder { /// The join predicate predicate: Expression, /// Operations to perform when a source record and target record match - match_operations: Vec, + match_operations: Vec, /// Operations to perform on source records when they do not pair with a target record - not_match_operations: Vec, + not_match_operations: Vec, /// Operations to perform on target records when they do not pair with a source record - not_match_source_operations: Vec, + not_match_source_operations: Vec, ///Prefix the source columns with a user provided prefix source_alias: Option, + ///Prefix target columns with a user provided prefix + target_alias: Option, /// A snapshot of the table's state. AKA the target table in the operation snapshot: DeltaTableState, /// The source data @@ -127,6 +131,7 @@ impl MergeBuilder { snapshot, object_store, source_alias: None, + target_alias: None, state: None, app_metadata: None, writer_properties: None, @@ -150,8 +155,9 @@ impl MergeBuilder { /// ```rust ignore /// let table = open_table("../path/to/table")?; /// let (table, metrics) = DeltaOps(table) - /// .merge(source, col("id").eq(col("source.id"))) + /// .merge(source, col("target.id").eq(col("source.id"))) /// .with_source_alias("source") + /// .with_target_alias("target") /// .when_matched_update(|update| { /// update /// .predicate(col("source.value").lt(lit(0))) @@ -170,13 +176,8 @@ impl MergeBuilder { F: FnOnce(UpdateBuilder) -> UpdateBuilder, { let builder = builder(UpdateBuilder::default()); - let op = MergeOperation::try_new( - &self.snapshot, - &self.state.as_ref(), - builder.predicate, - builder.updates, - OperationType::Update, - )?; + let op = + MergeOperationConfig::new(builder.predicate, builder.updates, OperationType::Update)?; self.match_operations.push(op); Ok(self) } @@ -192,8 +193,9 @@ impl MergeBuilder { /// ```rust ignore /// let table = open_table("../path/to/table")?; /// let (table, metrics) = DeltaOps(table) - /// .merge(source, col("id").eq(col("source.id"))) + /// .merge(source, col("target.id").eq(col("source.id"))) /// .with_source_alias("source") + /// .with_target_alias("target") /// .when_matched_delete(|delete| { /// delete.predicate(col("source.delete")) /// })? @@ -204,9 +206,7 @@ impl MergeBuilder { F: FnOnce(DeleteBuilder) -> DeleteBuilder, { let builder = builder(DeleteBuilder::default()); - let op = MergeOperation::try_new( - &self.snapshot, - &self.state.as_ref(), + let op = MergeOperationConfig::new( builder.predicate, HashMap::default(), OperationType::Delete, @@ -226,8 +226,9 @@ impl MergeBuilder { /// ```rust ignore /// let table = open_table("../path/to/table")?; /// let (table, metrics) = DeltaOps(table) - /// .merge(source, col("id").eq(col("source.id"))) + /// .merge(source, col("target.id").eq(col("source.id"))) /// .with_source_alias("source") + /// .with_target_alias("target") /// .when_not_matched_insert(|insert| { /// insert /// .set("id", col("source.id")) @@ -241,13 +242,7 @@ impl MergeBuilder { F: FnOnce(InsertBuilder) -> InsertBuilder, { let builder = builder(InsertBuilder::default()); - let op = MergeOperation::try_new( - &self.snapshot, - &self.state.as_ref(), - builder.predicate, - builder.set, - OperationType::Insert, - )?; + let op = MergeOperationConfig::new(builder.predicate, builder.set, OperationType::Insert)?; self.not_match_operations.push(op); Ok(self) } @@ -266,8 +261,9 @@ impl MergeBuilder { /// ```rust ignore /// let table = open_table("../path/to/table")?; /// let (table, metrics) = DeltaOps(table) - /// .merge(source, col("id").eq(col("source.id"))) + /// .merge(source, col("target.id").eq(col("source.id"))) /// .with_source_alias("source") + /// .with_target_alias("target") /// .when_not_matched_by_source_update(|update| { /// update /// .update("active", lit(false)) @@ -280,13 +276,8 @@ impl MergeBuilder { F: FnOnce(UpdateBuilder) -> UpdateBuilder, { let builder = builder(UpdateBuilder::default()); - let op = MergeOperation::try_new( - &self.snapshot, - &self.state.as_ref(), - builder.predicate, - builder.updates, - OperationType::Update, - )?; + let op = + MergeOperationConfig::new(builder.predicate, builder.updates, OperationType::Update)?; self.not_match_source_operations.push(op); Ok(self) } @@ -302,8 +293,9 @@ impl MergeBuilder { /// ```rust ignore /// let table = open_table("../path/to/table")?; /// let (table, metrics) = DeltaOps(table) - /// .merge(source, col("id").eq(col("source.id"))) + /// .merge(source, col("target.id").eq(col("source.id"))) /// .with_source_alias("source") + /// .with_target_alias("target") /// .when_not_matched_by_source_delete(|delete| { /// delete /// })? @@ -314,9 +306,7 @@ impl MergeBuilder { F: FnOnce(DeleteBuilder) -> DeleteBuilder, { let builder = builder(DeleteBuilder::default()); - let op = MergeOperation::try_new( - &self.snapshot, - &self.state.as_ref(), + let op = MergeOperationConfig::new( builder.predicate, HashMap::default(), OperationType::Delete, @@ -331,6 +321,12 @@ impl MergeBuilder { self } + /// Rename columns in the target dataset to have a prefix of `alias`.`original column name` + pub fn with_target_alias(mut self, alias: S) -> Self { + self.target_alias = Some(alias.to_string()); + self + } + /// The Datafusion session state to use pub fn with_session_state(mut self, state: SessionState) -> Self { self.state = Some(state); @@ -443,6 +439,15 @@ enum OperationType { Copy, } +//Encapsute the User's Merge configuration for later processing +struct MergeOperationConfig { + /// Which records to update + predicate: Option, + /// How to update columns in a record that match the predicate + operations: HashMap, + r#type: OperationType, +} + struct MergeOperation { /// Which records to update predicate: Option, @@ -452,28 +457,71 @@ struct MergeOperation { } impl MergeOperation { - pub fn try_new( - snapshot: &DeltaTableState, - state: &Option<&SessionState>, + fn try_from( + config: MergeOperationConfig, + schema: &DFSchema, + state: &SessionState, + target_alias: &Option, + ) -> DeltaResult { + let mut ops = HashMap::with_capacity(config.operations.capacity()); + + for (column, expression) in config.operations.into_iter() { + // Normalize the column name to contain the target alias. If a table reference was provided ensure it's the target. + let column = match target_alias { + Some(alias) => { + let r = TableReference::bare(alias.to_owned()); + match column { + Column { + relation: None, + name, + } => Column { + relation: Some(r), + name, + }, + Column { + relation: Some(TableReference::Bare { table }), + name, + } => { + if table.eq(alias) { + Column { + relation: Some(r), + name, + } + } else { + return Err(DeltaTableError::Generic( + format!("Table alias '{table}' in column reference '{table}.{name}' unknown. Hint: You must reference the Delta Table with alias '{alias}'.") + )); + } + } + _ => { + return Err(DeltaTableError::Generic( + "Column must reference column in Delta table".into(), + )) + } + } + } + None => column, + }; + ops.insert(column, into_expr(expression, schema, state)?); + } + + Ok(MergeOperation { + predicate: maybe_into_expr(config.predicate, schema, state)?, + operations: ops, + r#type: config.r#type, + }) + } +} + +impl MergeOperationConfig { + pub fn new( predicate: Option, operations: HashMap, r#type: OperationType, ) -> DeltaResult { - let context = SessionContext::new(); - let mut s = &context.state(); - if let Some(df_state) = state { - s = df_state; - } - let predicate = maybe_into_expr(predicate, snapshot, s)?; - let mut _operations = HashMap::new(); - - for (column, expr) in operations { - _operations.insert(column, into_expr(expr, snapshot, s)?); - } - - Ok(MergeOperation { + Ok(MergeOperationConfig { predicate, - operations: _operations, + operations, r#type, }) } @@ -517,9 +565,10 @@ async fn execute( app_metadata: Option>, safe_cast: bool, source_alias: Option, - match_operations: Vec, - not_match_target_operations: Vec, - not_match_source_operations: Vec, + target_alias: Option, + match_operations: Vec, + not_match_target_operations: Vec, + not_match_source_operations: Vec, ) -> DeltaResult<((Vec, i64), MergeMetrics)> { let mut metrics = MergeMetrics::default(); let exec_start = Instant::now(); @@ -528,11 +577,6 @@ async fn execute( .current_metadata() .ok_or(DeltaTableError::NoMetadata)?; - let predicate = match predicate { - Expression::DataFusion(expr) => expr, - Expression::String(s) => snapshot.parse_predicate_expression(s, &state)?, - }; - let schema = snapshot.input_schema()?; // TODO: Given the join predicate, remove any expression that involve the @@ -567,16 +611,10 @@ async fn execute( let mut expressions: Vec<(Arc, String)> = Vec::new(); let source_schema = source_count.schema(); - let source_prefix = source_alias - .map(|mut s| { - s.push('.'); - s - }) - .unwrap_or_default(); for (i, field) in source_schema.fields().into_iter().enumerate() { expressions.push(( Arc::new(expressions::Column::new(field.name(), i)), - source_prefix.clone() + field.name(), + field.name().clone(), )); } expressions.push(( @@ -607,15 +645,54 @@ async fn execute( let target = Arc::new(CoalescePartitionsExec::new(target)); let source = Arc::new(CoalescePartitionsExec::new(source)); - let join_schema = build_join_schema(&source.schema(), &target.schema(), &JoinType::Full); + let source_schema = match &source_alias { + Some(alias) => { + DFSchema::try_from_qualified_schema(TableReference::bare(alias), &source.schema())? + } + None => DFSchema::try_from(source.schema().as_ref().to_owned())?, + }; + + let target_schema = match &target_alias { + Some(alias) => { + DFSchema::try_from_qualified_schema(TableReference::bare(alias), &target.schema())? + } + None => DFSchema::try_from(target.schema().as_ref().to_owned())?, + }; + + let join_schema_df = build_join_schema(&source_schema, &target_schema, &JoinType::Full)?; + + let join_schema = + physical_build_join_schema(&source.schema(), &target.schema(), &JoinType::Full); + let (join_schema, join_order) = (join_schema.0, join_schema.1); + + let predicate = match predicate { + Expression::DataFusion(expr) => expr, + Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?, + }; + + let match_operations: Vec = match_operations + .into_iter() + .map(|op| MergeOperation::try_from(op, &join_schema_df, &state, &target_alias)) + .collect::, DeltaTableError>>()?; + + let not_match_target_operations: Vec = not_match_target_operations + .into_iter() + .map(|op| MergeOperation::try_from(op, &join_schema_df, &state, &target_alias)) + .collect::, DeltaTableError>>()?; + + let not_match_source_operations: Vec = not_match_source_operations + .into_iter() + .map(|op| MergeOperation::try_from(op, &join_schema_df, &state, &target_alias)) + .collect::, DeltaTableError>>()?; + let predicate_expr = create_physical_expr( &predicate, - &join_schema.0.clone().try_into()?, - &join_schema.0, + &join_schema_df, + &join_schema, state.execution_props(), )?; - let join_filter = JoinFilter::new(predicate_expr, join_schema.1, join_schema.0); + let join_filter = JoinFilter::new(predicate_expr, join_order, join_schema); let join: Arc = Arc::new(NestedLoopJoinExec::try_new( source.clone(), target.clone(), @@ -740,13 +817,21 @@ async fn execute( let case = create_physical_expr( &case, - &join.schema().as_ref().to_owned().try_into()?, + &join_schema_df, &join.schema(), state.execution_props(), )?; expressions.push((case, OPERATION_COLUMN.to_owned())); let projection = Arc::new(ProjectionExec::try_new(expressions, join.clone())?); + let mut f = join_schema_df.fields().to_owned(); + f.push(DFField::new_unqualified( + OPERATION_COLUMN, + arrow_schema::DataType::Int64, + false, + )); + let project_schema_df = DFSchema::new_with_metadata(f, HashMap::new())?; + // Project again and include the original table schema plus a column to mark if row needs to be filtered before write let mut expressions: Vec<(Arc, String)> = Vec::new(); let schema = projection.schema(); @@ -758,15 +843,27 @@ async fn execute( } let mut projection_map = HashMap::new(); - for field in snapshot.schema().unwrap().get_fields() { + let mut f = project_schema_df.fields().clone(); + + for delta_field in snapshot.schema().unwrap().get_fields() { let mut when_expr = Vec::with_capacity(operations_size); let mut then_expr = Vec::with_capacity(operations_size); + let qualifier = match &target_alias { + Some(alias) => Some(TableReference::Bare { + table: alias.to_owned().into(), + }), + None => TableReference::none(), + }; + let name = delta_field.get_name(); + let column = Column::new(qualifier.clone(), name); + let field = project_schema_df.field_with_name(qualifier.as_ref(), name)?; + for (idx, (operations, _)) in ops.iter().enumerate() { let op = operations - .get(&field.get_name().to_owned().into()) + .get(&column) .map(|expr| expr.to_owned()) - .unwrap_or(col(field.get_name())); + .unwrap_or_else(|| col(column.clone())); when_expr.push(lit(idx as i32)); then_expr.push(op); @@ -782,13 +879,20 @@ async fn execute( let case = create_physical_expr( &case, - &projection.schema().as_ref().to_owned().try_into()?, + &project_schema_df, &projection.schema(), state.execution_props(), )?; - projection_map.insert(field.get_name(), expressions.len()); - expressions.push((case, "__delta_rs_c_".to_owned() + field.get_name())); + projection_map.insert(delta_field.get_name(), expressions.len()); + let name = "__delta_rs_c_".to_owned() + delta_field.get_name(); + + f.push(DFField::new_unqualified( + &name, + field.data_type().clone(), + true, + )); + expressions.push((case, name)); } let mut insert_when = Vec::with_capacity(ops.len()); @@ -873,7 +977,7 @@ async fn execute( } let schema = projection.schema(); - let input_dfschema = schema.as_ref().to_owned().try_into()?; + let input_dfschema = project_schema_df; expressions.push(( build_case( delete_when, @@ -884,6 +988,11 @@ async fn execute( )?, DELETE_COLUMN.to_owned(), )); + f.push(DFField::new_unqualified( + DELETE_COLUMN, + arrow_schema::DataType::Boolean, + true, + )); expressions.push(( build_case( @@ -895,6 +1004,12 @@ async fn execute( )?, TARGET_INSERT_COLUMN.to_owned(), )); + f.push(DFField::new_unqualified( + TARGET_INSERT_COLUMN, + arrow_schema::DataType::Boolean, + true, + )); + expressions.push(( build_case( update_when, @@ -905,6 +1020,12 @@ async fn execute( )?, TARGET_UPDATE_COLUMN.to_owned(), )); + f.push(DFField::new_unqualified( + TARGET_UPDATE_COLUMN, + arrow_schema::DataType::Boolean, + true, + )); + expressions.push(( build_case( target_delete_when, @@ -915,6 +1036,12 @@ async fn execute( )?, TARGET_DELETE_COLUMN.to_owned(), )); + f.push(DFField::new_unqualified( + TARGET_DELETE_COLUMN, + arrow_schema::DataType::Boolean, + true, + )); + expressions.push(( build_case( copy_when, @@ -925,6 +1052,11 @@ async fn execute( )?, TARGET_COPY_COLUMN.to_owned(), )); + f.push(DFField::new_unqualified( + TARGET_COPY_COLUMN, + arrow_schema::DataType::Boolean, + true, + )); let projection = Arc::new(ProjectionExec::try_new(expressions, projection.clone())?); @@ -963,9 +1095,11 @@ async fn execute( ); })); + let write_schema_df = DFSchema::new_with_metadata(f, HashMap::new())?; + let write_predicate = create_physical_expr( &(col(DELETE_COLUMN).is_false()), - &target_count_plan.schema().as_ref().to_owned().try_into()?, + &write_schema_df, &target_count_plan.schema(), state.execution_props(), )?; @@ -1095,6 +1229,7 @@ impl std::future::IntoFuture for MergeBuilder { this.app_metadata, this.safe_cast, this.source_alias, + this.target_alias, this.match_operations, this.not_match_operations, this.not_match_source_operations, @@ -1122,12 +1257,15 @@ mod tests { use arrow::datatypes::Schema as ArrowSchema; use arrow::record_batch::RecordBatch; use datafusion::assert_batches_sorted_eq; + use datafusion::prelude::DataFrame; use datafusion::prelude::SessionContext; use datafusion_expr::col; use datafusion_expr::lit; use serde_json::json; use std::sync::Arc; + use super::MergeMetrics; + async fn setup_table(partitions: Option>) -> DeltaTable { let table_schema = get_delta_schema(); @@ -1164,8 +1302,7 @@ mod tests { .unwrap() } - #[tokio::test] - async fn test_merge() { + async fn setup() -> (DeltaTable, DataFrame) { let schema = get_arrow_schema(&None); let table = setup_table(None).await; @@ -1188,10 +1325,44 @@ mod tests { ) .unwrap(); let source = ctx.read_batch(batch).unwrap(); + (table, source) + } + + async fn assert_merge(table: DeltaTable, metrics: MergeMetrics) { + assert_eq!(table.version(), 2); + assert_eq!(table.get_file_uris().count(), 1); + assert_eq!(metrics.num_target_files_added, 1); + assert_eq!(metrics.num_target_files_removed, 1); + assert_eq!(metrics.num_target_rows_copied, 1); + assert_eq!(metrics.num_target_rows_updated, 3); + assert_eq!(metrics.num_target_rows_inserted, 1); + assert_eq!(metrics.num_target_rows_deleted, 0); + assert_eq!(metrics.num_output_rows, 5); + assert_eq!(metrics.num_source_rows, 3); + + let expected = vec![ + "+----+-------+------------+", + "| id | value | modified |", + "+----+-------+------------+", + "| A | 2 | 2021-02-01 |", + "| B | 10 | 2021-02-02 |", + "| C | 20 | 2023-07-04 |", + "| D | 100 | 2021-02-02 |", + "| X | 30 | 2023-07-04 |", + "+----+-------+------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + + #[tokio::test] + async fn test_merge() { + let (table, source) = setup().await; let (mut table, metrics) = DeltaOps(table) - .merge(source, col("id").eq(col("source.id"))) + .merge(source, col("target.id").eq(col("source.id"))) .with_source_alias("source") + .with_target_alias("target") .when_matched_update(|update| { update .update("value", col("source.value")) @@ -1200,8 +1371,8 @@ mod tests { .unwrap() .when_not_matched_by_source_update(|update| { update - .predicate(col("value").eq(lit(1))) - .update("value", col("value") + lit(1)) + .predicate(col("target.value").eq(lit(1))) + .update("value", col("target.value") + lit(1)) }) .unwrap() .when_not_matched_insert(|insert| { @@ -1214,21 +1385,62 @@ mod tests { .await .unwrap(); - assert_eq!(table.version(), 2); - assert_eq!(table.get_file_uris().count(), 1); - assert_eq!(metrics.num_target_files_added, 1); - assert_eq!(metrics.num_target_files_removed, 1); - assert_eq!(metrics.num_target_rows_copied, 1); - assert_eq!(metrics.num_target_rows_updated, 3); - assert_eq!(metrics.num_target_rows_inserted, 1); - assert_eq!(metrics.num_target_rows_deleted, 0); - assert_eq!(metrics.num_output_rows, 5); - assert_eq!(metrics.num_source_rows, 3); + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[commit_info.len() - 1]; + let parameters = last_commit.operation_parameters.clone().unwrap(); + assert_eq!(parameters["predicate"], json!("target.id = source.id")); + assert_eq!( + parameters["matchedPredicates"], + json!(r#"[{"actionType":"update"}]"#) + ); + assert_eq!( + parameters["notMatchedPredicates"], + json!(r#"[{"actionType":"insert"}]"#) + ); + assert_eq!( + parameters["notMatchedBySourcePredicates"], + json!(r#"[{"actionType":"update","predicate":"target.value = 1"}]"#) + ); + + assert_merge(table, metrics).await; + } + + #[tokio::test] + async fn test_merge_str() { + // Validate that users can use string predicates + // Also validates that update and set operations can contain the target alias + let (table, source) = setup().await; + + let (mut table, metrics) = DeltaOps(table) + .merge(source, "target.id = source.id") + .with_source_alias("source") + .with_target_alias("target") + .when_matched_update(|update| { + update + .update("target.value", "source.value") + .update("modified", "source.modified") + }) + .unwrap() + .when_not_matched_by_source_update(|update| { + update + .predicate("target.value = arrow_cast(1, 'Int32')") + .update("value", "target.value + cast(1 as int)") + }) + .unwrap() + .when_not_matched_insert(|insert| { + insert + .set("target.id", "source.id") + .set("value", "source.value") + .set("modified", "source.modified") + }) + .unwrap() + .await + .unwrap(); let commit_info = table.history(None).await.unwrap(); let last_commit = &commit_info[commit_info.len() - 1]; let parameters = last_commit.operation_parameters.clone().unwrap(); - assert_eq!(parameters["predicate"], json!("id = source.id")); + assert_eq!(parameters["predicate"], json!("target.id = source.id")); assert_eq!( parameters["matchedPredicates"], json!(r#"[{"actionType":"update"}]"#) @@ -1239,22 +1451,127 @@ mod tests { ); assert_eq!( parameters["notMatchedBySourcePredicates"], - json!(r#"[{"actionType":"update","predicate":"value = 1"}]"#) + json!( + r#"[{"actionType":"update","predicate":"target.value = arrow_cast(1, 'Int32')"}]"# + ) ); - let expected = vec![ - "+----+-------+------------+", - "| id | value | modified |", - "+----+-------+------------+", - "| A | 2 | 2021-02-01 |", - "| B | 10 | 2021-02-02 |", - "| C | 20 | 2023-07-04 |", - "| D | 100 | 2021-02-02 |", - "| X | 30 | 2023-07-04 |", - "+----+-------+------------+", - ]; - let actual = get_data(&table).await; - assert_batches_sorted_eq!(&expected, &actual); + assert_merge(table, metrics).await; + } + + #[tokio::test] + async fn test_merge_no_alias() { + // Validate merge can be used without specifying an alias + let (table, source) = setup().await; + + let source = source + .with_column_renamed("id", "source_id") + .unwrap() + .with_column_renamed("value", "source_value") + .unwrap() + .with_column_renamed("modified", "source_modified") + .unwrap(); + + let (table, metrics) = DeltaOps(table) + .merge(source, "id = source_id") + .when_matched_update(|update| { + update + .update("value", "source_value") + .update("modified", "source_modified") + }) + .unwrap() + .when_not_matched_by_source_update(|update| { + update + .predicate("value = arrow_cast(1, 'Int32')") + .update("value", "value + cast(1 as int)") + }) + .unwrap() + .when_not_matched_insert(|insert| { + insert + .set("id", "source_id") + .set("value", "source_value") + .set("modified", "source_modified") + }) + .unwrap() + .await + .unwrap(); + + assert_merge(table, metrics).await; + } + + #[tokio::test] + async fn test_merge_with_alias_mix() { + // Validate merge can be used with an alias and unambiguous column references + // I.E users should be able to specify an alias and still reference columns without using that alias when there is no ambiguity + let (table, source) = setup().await; + + let source = source + .with_column_renamed("id", "source_id") + .unwrap() + .with_column_renamed("value", "source_value") + .unwrap() + .with_column_renamed("modified", "source_modified") + .unwrap(); + + let (table, metrics) = DeltaOps(table) + .merge(source, "id = source_id") + .with_target_alias("target") + .when_matched_update(|update| { + update + .update("value", "source_value") + .update("modified", "source_modified") + }) + .unwrap() + .when_not_matched_by_source_update(|update| { + update + .predicate("value = arrow_cast(1, 'Int32')") + .update("value", "target.value + cast(1 as int)") + }) + .unwrap() + .when_not_matched_insert(|insert| { + insert + .set("id", "source_id") + .set("target.value", "source_value") + .set("modified", "source_modified") + }) + .unwrap() + .await + .unwrap(); + + assert_merge(table, metrics).await; + } + + #[tokio::test] + async fn test_merge_failures() { + // Validate target columns cannot be from the source + let (table, source) = setup().await; + let res = DeltaOps(table) + .merge(source, col("target.id").eq(col("source.id"))) + .with_source_alias("source") + .with_target_alias("target") + .when_matched_update(|update| { + update + .update("source.value", "source.value") + .update("modified", "source.modified") + }) + .unwrap() + .await; + assert!(res.is_err()); + + // Validate failure when aliases are the same + let (table, source) = setup().await; + let res = DeltaOps(table) + .merge(source, col("target.id").eq(col("source.id"))) + .with_source_alias("source") + .with_target_alias("source") + .when_matched_update(|update| { + update + .update("target.value", "source.value") + .update("modified", "source.modified") + }) + .unwrap() + .await; + assert!(res.is_err()) } #[tokio::test] @@ -1286,11 +1603,12 @@ mod tests { let (table, metrics) = DeltaOps(table) .merge( source, - col("id") + col("target.id") .eq(col("source.id")) - .and(col("modified").eq(lit("2021-02-02"))), + .and(col("target.modified").eq(lit("2021-02-02"))), ) .with_source_alias("source") + .with_target_alias("target") .when_matched_update(|update| { update .update("value", col("source.value")) @@ -1299,14 +1617,14 @@ mod tests { .unwrap() .when_not_matched_by_source_update(|update| { update - .predicate(col("value").eq(lit(1))) - .update("value", col("value") + lit(1)) + .predicate(col("target.value").eq(lit(1))) + .update("value", col("target.value") + lit(1)) }) .unwrap() .when_not_matched_by_source_update(|update| { update - .predicate(col("modified").eq(lit("2021-02-01"))) - .update("value", col("value") - lit(1)) + .predicate(col("target.modified").eq(lit("2021-02-01"))) + .update("value", col("target.value") - lit(1)) }) .unwrap() .when_not_matched_insert(|insert| { @@ -1374,8 +1692,9 @@ mod tests { let source = ctx.read_batch(batch).unwrap(); let (mut table, metrics) = DeltaOps(table) - .merge(source, col("id").eq(col("source.id"))) + .merge(source, col("target.id").eq(col("source.id"))) .with_source_alias("source") + .with_target_alias("target") .when_matched_delete(|delete| delete) .unwrap() .await @@ -1395,7 +1714,7 @@ mod tests { let commit_info = table.history(None).await.unwrap(); let last_commit = &commit_info[commit_info.len() - 1]; let parameters = last_commit.operation_parameters.clone().unwrap(); - assert_eq!(parameters["predicate"], json!("id = source.id")); + assert_eq!(parameters["predicate"], json!("target.id = source.id")); assert_eq!( parameters["matchedPredicates"], json!(r#"[{"actionType":"delete"}]"#) @@ -1437,8 +1756,9 @@ mod tests { let source = ctx.read_batch(batch).unwrap(); let (mut table, metrics) = DeltaOps(table) - .merge(source, col("id").eq(col("source.id"))) + .merge(source, col("target.id").eq(col("source.id"))) .with_source_alias("source") + .with_target_alias("target") .when_matched_delete(|delete| delete.predicate(col("source.value").lt_eq(lit(10)))) .unwrap() .await @@ -1458,7 +1778,7 @@ mod tests { let commit_info = table.history(None).await.unwrap(); let last_commit = &commit_info[commit_info.len() - 1]; let parameters = last_commit.operation_parameters.clone().unwrap(); - assert_eq!(parameters["predicate"], json!("id = source.id")); + assert_eq!(parameters["predicate"], json!("target.id = source.id")); assert_eq!( parameters["matchedPredicates"], json!(r#"[{"actionType":"delete","predicate":"source.value <= 10"}]"#) @@ -1505,8 +1825,9 @@ mod tests { let source = ctx.read_batch(batch).unwrap(); let (mut table, metrics) = DeltaOps(table) - .merge(source, col("id").eq(col("source.id"))) + .merge(source, col("target.id").eq(col("source.id"))) .with_source_alias("source") + .with_target_alias("target") .when_not_matched_by_source_delete(|delete| delete) .unwrap() .await @@ -1526,7 +1847,7 @@ mod tests { let commit_info = table.history(None).await.unwrap(); let last_commit = &commit_info[commit_info.len() - 1]; let parameters = last_commit.operation_parameters.clone().unwrap(); - assert_eq!(parameters["predicate"], json!("id = source.id")); + assert_eq!(parameters["predicate"], json!("target.id = source.id")); assert_eq!( parameters["notMatchedBySourcePredicates"], json!(r#"[{"actionType":"delete"}]"#) @@ -1567,10 +1888,11 @@ mod tests { let source = ctx.read_batch(batch).unwrap(); let (mut table, metrics) = DeltaOps(table) - .merge(source, col("id").eq(col("source.id"))) + .merge(source, col("target.id").eq(col("source.id"))) .with_source_alias("source") + .with_target_alias("target") .when_not_matched_by_source_delete(|delete| { - delete.predicate(col("modified").gt(lit("2021-02-01"))) + delete.predicate(col("target.modified").gt(lit("2021-02-01"))) }) .unwrap() .await @@ -1590,10 +1912,10 @@ mod tests { let commit_info = table.history(None).await.unwrap(); let last_commit = &commit_info[commit_info.len() - 1]; let parameters = last_commit.operation_parameters.clone().unwrap(); - assert_eq!(parameters["predicate"], json!("id = source.id")); + assert_eq!(parameters["predicate"], json!("target.id = source.id")); assert_eq!( parameters["notMatchedBySourcePredicates"], - json!(r#"[{"actionType":"delete","predicate":"modified > '2021-02-01'"}]"#) + json!(r#"[{"actionType":"delete","predicate":"target.modified > '2021-02-01'"}]"#) ); let expected = vec![ diff --git a/rust/src/operations/mod.rs b/rust/src/operations/mod.rs index c07b81438b..c15bb8052e 100644 --- a/rust/src/operations/mod.rs +++ b/rust/src/operations/mod.rs @@ -211,10 +211,11 @@ mod datafusion_utils { metrics::{ExecutionPlanMetricsSet, MetricsSet}, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, }; + use datafusion_common::DFSchema; use datafusion_expr::Expr; use futures::{Stream, StreamExt}; - use crate::{table::state::DeltaTableState, DeltaResult}; + use crate::{delta_datafusion::expr::parse_predicate_expression, DeltaResult}; /// Used to represent user input of either a Datafusion expression or string expression pub enum Expression { @@ -243,22 +244,22 @@ mod datafusion_utils { pub(crate) fn into_expr( expr: Expression, - snapshot: &DeltaTableState, + schema: &DFSchema, df_state: &SessionState, ) -> DeltaResult { match expr { Expression::DataFusion(expr) => Ok(expr), - Expression::String(s) => snapshot.parse_predicate_expression(s, df_state), + Expression::String(s) => parse_predicate_expression(schema, s, df_state), } } pub(crate) fn maybe_into_expr( expr: Option, - snapshot: &DeltaTableState, + schema: &DFSchema, df_state: &SessionState, ) -> DeltaResult> { Ok(match expr { - Some(predicate) => Some(into_expr(predicate, snapshot, df_state)?), + Some(predicate) => Some(into_expr(predicate, schema, df_state)?), None => None, }) } diff --git a/rust/src/operations/transaction/state.rs b/rust/src/operations/transaction/state.rs index 5924609fb7..32c386cbdc 100644 --- a/rust/src/operations/transaction/state.rs +++ b/rust/src/operations/transaction/state.rs @@ -8,18 +8,14 @@ use datafusion::datasource::physical_plan::wrap_partition_type_in_dict; use datafusion::execution::context::SessionState; use datafusion::optimizer::utils::conjunction; use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; -use datafusion_common::config::ConfigOptions; use datafusion_common::scalar::ScalarValue; -use datafusion_common::{Column, DFSchema, Result as DFResult, TableReference}; -use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource}; -use datafusion_sql::planner::{ContextProvider, SqlToRel}; +use datafusion_common::{Column, DFSchema}; +use datafusion_expr::Expr; use itertools::Either; use object_store::ObjectStore; use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStreamBuilder}; -use sqlparser::dialect::GenericDialect; -use sqlparser::parser::Parser; -use sqlparser::tokenizer::Tokenizer; +use crate::delta_datafusion::expr::parse_predicate_expression; use crate::delta_datafusion::{ get_null_of_arrow_type, logical_expr_to_physical_expr, to_correct_scalar_value, }; @@ -110,26 +106,8 @@ impl DeltaTableState { expr: impl AsRef, df_state: &SessionState, ) -> DeltaResult { - let dialect = &GenericDialect {}; - let mut tokenizer = Tokenizer::new(dialect, expr.as_ref()); - let tokens = tokenizer - .tokenize() - .map_err(|err| DeltaTableError::GenericError { - source: Box::new(err), - })?; - let sql = Parser::new(dialect) - .with_tokens(tokens) - .parse_expr() - .map_err(|err| DeltaTableError::GenericError { - source: Box::new(err), - })?; - - // TODO should we add the table name as qualifier when available? - let df_schema = DFSchema::try_from_qualified_schema("", self.arrow_schema()?.as_ref())?; - let context_provider = DeltaContextProvider { state: df_state }; - let sql_to_rel = SqlToRel::new(&context_provider); - - Ok(sql_to_rel.sql_to_expr(sql, &df_schema, &mut Default::default())?) + let schema = DFSchema::try_from(self.arrow_schema()?.as_ref().to_owned())?; + parse_predicate_expression(&schema, expr, df_state) } /// Get the physical table schema. @@ -347,36 +325,6 @@ impl PruningStatistics for DeltaTableState { } } -pub(crate) struct DeltaContextProvider<'a> { - state: &'a SessionState, -} - -impl<'a> ContextProvider for DeltaContextProvider<'a> { - fn get_table_provider(&self, _name: TableReference) -> DFResult> { - unimplemented!() - } - - fn get_function_meta(&self, name: &str) -> Option> { - self.state.scalar_functions().get(name).cloned() - } - - fn get_aggregate_meta(&self, name: &str) -> Option> { - self.state.aggregate_functions().get(name).cloned() - } - - fn get_variable_type(&self, _var: &[String]) -> Option { - unimplemented!() - } - - fn options(&self) -> &ConfigOptions { - self.state.config_options() - } - - fn get_window_meta(&self, name: &str) -> Option> { - self.state.window_functions().get(name).cloned() - } -} - #[cfg(test)] mod tests { use super::*;