Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Expr::map_children #9876

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ impl<T> Transformed<T> {
}
}

/// Transformation helper to process tree nodes that are siblings.
/// Transformation helper to process sequence of iterable tree nodes that are siblings.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the reason / logic for continue / jump quite subtle when trying to make TreeNodeMutator. I have some suggested comments below to explain the rationale that might help

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @alamb, merged your suggestion.

pub trait TransformedIterator: Iterator {
fn map_until_stop_and_collect<
peter-toth marked this conversation as resolved.
Show resolved Hide resolved
F: FnMut(Self::Item) -> Result<Transformed<Self::Item>>,
Expand All @@ -551,22 +551,48 @@ impl<I: Iterator> TransformedIterator for I {
) -> Result<Transformed<Vec<Self::Item>>> {
let mut tnr = TreeNodeRecursion::Continue;
let mut transformed = false;
let data = self
.map(|item| match tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
f(item).map(|result| {
tnr = result.tnr;
transformed |= result.transformed;
result.data
})
}
TreeNodeRecursion::Stop => Ok(item),
})
.collect::<Result<Vec<_>>>()?;
Ok(Transformed::new(data, transformed, tnr))
self.map(|item| match tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
f(item).map(|result| {
tnr = result.tnr;
transformed |= result.transformed;
result.data
})
}
TreeNodeRecursion::Stop => Ok(item),
})
.collect::<Result<Vec<_>>>()
.map(|data| Transformed::new(data, transformed, tnr))
}
}

/// Transformation helper to process sequence of tree node containing expressions.
/// This macro is very similar to [TransformedIterator::map_until_stop_and_collect] to
/// process nodes that are siblings, but it accepts an initial transformation and a
/// sequence of pairs of an expression and its transformation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite clever. I have some suggestions on naming that would have helped me:

Use F to mirror the nomenclature of TransformedIterator. I thought $TRANSFORMED_EXPR was an actually expr (but it is actually a closure that is invoked I think 🤔 )

$TRANSFORMED_EXPR_0 --> F0      
$TRANSFORMED_EXPR --> F.

I think it would help to document each parameter and the return value (specifically it looks like it returns Transformed<(data, ..., data)>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let me fix these and come back yo you.

Copy link
Contributor Author

@peter-toth peter-toth Apr 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed in a36f1aa and added rustdoc to the macro. Let me know if it needs more details.

#[macro_export]
macro_rules! map_until_stop_and_collect {
($TRANSFORMED_EXPR_0:expr, $($EXPR:expr, $TRANSFORMED_EXPR:expr),*) => {{
$TRANSFORMED_EXPR_0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| {
let data = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could name this something other than data so it was clearer what type it is (maybe all_datas?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in a36f1aa

data0,
$(
if tnr == TreeNodeRecursion::Continue || tnr == TreeNodeRecursion::Jump {
$TRANSFORMED_EXPR.map(|result| {
tnr = result.tnr;
transformed |= result.transformed;
result.data
})?
} else {
$EXPR
},
)*
);
Ok(Transformed::new(data, transformed, tnr))
})
}}
}

