From 8a66343aa361d396ce5a637f9d3ada5844758481 Mon Sep 17 00:00:00 2001 From: David Blajda Date: Sun, 19 Nov 2023 18:52:55 -0500 Subject: [PATCH] refactor: merge to use logical plans (#1720) # Description This refactors the merge operation to use DataFusion's DataFrame and LogicalPlan APIs The NLJ is eliminated and the query planner can pick the optimal join operator. This also enables the operation to use multiple threads and should result in significant speed up. Merge is still limited to using a single thread in some area. When collecting benchmarks, I encountered multiple OoM issues with Datafusion's hash join implementation. There are multiple tickets upstream open regarding this. For now, I've limited the number of partitions to just 1 to prevent this. Predicates passed as SQL are also easier to use now. Manual casting was required to ensure data types were aligned. Now the logical plan will perform type coercion when optimizing the plan. # Related Issues - enhances #850 - closes #1790 - closes #1753 --- .../src/delta_datafusion/logical.rs | 48 ++ .../src/delta_datafusion/mod.rs | 9 +- .../src/delta_datafusion/physical.rs | 180 +++++ crates/deltalake-core/src/operations/merge.rs | 631 +++++++++--------- crates/deltalake-core/src/operations/mod.rs | 136 +--- .../deltalake-core/src/operations/update.rs | 5 +- 6 files changed, 536 insertions(+), 473 deletions(-) create mode 100644 crates/deltalake-core/src/delta_datafusion/logical.rs create mode 100644 crates/deltalake-core/src/delta_datafusion/physical.rs diff --git a/crates/deltalake-core/src/delta_datafusion/logical.rs b/crates/deltalake-core/src/delta_datafusion/logical.rs new file mode 100644 index 0000000000..7b05dd57d9 --- /dev/null +++ b/crates/deltalake-core/src/delta_datafusion/logical.rs @@ -0,0 +1,48 @@ +//! Logical Operations for DataFusion + +use datafusion_expr::{LogicalPlan, UserDefinedLogicalNodeCore}; + +// Metric Observer is used to update DataFusion metrics from a record batch. +// See MetricObserverExec for the physical implementation + +#[derive(Debug, Hash, Eq, PartialEq)] +pub(crate) struct MetricObserver { + // id is preserved during conversion to physical node + pub id: String, + pub input: LogicalPlan, +} + +impl UserDefinedLogicalNodeCore for MetricObserver { + // Predicate push down is not supported for this node. Try to limit usage + // near the end of plan. + fn name(&self) -> &str { + "MetricObserver" + } + + fn inputs(&self) -> Vec<&datafusion_expr::LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &datafusion_common::DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "MetricObserver id={}", &self.id) + } + + fn from_template( + &self, + _exprs: &[datafusion_expr::Expr], + inputs: &[datafusion_expr::LogicalPlan], + ) -> Self { + MetricObserver { + id: self.id.clone(), + input: inputs[0].clone(), + } + } +} diff --git a/crates/deltalake-core/src/delta_datafusion/mod.rs b/crates/deltalake-core/src/delta_datafusion/mod.rs index 1410efbfbc..8dea811383 100644 --- a/crates/deltalake-core/src/delta_datafusion/mod.rs +++ b/crates/deltalake-core/src/delta_datafusion/mod.rs @@ -81,6 +81,8 @@ use crate::{open_table, open_table_with_storage_options, DeltaTable}; const PATH_COLUMN: &str = "__delta_rs_path"; pub mod expr; +pub mod logical; +pub mod physical; impl From for DataFusionError { fn from(err: DeltaTableError) -> Self { @@ -351,7 +353,7 @@ pub(crate) fn logical_schema( snapshot: &DeltaTableState, scan_config: &DeltaScanConfig, ) -> DeltaResult { - let input_schema = snapshot.input_schema()?; + let input_schema = snapshot.arrow_schema()?; let mut fields = Vec::new(); for field in input_schema.fields.iter() { fields.push(field.to_owned()); @@ -505,11 +507,6 @@ impl<'a> DeltaScanBuilder<'a> { self } - pub fn with_schema(mut self, schema: SchemaRef) -> Self { - self.schema = Some(schema); - self - } - pub async fn build(self) -> DeltaResult { let config = self.config; let schema = match self.schema { diff --git a/crates/deltalake-core/src/delta_datafusion/physical.rs b/crates/deltalake-core/src/delta_datafusion/physical.rs new file mode 100644 index 0000000000..954df0b046 --- /dev/null +++ b/crates/deltalake-core/src/delta_datafusion/physical.rs @@ -0,0 +1,180 @@ +//! Physical Operations for DataFusion +use std::sync::Arc; + +use arrow_schema::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::error::Result as DataFusionResult; +use datafusion::physical_plan::DisplayAs; +use datafusion::physical_plan::{ + metrics::{ExecutionPlanMetricsSet, MetricsSet}, + ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, +}; +use futures::{Stream, StreamExt}; + +use crate::DeltaTableError; + +// Metric Observer is used to update DataFusion metrics from a record batch. +// Typically the null count for a particular column is pulled after performing a +// projection since this count is easy to obtain + +pub(crate) type MetricObserverFunction = fn(&RecordBatch, &ExecutionPlanMetricsSet) -> (); + +pub(crate) struct MetricObserverExec { + parent: Arc, + id: String, + metrics: ExecutionPlanMetricsSet, + update: MetricObserverFunction, +} + +impl MetricObserverExec { + pub fn new(id: String, parent: Arc, f: MetricObserverFunction) -> Self { + MetricObserverExec { + parent, + id, + metrics: ExecutionPlanMetricsSet::new(), + update: f, + } + } + + pub fn try_new( + id: String, + inputs: &[Arc], + f: MetricObserverFunction, + ) -> DataFusionResult> { + match inputs { + [input] => Ok(Arc::new(MetricObserverExec::new(id, input.clone(), f))), + _ => Err(datafusion_common::DataFusionError::External(Box::new( + DeltaTableError::Generic("MetricObserverExec expects only one child".into()), + ))), + } + } + + pub fn id(&self) -> &str { + &self.id + } +} + +impl std::fmt::Debug for MetricObserverExec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MetricObserverExec") + .field("id", &self.id) + .field("metrics", &self.metrics) + .finish() + } +} + +impl DisplayAs for MetricObserverExec { + fn fmt_as( + &self, + _: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "MetricObserverExec id={}", self.id) + } +} + +impl ExecutionPlan for MetricObserverExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow_schema::SchemaRef { + self.parent.schema() + } + + fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { + self.parent.output_partitioning() + } + + fn output_ordering(&self) -> Option<&[datafusion_physical_expr::PhysicalSortExpr]> { + self.parent.output_ordering() + } + + fn children(&self) -> Vec> { + vec![self.parent.clone()] + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> datafusion_common::Result { + let res = self.parent.execute(partition, context)?; + Ok(Box::pin(MetricObserverStream { + schema: self.schema(), + input: res, + metrics: self.metrics.clone(), + update: self.update, + })) + } + + fn statistics(&self) -> DataFusionResult { + self.parent.statistics() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + MetricObserverExec::try_new(self.id.clone(), &children, self.update) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +struct MetricObserverStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + metrics: ExecutionPlanMetricsSet, + update: MetricObserverFunction, +} + +impl Stream for MetricObserverStream { + type Item = DataFusionResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.input.poll_next_unpin(cx).map(|x| match x { + Some(Ok(batch)) => { + (self.update)(&batch, &self.metrics); + Some(Ok(batch)) + } + other => other, + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.input.size_hint() + } +} + +impl RecordBatchStream for MetricObserverStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +pub(crate) fn find_metric_node( + id: &str, + parent: &Arc, +) -> Option> { + //! Used to locate the physical MetricCountExec Node after the planner converts the logical node + if let Some(metric) = parent.as_any().downcast_ref::() { + if metric.id().eq(id) { + return Some(parent.to_owned()); + } + } + + for child in &parent.children() { + let res = find_metric_node(id, child); + if res.is_some() { + return res; + } + } + + None +} diff --git a/crates/deltalake-core/src/operations/merge.rs b/crates/deltalake-core/src/operations/merge.rs index a9ad6a8655..8b0dd56708 100644 --- a/crates/deltalake-core/src/operations/merge.rs +++ b/crates/deltalake-core/src/operations/merge.rs @@ -8,8 +8,7 @@ //! specified matter. See [`MergeBuilder`] for more information //! //! *WARNING* The current implementation rewrites the entire delta table so only -//! use on small to medium sized tables. The solution also cannot take advantage -//! of multiple threads and is limited to a single single thread. +//! use on small to medium sized tables. //! Enhancements tracked at #850 //! //! # Example @@ -37,27 +36,25 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::{Instant, SystemTime, UNIX_EPOCH}; -use arrow_schema::SchemaRef; +use async_trait::async_trait; +use datafusion::datasource::provider_as_source; use datafusion::error::Result as DataFusionResult; +use datafusion::execution::context::{QueryPlanner, SessionConfig}; use datafusion::logical_expr::build_join_schema; -use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}; use datafusion::{ execution::context::SessionState, physical_plan::{ - filter::FilterExec, - joins::{ - utils::{build_join_schema as physical_build_join_schema, JoinFilter}, - NestedLoopJoinExec, - }, metrics::{MetricBuilder, MetricsSet}, - projection::ProjectionExec, ExecutionPlan, }, prelude::{DataFrame, SessionContext}, }; -use datafusion_common::{Column, DFField, DFSchema, ScalarValue, TableReference}; +use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; -use datafusion_physical_expr::{create_physical_expr, expressions, PhysicalExpr}; +use datafusion_expr::{ + Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, UNNAMED_TABLE, +}; use futures::future::BoxFuture; use parquet::file::properties::WriterProperties; use serde::Serialize; @@ -66,15 +63,19 @@ use serde_json::Value; use super::datafusion_utils::{into_expr, maybe_into_expr, Expression}; use super::transaction::{commit, PROTOCOL}; use crate::delta_datafusion::expr::{fmt_expr_to_sql, parse_predicate_expression}; -use crate::delta_datafusion::{register_store, DeltaScanBuilder}; +use crate::delta_datafusion::logical::MetricObserver; +use crate::delta_datafusion::physical::{find_metric_node, MetricObserverExec}; +use crate::delta_datafusion::{register_store, DeltaScanConfig, DeltaTableProvider}; use crate::kernel::{Action, Remove}; use crate::logstore::LogStoreRef; -use crate::operations::datafusion_utils::MetricObserverExec; use crate::operations::write::write_execution_plan; use crate::protocol::{DeltaOperation, MergePredicate}; use crate::table::state::DeltaTableState; use crate::{DeltaResult, DeltaTable, DeltaTableError}; +const SOURCE_COLUMN: &str = "__delta_rs_source"; +const TARGET_COLUMN: &str = "__delta_rs_target"; + const OPERATION_COLUMN: &str = "__delta_rs_operation"; const DELETE_COLUMN: &str = "__delta_rs_delete"; const TARGET_INSERT_COLUMN: &str = "__delta_rs_target_insert"; @@ -83,11 +84,16 @@ const TARGET_DELETE_COLUMN: &str = "__delta_rs_target_delete"; const TARGET_COPY_COLUMN: &str = "__delta_rs_target_copy"; const SOURCE_COUNT_METRIC: &str = "num_source_rows"; +const TARGET_COUNT_METRIC: &str = "num_target_rows"; const TARGET_COPY_METRIC: &str = "num_copied_rows"; const TARGET_INSERTED_METRIC: &str = "num_target_inserted_rows"; const TARGET_UPDATED_METRIC: &str = "num_target_updated_rows"; const TARGET_DELETED_METRIC: &str = "num_target_deleted_rows"; +const SOURCE_COUNT_ID: &str = "merge_source_count"; +const TARGET_COUNT_ID: &str = "merge_target_count"; +const OUTPUT_COUNT_ID: &str = "merge_output_count"; + /// Merge records into a Delta Table. pub struct MergeBuilder { /// The join predicate @@ -557,6 +563,89 @@ pub struct MergeMetrics { pub rewrite_time_ms: u64, } +struct MergeMetricExtensionPlanner {} + +#[async_trait] +impl ExtensionPlanner for MergeMetricExtensionPlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> DataFusionResult>> { + if let Some(metric_observer) = node.as_any().downcast_ref::() { + if metric_observer.id.eq(SOURCE_COUNT_ID) { + return Ok(Some(MetricObserverExec::try_new( + SOURCE_COUNT_ID.into(), + physical_inputs, + |batch, metrics| { + MetricBuilder::new(metrics) + .global_counter(SOURCE_COUNT_METRIC) + .add(batch.num_rows()); + }, + )?)); + } + + if metric_observer.id.eq(TARGET_COUNT_ID) { + return Ok(Some(MetricObserverExec::try_new( + TARGET_COUNT_ID.into(), + physical_inputs, + |batch, metrics| { + MetricBuilder::new(metrics) + .global_counter(TARGET_COUNT_METRIC) + .add(batch.num_rows()); + }, + )?)); + } + + if metric_observer.id.eq(OUTPUT_COUNT_ID) { + return Ok(Some(MetricObserverExec::try_new( + OUTPUT_COUNT_ID.into(), + physical_inputs, + |batch, metrics| { + MetricBuilder::new(metrics) + .global_counter(TARGET_INSERTED_METRIC) + .add( + batch + .column_by_name(TARGET_INSERT_COLUMN) + .unwrap() + .null_count(), + ); + MetricBuilder::new(metrics) + .global_counter(TARGET_UPDATED_METRIC) + .add( + batch + .column_by_name(TARGET_UPDATE_COLUMN) + .unwrap() + .null_count(), + ); + MetricBuilder::new(metrics) + .global_counter(TARGET_DELETED_METRIC) + .add( + batch + .column_by_name(TARGET_DELETE_COLUMN) + .unwrap() + .null_count(), + ); + MetricBuilder::new(metrics) + .global_counter(TARGET_COPY_METRIC) + .add( + batch + .column_by_name(TARGET_COPY_COLUMN) + .unwrap() + .null_count(), + ); + }, + )?)); + } + } + + Ok(None) + } +} + #[allow(clippy::too_many_arguments)] async fn execute( predicate: Expression, @@ -589,83 +678,61 @@ async fn execute( // If the user specified any not_source_match operations then those // predicates also need to be considered when pruning - let target = Arc::new( - DeltaScanBuilder::new(snapshot, log_store.clone(), &state) - .with_schema(snapshot.input_schema()?) - .build() - .await?, - ); - - let source = source.create_physical_plan().await?; - - let source_count = Arc::new(MetricObserverExec::new(source, |batch, metrics| { - MetricBuilder::new(metrics) - .global_counter(SOURCE_COUNT_METRIC) - .add(batch.num_rows()); - })); - - let mut expressions: Vec<(Arc, String)> = Vec::new(); - let source_schema = source_count.schema(); - - for (i, field) in source_schema.fields().into_iter().enumerate() { - expressions.push(( - Arc::new(expressions::Column::new(field.name(), i)), - field.name().clone(), - )); - } - expressions.push(( - Arc::new(expressions::Literal::new(true.into())), - "__delta_rs_source".to_owned(), - )); - let source = Arc::new(ProjectionExec::try_new(expressions, source_count.clone())?); - - let mut expressions: Vec<(Arc, String)> = Vec::new(); - let target_schema = target.schema(); - for (i, field) in target_schema.fields().into_iter().enumerate() { - expressions.push(( - Arc::new(expressions::Column::new(field.name(), i)), - field.name().to_owned(), - )); - } - expressions.push(( - Arc::new(expressions::Literal::new(true.into())), - "__delta_rs_target".to_owned(), - )); - let target = Arc::new(ProjectionExec::try_new(expressions, target.clone())?); - - // TODO: Currently a NestedLoopJoin is used but we should target to support SortMergeJoin - // This would require rewriting the join predicate to only contain equality between left and right columns and pushing some filters down - // Ideally it would be nice if the optimizer / planner can pick the best join so maybe explore rewriting the entire operation using logical plans. - - // NLJ requires both sides to have one partition for outer joins - let target = Arc::new(CoalescePartitionsExec::new(target)); - let source = Arc::new(CoalescePartitionsExec::new(source)); - - let source_schema = match &source_alias { - Some(alias) => { - DFSchema::try_from_qualified_schema(TableReference::bare(alias), &source.schema())? - } - None => DFSchema::try_from(source.schema().as_ref().to_owned())?, + let source_name = match &source_alias { + Some(alias) => TableReference::bare(alias.to_string()), + None => TableReference::bare(UNNAMED_TABLE), }; - let target_schema = match &target_alias { - Some(alias) => { - DFSchema::try_from_qualified_schema(TableReference::bare(alias), &target.schema())? - } - None => DFSchema::try_from(target.schema().as_ref().to_owned())?, + let target_name = match &target_alias { + Some(alias) => TableReference::bare(alias.to_string()), + None => TableReference::bare(UNNAMED_TABLE), }; - let join_schema_df = build_join_schema(&source_schema, &target_schema, &JoinType::Full)?; + // This is only done to provide the source columns with a correct table reference. Just renaming the columns does not work + let source = + LogicalPlanBuilder::scan(source_name, provider_as_source(source.into_view()), None)? + .build()?; + + let source = LogicalPlan::Extension(Extension { + node: Arc::new(MetricObserver { + id: SOURCE_COUNT_ID.into(), + input: source, + }), + }); + + let source = DataFrame::new(state.clone(), source); + let source = source.with_column(SOURCE_COLUMN, lit(true))?; + + let target_provider = Arc::new(DeltaTableProvider::try_new( + snapshot.clone(), + log_store.clone(), + DeltaScanConfig::default(), + )?); + let target_provider = provider_as_source(target_provider); + + let target = LogicalPlanBuilder::scan(target_name, target_provider, None)?.build()?; - let join_schema = - physical_build_join_schema(&source.schema(), &target.schema(), &JoinType::Full); - let (join_schema, join_order) = (join_schema.0, join_schema.1); + // TODO: This is here to prevent predicate pushdowns. In the future we can replace this node to allow pushdowns depending on which operations are being used. + let target = LogicalPlan::Extension(Extension { + node: Arc::new(MetricObserver { + id: TARGET_COUNT_ID.into(), + input: target, + }), + }); + let target = DataFrame::new(state.clone(), target); + let target = target.with_column(TARGET_COLUMN, lit(true))?; + let source_schema = source.schema(); + let target_schema = target.schema(); + let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?; let predicate = match predicate { Expression::DataFusion(expr) => expr, Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?, }; + let join = source.join(target, JoinType::Full, &[], &[], Some(predicate.clone()))?; + let join_schema_df = join.schema().to_owned(); + let match_operations: Vec = match_operations .into_iter() .map(|op| MergeOperation::try_from(op, &join_schema_df, &state, &target_alias)) @@ -681,40 +748,15 @@ async fn execute( .map(|op| MergeOperation::try_from(op, &join_schema_df, &state, &target_alias)) .collect::, DeltaTableError>>()?; - let predicate_expr = create_physical_expr( - &predicate, - &join_schema_df, - &join_schema, - state.execution_props(), - )?; - - let join_filter = JoinFilter::new(predicate_expr, join_order, join_schema); - let join: Arc = Arc::new(NestedLoopJoinExec::try_new( - source.clone(), - target.clone(), - Some(join_filter), - &datafusion_expr::JoinType::Full, - )?); - - // Project to include __delta_rs_operation which indicates which particular operation to perform on the column. - let mut expressions: Vec<(Arc, String)> = Vec::new(); - let schema = join.schema(); - for (i, field) in schema.fields().into_iter().enumerate() { - expressions.push(( - Arc::new(expressions::Column::new(field.name(), i)), - field.name().to_owned(), - )); - } - - let matched = col("__delta_rs_source") + let matched = col(SOURCE_COLUMN) .is_true() - .and(col("__delta_rs_target").is_true()); - let not_matched_target = col("__delta_rs_source") + .and(col(TARGET_COLUMN).is_true()); + let not_matched_target = col(SOURCE_COLUMN) .is_true() - .and(col("__delta_rs_target").is_null()); - let not_matched_source = col("__delta_rs_source") + .and(col(TARGET_COLUMN).is_null()); + let not_matched_source = col(SOURCE_COLUMN) .is_null() - .and(col("__delta_rs_target")) + .and(col(TARGET_COLUMN)) .is_true(); // Plus 3 for the default operations for each match category @@ -811,35 +853,10 @@ async fn execute( let case = CaseBuilder::new(None, when_expr, then_expr, None).end()?; - let case = create_physical_expr( - &case, - &join_schema_df, - &join.schema(), - state.execution_props(), - )?; - expressions.push((case, OPERATION_COLUMN.to_owned())); - let projection = Arc::new(ProjectionExec::try_new(expressions, join.clone())?); - - let mut f = join_schema_df.fields().to_owned(); - f.push(DFField::new_unqualified( - OPERATION_COLUMN, - arrow_schema::DataType::Int64, - false, - )); - let project_schema_df = DFSchema::new_with_metadata(f, HashMap::new())?; - - // Project again and include the original table schema plus a column to mark if row needs to be filtered before write - let mut expressions: Vec<(Arc, String)> = Vec::new(); - let schema = projection.schema(); - for (i, field) in schema.fields().into_iter().enumerate() { - expressions.push(( - Arc::new(expressions::Column::new(field.name(), i)), - field.name().to_owned(), - )); - } + let projection = join.with_column(OPERATION_COLUMN, case)?; - let mut projection_map = HashMap::new(); - let mut f = project_schema_df.fields().clone(); + let mut new_columns = projection; + let mut write_projection = Vec::new(); for delta_field in snapshot.schema().unwrap().fields() { let mut when_expr = Vec::with_capacity(operations_size); @@ -853,7 +870,6 @@ async fn execute( }; let name = delta_field.name(); let column = Column::new(qualifier.clone(), name); - let field = project_schema_df.field_with_name(qualifier.as_ref(), name)?; for (idx, (operations, _)) in ops.iter().enumerate() { let op = operations @@ -873,22 +889,9 @@ async fn execute( ) .end()?; - let case = create_physical_expr( - &case, - &project_schema_df, - &projection.schema(), - state.execution_props(), - )?; - - projection_map.insert(delta_field.name(), expressions.len()); let name = "__delta_rs_c_".to_owned() + delta_field.name(); - - f.push(DFField::new_unqualified( - &name, - field.data_type().clone(), - true, - )); - expressions.push((case, name)); + write_projection.push(col(name.clone()).alias(delta_field.name())); + new_columns = new_columns.with_column(&name, case)?; } let mut insert_when = Vec::with_capacity(ops.len()); @@ -954,168 +957,47 @@ async fn execute( ); } - fn build_case( - when: Vec, - then: Vec, - schema: SchemaRef, - input_dfschema: &DFSchema, - state: &SessionState, - ) -> DataFusionResult> { - let case = CaseBuilder::new( + fn build_case(when: Vec, then: Vec) -> DataFusionResult { + CaseBuilder::new( Some(Box::new(col(OPERATION_COLUMN))), when, then, Some(Box::new(lit(false))), ) - .end()?; - - create_physical_expr(&case, input_dfschema, &schema, state.execution_props()) + .end() } - let schema = projection.schema(); - let input_dfschema = project_schema_df; - expressions.push(( - build_case( - delete_when, - delete_then, - schema.clone(), - &input_dfschema, - &state, - )?, - DELETE_COLUMN.to_owned(), - )); - f.push(DFField::new_unqualified( - DELETE_COLUMN, - arrow_schema::DataType::Boolean, - true, - )); - - expressions.push(( - build_case( - insert_when, - insert_then, - schema.clone(), - &input_dfschema, - &state, - )?, - TARGET_INSERT_COLUMN.to_owned(), - )); - f.push(DFField::new_unqualified( - TARGET_INSERT_COLUMN, - arrow_schema::DataType::Boolean, - true, - )); - - expressions.push(( - build_case( - update_when, - update_then, - schema.clone(), - &input_dfschema, - &state, - )?, - TARGET_UPDATE_COLUMN.to_owned(), - )); - f.push(DFField::new_unqualified( - TARGET_UPDATE_COLUMN, - arrow_schema::DataType::Boolean, - true, - )); - - expressions.push(( - build_case( - target_delete_when, - target_delete_then, - schema.clone(), - &input_dfschema, - &state, - )?, - TARGET_DELETE_COLUMN.to_owned(), - )); - f.push(DFField::new_unqualified( + new_columns = new_columns.with_column(DELETE_COLUMN, build_case(delete_when, delete_then)?)?; + new_columns = + new_columns.with_column(TARGET_INSERT_COLUMN, build_case(insert_when, insert_then)?)?; + new_columns = + new_columns.with_column(TARGET_UPDATE_COLUMN, build_case(update_when, update_then)?)?; + new_columns = new_columns.with_column( TARGET_DELETE_COLUMN, - arrow_schema::DataType::Boolean, - true, - )); - - expressions.push(( - build_case( - copy_when, - copy_then, - schema.clone(), - &input_dfschema, - &state, - )?, - TARGET_COPY_COLUMN.to_owned(), - )); - f.push(DFField::new_unqualified( - TARGET_COPY_COLUMN, - arrow_schema::DataType::Boolean, - true, - )); - - let projection = Arc::new(ProjectionExec::try_new(expressions, projection.clone())?); - - let target_count_plan = Arc::new(MetricObserverExec::new(projection, |batch, metrics| { - MetricBuilder::new(metrics) - .global_counter(TARGET_INSERTED_METRIC) - .add( - batch - .column_by_name(TARGET_INSERT_COLUMN) - .unwrap() - .null_count(), - ); - MetricBuilder::new(metrics) - .global_counter(TARGET_UPDATED_METRIC) - .add( - batch - .column_by_name(TARGET_UPDATE_COLUMN) - .unwrap() - .null_count(), - ); - MetricBuilder::new(metrics) - .global_counter(TARGET_DELETED_METRIC) - .add( - batch - .column_by_name(TARGET_DELETE_COLUMN) - .unwrap() - .null_count(), - ); - MetricBuilder::new(metrics) - .global_counter(TARGET_COPY_METRIC) - .add( - batch - .column_by_name(TARGET_COPY_COLUMN) - .unwrap() - .null_count(), - ); - })); - - let write_schema_df = DFSchema::new_with_metadata(f, HashMap::new())?; - - let write_predicate = create_physical_expr( - &(col(DELETE_COLUMN).is_false()), - &write_schema_df, - &target_count_plan.schema(), - state.execution_props(), + build_case(target_delete_when, target_delete_then)?, )?; - let filter: Arc = Arc::new(FilterExec::try_new( - write_predicate, - target_count_plan.clone(), - )?); + new_columns = new_columns.with_column(TARGET_COPY_COLUMN, build_case(copy_when, copy_then)?)?; - let mut expressions: Vec<(Arc, String)> = Vec::new(); - for (key, value) in projection_map { - expressions.push(( - Arc::new(expressions::Column::new( - &("__delta_rs_c_".to_owned() + key), - value, - )), - key.to_owned(), - )); - } - // project filtered records to delta schema - let projection = Arc::new(ProjectionExec::try_new(expressions, filter.clone())?); + let new_columns = new_columns.into_optimized_plan()?; + let operation_count = LogicalPlan::Extension(Extension { + node: Arc::new(MetricObserver { + id: OUTPUT_COUNT_ID.into(), + input: new_columns, + }), + }); + + let operation_count = DataFrame::new(state.clone(), operation_count); + let filtered = operation_count.filter(col(DELETE_COLUMN).is_false())?; + + let project = filtered.select(write_projection)?; + let optimized = &project.into_optimized_plan()?; + + let state = state.with_query_planner(Arc::new(MergePlanner {})); + let write = state.create_physical_plan(optimized).await?; + + let err = || DeltaTableError::Generic("Unable to locate expected metric node".into()); + let source_count = find_metric_node(SOURCE_COUNT_ID, &write).ok_or_else(err)?; + let op_count = find_metric_node(OUTPUT_COUNT_ID, &write).ok_or_else(err)?; // write projected records let table_partition_cols = current_metadata.partition_columns.clone(); @@ -1124,9 +1006,9 @@ async fn execute( let add_actions = write_execution_plan( snapshot, state.clone(), - projection.clone(), + write, table_partition_cols.clone(), - log_store.object_store().clone(), + log_store.object_store(), Some(snapshot.table_config().target_file_size() as usize), None, writer_properties, @@ -1163,7 +1045,7 @@ async fn execute( let mut version = snapshot.version(); let source_count_metrics = source_count.metrics().unwrap(); - let target_count_metrics = target_count_plan.metrics().unwrap(); + let target_count_metrics = op_count.metrics().unwrap(); fn get_metric(metrics: &MetricsSet, name: &str) -> usize { metrics.sum_by_name(name).map(|m| m.as_usize()).unwrap_or(0) } @@ -1200,6 +1082,25 @@ async fn execute( Ok(((actions, version), metrics)) } +// TODO: Abstract MergePlanner into DeltaPlanner to support other delta operations in the future. +struct MergePlanner {} + +#[async_trait] +impl QueryPlanner for MergePlanner { + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> DataFusionResult> { + let planner = Arc::new(Box::new(DefaultPhysicalPlanner::with_extension_planners( + vec![Arc::new(MergeMetricExtensionPlanner {})], + ))); + planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + impl std::future::IntoFuture for MergeBuilder { type Output = DeltaResult<(DeltaTable, MergeMetrics)>; type IntoFuture = BoxFuture<'static, Self::Output>; @@ -1211,7 +1112,9 @@ impl std::future::IntoFuture for MergeBuilder { PROTOCOL.can_write_to(&this.snapshot)?; let state = this.state.unwrap_or_else(|| { - let session = SessionContext::new(); + //TODO: Datafusion's Hashjoin has some memory issues. Running with all cores results in a OoM. Can be removed when upstream improvemetns are made. + let config = SessionConfig::new().with_target_partitions(1); + let session = SessionContext::new_with_config(config); // If a user provides their own their DF state then they must register the store themselves register_store(this.log_store.clone(), session.runtime_env()); @@ -1349,8 +1252,8 @@ mod tests { async fn assert_merge(table: DeltaTable, metrics: MergeMetrics) { assert_eq!(table.version(), 2); - assert_eq!(table.get_file_uris().count(), 1); - assert_eq!(metrics.num_target_files_added, 1); + assert!(table.get_file_uris().count() >= 1); + assert!(metrics.num_target_files_added >= 1); assert_eq!(metrics.num_target_files_removed, 1); assert_eq!(metrics.num_target_rows_copied, 1); assert_eq!(metrics.num_target_rows_updated, 3); @@ -1442,7 +1345,7 @@ mod tests { .unwrap() .when_not_matched_by_source_update(|update| { update - .predicate("target.value = arrow_cast(1, 'Int32')") + .predicate("target.value = 1") .update("value", "target.value + cast(1 as int)") }) .unwrap() @@ -1470,9 +1373,7 @@ mod tests { ); assert_eq!( parameters["notMatchedBySourcePredicates"], - json!( - r#"[{"actionType":"update","predicate":"target.value = arrow_cast(1, 'Int32')"}]"# - ) + json!(r#"[{"actionType":"update","predicate":"target.value = 1"}]"#) ); assert_merge(table, metrics).await; @@ -1500,9 +1401,7 @@ mod tests { }) .unwrap() .when_not_matched_by_source_update(|update| { - update - .predicate("value = arrow_cast(1, 'Int32')") - .update("value", "value + cast(1 as int)") + update.predicate("value = 1").update("value", "value + 1") }) .unwrap() .when_not_matched_insert(|insert| { @@ -1543,8 +1442,8 @@ mod tests { .unwrap() .when_not_matched_by_source_update(|update| { update - .predicate("value = arrow_cast(1, 'Int32')") - .update("value", "target.value + cast(1 as int)") + .predicate("value = 1") + .update("value", "target.value + 1") }) .unwrap() .when_not_matched_insert(|insert| { @@ -1657,8 +1556,8 @@ mod tests { .unwrap(); assert_eq!(table.version(), 2); - assert_eq!(table.get_file_uris().count(), 3); - assert_eq!(metrics.num_target_files_added, 3); + assert!(table.get_file_uris().count() >= 3); + assert!(metrics.num_target_files_added >= 3); assert_eq!(metrics.num_target_files_removed, 2); assert_eq!(metrics.num_target_rows_copied, 1); assert_eq!(metrics.num_target_rows_updated, 3); @@ -1720,8 +1619,8 @@ mod tests { .unwrap(); assert_eq!(table.version(), 2); - assert_eq!(table.get_file_uris().count(), 2); - assert_eq!(metrics.num_target_files_added, 2); + assert!(table.get_file_uris().count() >= 2); + assert!(metrics.num_target_files_added >= 2); assert_eq!(metrics.num_target_files_removed, 2); assert_eq!(metrics.num_target_rows_copied, 2); assert_eq!(metrics.num_target_rows_updated, 0); @@ -1784,8 +1683,8 @@ mod tests { .unwrap(); assert_eq!(table.version(), 2); - assert_eq!(table.get_file_uris().count(), 2); - assert_eq!(metrics.num_target_files_added, 2); + assert!(table.get_file_uris().count() >= 2); + assert!(metrics.num_target_files_added >= 2); assert_eq!(metrics.num_target_files_removed, 2); assert_eq!(metrics.num_target_rows_copied, 3); assert_eq!(metrics.num_target_rows_updated, 0); @@ -1918,8 +1817,7 @@ mod tests { .unwrap(); assert_eq!(table.version(), 2); - assert_eq!(table.get_file_uris().count(), 2); - assert_eq!(metrics.num_target_files_added, 2); + assert!(metrics.num_target_files_added >= 2); assert_eq!(metrics.num_target_files_removed, 2); assert_eq!(metrics.num_target_rows_copied, 3); assert_eq!(metrics.num_target_rows_updated, 0); @@ -1949,4 +1847,77 @@ mod tests { let actual = get_data(&table).await; assert_batches_sorted_eq!(&expected, &actual); } + + #[tokio::test] + async fn test_merge_empty_table() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["modified"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_file_uris().count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let (table, metrics) = DeltaOps(table) + .merge( + source, + col("target.id") + .eq(col("source.id")) + .and(col("target.modified").eq(lit("2021-02-02"))), + ) + .with_source_alias("source") + .with_target_alias("target") + .when_matched_update(|update| { + update + .update("value", col("source.value")) + .update("modified", col("source.modified")) + }) + .unwrap() + .when_not_matched_insert(|insert| { + insert + .set("id", col("source.id")) + .set("value", col("source.value")) + .set("modified", col("source.modified")) + }) + .unwrap() + .await + .unwrap(); + + assert_eq!(table.version(), 1); + assert!(table.get_file_uris().count() >= 2); + assert!(metrics.num_target_files_added >= 2); + assert_eq!(metrics.num_target_files_removed, 0); + assert_eq!(metrics.num_target_rows_copied, 0); + assert_eq!(metrics.num_target_rows_updated, 0); + assert_eq!(metrics.num_target_rows_inserted, 3); + assert_eq!(metrics.num_target_rows_deleted, 0); + assert_eq!(metrics.num_output_rows, 3); + assert_eq!(metrics.num_source_rows, 3); + + let expected = vec![ + "+----+-------+------------+", + "| id | value | modified |", + "+----+-------+------------+", + "| B | 10 | 2021-02-02 |", + "| C | 20 | 2023-07-04 |", + "| X | 30 | 2023-07-04 |", + "+----+-------+------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } } diff --git a/crates/deltalake-core/src/operations/mod.rs b/crates/deltalake-core/src/operations/mod.rs index a0dbfd0239..a81e16578f 100644 --- a/crates/deltalake-core/src/operations/mod.rs +++ b/crates/deltalake-core/src/operations/mod.rs @@ -192,20 +192,9 @@ impl AsRef for DeltaOps { #[cfg(feature = "datafusion")] mod datafusion_utils { - use std::sync::Arc; - - use arrow_schema::SchemaRef; - use datafusion::arrow::record_batch::RecordBatch; - use datafusion::error::Result as DataFusionResult; use datafusion::execution::context::SessionState; - use datafusion::physical_plan::DisplayAs; - use datafusion::physical_plan::{ - metrics::{ExecutionPlanMetricsSet, MetricsSet}, - ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, - }; - use datafusion_common::{DFSchema, Statistics}; + use datafusion_common::DFSchema; use datafusion_expr::Expr; - use futures::{Stream, StreamExt}; use crate::{delta_datafusion::expr::parse_predicate_expression, DeltaResult}; @@ -255,127 +244,4 @@ mod datafusion_utils { None => None, }) } - - pub(crate) type MetricObserverFunction = fn(&RecordBatch, &ExecutionPlanMetricsSet) -> (); - - pub(crate) struct MetricObserverExec { - parent: Arc, - metrics: ExecutionPlanMetricsSet, - update: MetricObserverFunction, - } - - impl MetricObserverExec { - pub fn new(parent: Arc, f: MetricObserverFunction) -> Self { - MetricObserverExec { - parent, - metrics: ExecutionPlanMetricsSet::new(), - update: f, - } - } - } - - impl std::fmt::Debug for MetricObserverExec { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MergeStatsExec") - .field("parent", &self.parent) - .field("metrics", &self.metrics) - .finish() - } - } - - impl DisplayAs for MetricObserverExec { - fn fmt_as( - &self, - _: datafusion::physical_plan::DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - write!(f, "MetricObserverExec") - } - } - - impl ExecutionPlan for MetricObserverExec { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn schema(&self) -> arrow_schema::SchemaRef { - self.parent.schema() - } - - fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { - self.parent.output_partitioning() - } - - fn output_ordering(&self) -> Option<&[datafusion_physical_expr::PhysicalSortExpr]> { - self.parent.output_ordering() - } - - fn children(&self) -> Vec> { - vec![self.parent.clone()] - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> datafusion_common::Result - { - let res = self.parent.execute(partition, context)?; - Ok(Box::pin(MetricObserverStream { - schema: self.schema(), - input: res, - metrics: self.metrics.clone(), - update: self.update, - })) - } - - fn statistics(&self) -> DataFusionResult { - self.parent.statistics() - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> datafusion_common::Result> { - ExecutionPlan::with_new_children(self.parent.clone(), children) - } - - fn metrics(&self) -> Option { - Some(self.metrics.clone_inner()) - } - } - - struct MetricObserverStream { - schema: SchemaRef, - input: SendableRecordBatchStream, - metrics: ExecutionPlanMetricsSet, - update: MetricObserverFunction, - } - - impl Stream for MetricObserverStream { - type Item = DataFusionResult; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.input.poll_next_unpin(cx).map(|x| match x { - Some(Ok(batch)) => { - (self.update)(&batch, &self.metrics); - Some(Ok(batch)) - } - other => other, - }) - } - - fn size_hint(&self) -> (usize, Option) { - self.input.size_hint() - } - } - - impl RecordBatchStream for MetricObserverStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - } } diff --git a/crates/deltalake-core/src/operations/update.rs b/crates/deltalake-core/src/operations/update.rs index 7583ed6b39..907dec5998 100644 --- a/crates/deltalake-core/src/operations/update.rs +++ b/crates/deltalake-core/src/operations/update.rs @@ -43,10 +43,10 @@ use parquet::file::properties::WriterProperties; use serde::Serialize; use serde_json::Value; -use super::datafusion_utils::{Expression, MetricObserverExec}; +use super::datafusion_utils::Expression; use super::transaction::{commit, PROTOCOL}; use super::write::write_execution_plan; -use crate::delta_datafusion::expr::fmt_expr_to_sql; +use crate::delta_datafusion::{expr::fmt_expr_to_sql, physical::MetricObserverExec}; use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::kernel::{Action, Remove}; use crate::logstore::LogStoreRef; @@ -275,6 +275,7 @@ async fn execute( Arc::new(ProjectionExec::try_new(expressions, scan)?); let count_plan = Arc::new(MetricObserverExec::new( + "update_count".into(), projection_predicate.clone(), |batch, metrics| { let array = batch.column_by_name("__delta_rs_update_predicate").unwrap();