From 2f550032140d42d1ee6d8ed86f7790766fa7302e Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 3 Apr 2024 22:20:01 +0200 Subject: [PATCH] Simplify Expr::map_children (#9876) * add map_until_stop_and_collect macro * fix clippy * simplify * Update datafusion/common/src/tree_node.rs Co-authored-by: Andrew Lamb * add documentation * fix macro --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/tree_node.rs | 82 ++++++++-- datafusion/expr/src/tree_node/expr.rs | 226 ++++++++++++-------------- 2 files changed, 171 insertions(+), 137 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 2d653a27c47b..554722f37ba2 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -532,8 +532,20 @@ impl Transformed { } } -/// Transformation helper to process tree nodes that are siblings. +/// Transformation helper to process a sequence of iterable tree nodes that are siblings. pub trait TransformedIterator: Iterator { + /// Apples `f` to each item in this iterator + /// + /// Visits all items in the iterator unless + /// `f` returns an error or `f` returns TreeNodeRecursion::stop. + /// + /// # Returns + /// Error if `f` returns an error + /// + /// Ok(Transformed) such that: + /// 1. `transformed` is true if any return from `f` had transformed true + /// 2. `data` from the last invocation of `f` + /// 3. `tnr` from the last invocation of `f` or `Continue` if the iterator is empty fn map_until_stop_and_collect< F: FnMut(Self::Item) -> Result>, >( @@ -551,22 +563,64 @@ impl TransformedIterator for I { ) -> Result>> { let mut tnr = TreeNodeRecursion::Continue; let mut transformed = false; - let data = self - .map(|item| match tnr { - TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { - f(item).map(|result| { - tnr = result.tnr; - transformed |= result.transformed; - result.data - }) - } - TreeNodeRecursion::Stop => Ok(item), - }) - .collect::>>()?; - Ok(Transformed::new(data, transformed, tnr)) + self.map(|item| match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + f(item).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + result.data + }) + } + TreeNodeRecursion::Stop => Ok(item), + }) + .collect::>>() + .map(|data| Transformed::new(data, transformed, tnr)) } } +/// Transformation helper to process a heterogeneous sequence of tree node containing +/// expressions. +/// This macro is very similar to [TransformedIterator::map_until_stop_and_collect] to +/// process nodes that are siblings, but it accepts an initial transformation (`F0`) and +/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its +/// transformation (`F`). +/// +/// The macro builds up a tuple that contains `Transformed.data` result of `F0` as the +/// first element and further elements from the sequence of pairs. An element from a pair +/// is either the value of `EXPR` or the `Transformed.data` result of `F`, depending on +/// the `Transformed.tnr` result of previous `F`s (`F0` initially). +/// +/// # Returns +/// Error if any of the transformations returns an error +/// +/// Ok(Transformed<(data0, ..., dataN)>) such that: +/// 1. `transformed` is true if any of the transformations had transformed true +/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from `F0` and +/// `data1` ... `dataN` are from either `EXPR` or the `Transformed.data` of `F` +/// 3. `tnr` from `F0` or the last invocation of `F` +#[macro_export] +macro_rules! map_until_stop_and_collect { + ($F0:expr, $($EXPR:expr, $F:expr),*) => {{ + $F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| { + let all_datas = ( + data0, + $( + if tnr == TreeNodeRecursion::Continue || tnr == TreeNodeRecursion::Jump { + $F.map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + result.data + })? + } else { + $EXPR + }, + )* + ); + Ok(Transformed::new(all_datas, transformed, tnr)) + }) + }} +} + /// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. pub trait TransformedResult { fn data(self) -> Result; diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 0909d8f662f6..df1585e5a598 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -27,7 +27,9 @@ use crate::{Expr, GetFieldAccess}; use datafusion_common::tree_node::{ Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, }; -use datafusion_common::{handle_visit_recursion, internal_err, Result}; +use datafusion_common::{ + handle_visit_recursion, internal_err, map_until_stop_and_collect, Result, +}; impl TreeNode for Expr { fn apply_children Result>( @@ -167,15 +169,14 @@ impl TreeNode for Expr { Expr::InSubquery(InSubquery::new(be, subquery, negated)) }), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - transform_box(left, &mut f)? - .update_data(|new_left| (new_left, right)) - .try_transform_node(|(new_left, right)| { - Ok(transform_box(right, &mut f)? - .update_data(|new_right| (new_left, new_right))) - })? - .update_data(|(new_left, new_right)| { - Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) - }) + map_until_stop_and_collect!( + transform_box(left, &mut f), + right, + transform_box(right, &mut f) + )? + .update_data(|(new_left, new_right)| { + Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) + }) } Expr::Like(Like { negated, @@ -183,42 +184,40 @@ impl TreeNode for Expr { pattern, escape_char, case_insensitive, - }) => transform_box(expr, &mut f)? - .update_data(|new_expr| (new_expr, pattern)) - .try_transform_node(|(new_expr, pattern)| { - Ok(transform_box(pattern, &mut f)? - .update_data(|new_pattern| (new_expr, new_pattern))) - })? - .update_data(|(new_expr, new_pattern)| { - Expr::Like(Like::new( - negated, - new_expr, - new_pattern, - escape_char, - case_insensitive, - )) - }), + }) => map_until_stop_and_collect!( + transform_box(expr, &mut f), + pattern, + transform_box(pattern, &mut f) + )? + .update_data(|(new_expr, new_pattern)| { + Expr::Like(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }), Expr::SimilarTo(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => transform_box(expr, &mut f)? - .update_data(|new_expr| (new_expr, pattern)) - .try_transform_node(|(new_expr, pattern)| { - Ok(transform_box(pattern, &mut f)? - .update_data(|new_pattern| (new_expr, new_pattern))) - })? - .update_data(|(new_expr, new_pattern)| { - Expr::SimilarTo(Like::new( - negated, - new_expr, - new_pattern, - escape_char, - case_insensitive, - )) - }), + }) => map_until_stop_and_collect!( + transform_box(expr, &mut f), + pattern, + transform_box(pattern, &mut f) + )? + .update_data(|(new_expr, new_pattern)| { + Expr::SimilarTo(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }), Expr::Not(expr) => transform_box(expr, &mut f)?.update_data(Expr::Not), Expr::IsNotNull(expr) => { transform_box(expr, &mut f)?.update_data(Expr::IsNotNull) @@ -248,48 +247,38 @@ impl TreeNode for Expr { negated, low, high, - }) => transform_box(expr, &mut f)? - .update_data(|new_expr| (new_expr, low, high)) - .try_transform_node(|(new_expr, low, high)| { - Ok(transform_box(low, &mut f)? - .update_data(|new_low| (new_expr, new_low, high))) - })? - .try_transform_node(|(new_expr, new_low, high)| { - Ok(transform_box(high, &mut f)? - .update_data(|new_high| (new_expr, new_low, new_high))) - })? - .update_data(|(new_expr, new_low, new_high)| { - Expr::Between(Between::new(new_expr, negated, new_low, new_high)) - }), + }) => map_until_stop_and_collect!( + transform_box(expr, &mut f), + low, + transform_box(low, &mut f), + high, + transform_box(high, &mut f) + )? + .update_data(|(new_expr, new_low, new_high)| { + Expr::Between(Between::new(new_expr, negated, new_low, new_high)) + }), Expr::Case(Case { expr, when_then_expr, else_expr, - }) => transform_option_box(expr, &mut f)? - .update_data(|new_expr| (new_expr, when_then_expr, else_expr)) - .try_transform_node(|(new_expr, when_then_expr, else_expr)| { - Ok(when_then_expr - .into_iter() - .map_until_stop_and_collect(|(when, then)| { - transform_box(when, &mut f)? - .update_data(|new_when| (new_when, then)) - .try_transform_node(|(new_when, then)| { - Ok(transform_box(then, &mut f)? - .update_data(|new_then| (new_when, new_then))) - }) - })? - .update_data(|new_when_then_expr| { - (new_expr, new_when_then_expr, else_expr) - })) - })? - .try_transform_node(|(new_expr, new_when_then_expr, else_expr)| { - Ok(transform_option_box(else_expr, &mut f)?.update_data( - |new_else_expr| (new_expr, new_when_then_expr, new_else_expr), - )) - })? - .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { - Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) - }), + }) => map_until_stop_and_collect!( + transform_option_box(expr, &mut f), + when_then_expr, + when_then_expr + .into_iter() + .map_until_stop_and_collect(|(when, then)| { + map_until_stop_and_collect!( + transform_box(when, &mut f), + then, + transform_box(then, &mut f) + ) + }), + else_expr, + transform_option_box(else_expr, &mut f) + )? + .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { + Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) + }), Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)? .update_data(|be| Expr::Cast(Cast::new(be, data_type))), Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? @@ -320,30 +309,23 @@ impl TreeNode for Expr { order_by, window_frame, null_treatment, - }) => transform_vec(args, &mut f)? - .update_data(|new_args| (new_args, partition_by, order_by)) - .try_transform_node(|(new_args, partition_by, order_by)| { - Ok(transform_vec(partition_by, &mut f)?.update_data( - |new_partition_by| (new_args, new_partition_by, order_by), - )) - })? - .try_transform_node(|(new_args, new_partition_by, order_by)| { - Ok( - transform_vec(order_by, &mut f)?.update_data(|new_order_by| { - (new_args, new_partition_by, new_order_by) - }), - ) - })? - .update_data(|(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new( - fun, - new_args, - new_partition_by, - new_order_by, - window_frame, - null_treatment, - )) - }), + }) => map_until_stop_and_collect!( + transform_vec(args, &mut f), + partition_by, + transform_vec(partition_by, &mut f), + order_by, + transform_vec(order_by, &mut f) + )? + .update_data(|(new_args, new_partition_by, new_order_by)| { + Expr::WindowFunction(WindowFunction::new( + fun, + new_args, + new_partition_by, + new_order_by, + window_frame, + null_treatment, + )) + }), Expr::AggregateFunction(AggregateFunction { args, func_def, @@ -351,17 +333,15 @@ impl TreeNode for Expr { filter, order_by, null_treatment, - }) => transform_vec(args, &mut f)? - .update_data(|new_args| (new_args, filter, order_by)) - .try_transform_node(|(new_args, filter, order_by)| { - Ok(transform_option_box(filter, &mut f)? - .update_data(|new_filter| (new_args, new_filter, order_by))) - })? - .try_transform_node(|(new_args, new_filter, order_by)| { - Ok(transform_option_vec(order_by, &mut f)? - .update_data(|new_order_by| (new_args, new_filter, new_order_by))) - })? - .map_data(|(new_args, new_filter, new_order_by)| match func_def { + }) => map_until_stop_and_collect!( + transform_vec(args, &mut f), + filter, + transform_option_box(filter, &mut f), + order_by, + transform_option_vec(order_by, &mut f) + )? + .map_data( + |(new_args, new_filter, new_order_by)| match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { Ok(Expr::AggregateFunction(AggregateFunction::new( fun, @@ -385,7 +365,8 @@ impl TreeNode for Expr { AggregateFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") } - })?, + }, + )?, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)? .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), @@ -402,15 +383,14 @@ impl TreeNode for Expr { expr, list, negated, - }) => transform_box(expr, &mut f)? - .update_data(|new_expr| (new_expr, list)) - .try_transform_node(|(new_expr, list)| { - Ok(transform_vec(list, &mut f)? - .update_data(|new_list| (new_expr, new_list))) - })? - .update_data(|(new_expr, new_list)| { - Expr::InList(InList::new(new_expr, new_list, negated)) - }), + }) => map_until_stop_and_collect!( + transform_box(expr, &mut f), + list, + transform_vec(list, &mut f) + )? + .update_data(|(new_expr, new_list)| { + Expr::InList(InList::new(new_expr, new_list, negated)) + }), Expr::GetIndexedField(GetIndexedField { expr, field }) => { transform_box(expr, &mut f)?.update_data(|be| { Expr::GetIndexedField(GetIndexedField::new(be, field))