/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
pub trait TransformedResult<T> {
fn data(self) -> Result<T>;
Expand Down
226 changes: 103 additions & 123 deletions datafusion/expr/src/tree_node/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use crate::{Expr, GetFieldAccess};
use datafusion_common::tree_node::{
Transformed, TransformedIterator, TreeNode, TreeNodeRecursion,
};
use datafusion_common::{handle_visit_recursion, internal_err, Result};
use datafusion_common::{
handle_visit_recursion, internal_err, map_until_stop_and_collect, Result,
};

impl TreeNode for Expr {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
Expand Down Expand Up @@ -167,58 +169,55 @@ impl TreeNode for Expr {
Expr::InSubquery(InSubquery::new(be, subquery, negated))
}),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
transform_box(left, &mut f)?
.update_data(|new_left| (new_left, right))
.try_transform_node(|(new_left, right)| {
Ok(transform_box(right, &mut f)?
.update_data(|new_right| (new_left, new_right)))
})?
.update_data(|(new_left, new_right)| {
Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
})
map_until_stop_and_collect!(
transform_box(left, &mut f),
right,
transform_box(right, &mut f)
)?
.update_data(|(new_left, new_right)| {
Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
})
}
Expr::Like(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => transform_box(expr, &mut f)?
.update_data(|new_expr| (new_expr, pattern))
.try_transform_node(|(new_expr, pattern)| {
Ok(transform_box(pattern, &mut f)?
.update_data(|new_pattern| (new_expr, new_pattern)))
})?
.update_data(|(new_expr, new_pattern)| {
Expr::Like(Like::new(
negated,
new_expr,
new_pattern,
escape_char,
case_insensitive,
))
}),
}) => map_until_stop_and_collect!(
transform_box(expr, &mut f),
pattern,
transform_box(pattern, &mut f)
)?
.update_data(|(new_expr, new_pattern)| {
Expr::Like(Like::new(
negated,
new_expr,
new_pattern,
escape_char,
case_insensitive,
))
}),
Expr::SimilarTo(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => transform_box(expr, &mut f)?
.update_data(|new_expr| (new_expr, pattern))
.try_transform_node(|(new_expr, pattern)| {
Ok(transform_box(pattern, &mut f)?
.update_data(|new_pattern| (new_expr, new_pattern)))
})?
.update_data(|(new_expr, new_pattern)| {
Expr::SimilarTo(Like::new(
negated,
new_expr,
new_pattern,
escape_char,
case_insensitive,
))
}),
}) => map_until_stop_and_collect!(
transform_box(expr, &mut f),
pattern,
transform_box(pattern, &mut f)
)?
.update_data(|(new_expr, new_pattern)| {
Expr::SimilarTo(Like::new(
negated,
new_expr,
new_pattern,
escape_char,
case_insensitive,
))
}),
Expr::Not(expr) => transform_box(expr, &mut f)?.update_data(Expr::Not),
Expr::IsNotNull(expr) => {
transform_box(expr, &mut f)?.update_data(Expr::IsNotNull)
Expand Down Expand Up @@ -248,48 +247,38 @@ impl TreeNode for Expr {
negated,
low,
high,
}) => transform_box(expr, &mut f)?
.update_data(|new_expr| (new_expr, low, high))
.try_transform_node(|(new_expr, low, high)| {
Ok(transform_box(low, &mut f)?
.update_data(|new_low| (new_expr, new_low, high)))
})?
.try_transform_node(|(new_expr, new_low, high)| {
Ok(transform_box(high, &mut f)?
.update_data(|new_high| (new_expr, new_low, new_high)))
})?
.update_data(|(new_expr, new_low, new_high)| {
Expr::Between(Between::new(new_expr, negated, new_low, new_high))
}),
}) => map_until_stop_and_collect!(
transform_box(expr, &mut f),
low,
transform_box(low, &mut f),
high,
transform_box(high, &mut f)
)?
.update_data(|(new_expr, new_low, new_high)| {
Expr::Between(Between::new(new_expr, negated, new_low, new_high))
}),
Expr::Case(Case {
expr,
when_then_expr,
else_expr,
}) => transform_option_box(expr, &mut f)?
.update_data(|new_expr| (new_expr, when_then_expr, else_expr))
.try_transform_node(|(new_expr, when_then_expr, else_expr)| {
Ok(when_then_expr
.into_iter()
.map_until_stop_and_collect(|(when, then)| {
transform_box(when, &mut f)?
.update_data(|new_when| (new_when, then))
.try_transform_node(|(new_when, then)| {
Ok(transform_box(then, &mut f)?
.update_data(|new_then| (new_when, new_then)))
})
})?
.update_data(|new_when_then_expr| {
(new_expr, new_when_then_expr, else_expr)
}))
})?
.try_transform_node(|(new_expr, new_when_then_expr, else_expr)| {
Ok(transform_option_box(else_expr, &mut f)?.update_data(
|new_else_expr| (new_expr, new_when_then_expr, new_else_expr),
))
})?
.update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
}),
}) => map_until_stop_and_collect!(
transform_option_box(expr, &mut f),
when_then_expr,
when_then_expr
.into_iter()
.map_until_stop_and_collect(|(when, then)| {
map_until_stop_and_collect!(
transform_box(when, &mut f),
then,
transform_box(then, &mut f)
)
}),
else_expr,
transform_option_box(else_expr, &mut f)
)?
.update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
}),
Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)?
.update_data(|be| Expr::Cast(Cast::new(be, data_type))),
Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)?
Expand Down Expand Up @@ -320,48 +309,39 @@ impl TreeNode for Expr {
order_by,
window_frame,
null_treatment,
}) => transform_vec(args, &mut f)?
.update_data(|new_args| (new_args, partition_by, order_by))
.try_transform_node(|(new_args, partition_by, order_by)| {
Ok(transform_vec(partition_by, &mut f)?.update_data(
|new_partition_by| (new_args, new_partition_by, order_by),
))
})?
.try_transform_node(|(new_args, new_partition_by, order_by)| {
Ok(
transform_vec(order_by, &mut f)?.update_data(|new_order_by| {
(new_args, new_partition_by, new_order_by)
}),
)
})?
.update_data(|(new_args, new_partition_by, new_order_by)| {
Expr::WindowFunction(WindowFunction::new(
fun,
new_args,
new_partition_by,
new_order_by,
window_frame,
null_treatment,
))
}),
}) => map_until_stop_and_collect!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very clever

transform_vec(args, &mut f),
partition_by,
transform_vec(partition_by, &mut f),
order_by,
transform_vec(order_by, &mut f)
)?
.update_data(|(new_args, new_partition_by, new_order_by)| {
Expr::WindowFunction(WindowFunction::new(
fun,
new_args,
new_partition_by,
new_order_by,
window_frame,
null_treatment,
))
}),
Expr::AggregateFunction(AggregateFunction {
args,
func_def,
distinct,
filter,
order_by,
null_treatment,
}) => transform_vec(args, &mut f)?
.update_data(|new_args| (new_args, filter, order_by))
.try_transform_node(|(new_args, filter, order_by)| {
Ok(transform_option_box(filter, &mut f)?
.update_data(|new_filter| (new_args, new_filter, order_by)))
})?
.try_transform_node(|(new_args, new_filter, order_by)| {
Ok(transform_option_vec(order_by, &mut f)?
.update_data(|new_order_by| (new_args, new_filter, new_order_by)))
})?
.map_data(|(new_args, new_filter, new_order_by)| match func_def {
}) => map_until_stop_and_collect!(
transform_vec(args, &mut f),
filter,
transform_option_box(filter, &mut f),
order_by,
transform_option_vec(order_by, &mut f)
)?
.map_data(
|(new_args, new_filter, new_order_by)| match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun,
Expand All @@ -384,7 +364,8 @@ impl TreeNode for Expr {
AggregateFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
})?,
},
)?,
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)?
.update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
Expand All @@ -401,15 +382,14 @@ impl TreeNode for Expr {
expr,
list,
negated,
}) => transform_box(expr, &mut f)?
.update_data(|new_expr| (new_expr, list))
.try_transform_node(|(new_expr, list)| {
Ok(transform_vec(list, &mut f)?
.update_data(|new_list| (new_expr, new_list)))
})?
.update_data(|(new_expr, new_list)| {
Expr::InList(InList::new(new_expr, new_list, negated))
}),
}) => map_until_stop_and_collect!(
transform_box(expr, &mut f),
list,
transform_vec(list, &mut f)
)?
.update_data(|(new_expr, new_list)| {
Expr::InList(InList::new(new_expr, new_list, negated))
}),
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
transform_box(expr, &mut f)?.update_data(|be| {
Expr::GetIndexedField(GetIndexedField::new(be, field))
Expand Down
Loading