Skip to content

Commit

Permalink
Simply expression rewrite in ProjectionPushdown, make more general
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Nov 9, 2023
1 parent 43cc870 commit d524da9
Showing 1 changed file with 35 additions and 115 deletions.
150 changes: 35 additions & 115 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,9 @@ use arrow_schema::SchemaRef;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::JoinSide;
use datafusion_physical_expr::expressions::{
BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr,
};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::{
Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
ScalarFunctionExpr,
};
use datafusion_physical_plan::union::UnionExec;

Expand Down Expand Up @@ -791,119 +788,42 @@ fn update_expr(
projected_exprs: &[(Arc<dyn PhysicalExpr>, String)],
sync_with_child: bool,
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
let expr_any = expr.as_any();
if let Some(column) = expr_any.downcast_ref::<Column>() {
if sync_with_child {
// Update the index of `column`:
Ok(Some(projected_exprs[column.index()].0.clone()))
} else {
// Determine how to update `column` to accommodate `projected_exprs`:
Ok(projected_exprs.iter().enumerate().find_map(
|(index, (projected_expr, alias))| {
projected_expr.as_any().downcast_ref::<Column>().and_then(
|projected_column| {
column
.name()
.eq(projected_column.name())
.then(|| Arc::new(Column::new(alias, index)) as _)
},
)
},
))
}
} else if let Some(binary) = expr_any.downcast_ref::<BinaryExpr>() {
match (
update_expr(binary.left(), projected_exprs, sync_with_child)?,
update_expr(binary.right(), projected_exprs, sync_with_child)?,
) {
(Some(left), Some(right)) => {
Ok(Some(Arc::new(BinaryExpr::new(left, *binary.op(), right))))
}
_ => Ok(None),
}
} else if let Some(cast) = expr_any.downcast_ref::<CastExpr>() {
update_expr(cast.expr(), projected_exprs, sync_with_child).map(|maybe_expr| {
maybe_expr.map(|expr| {
Arc::new(CastExpr::new(
expr,
cast.cast_type().clone(),
Some(cast.cast_options().clone()),
)) as _
})
})
} else if expr_any.is::<Literal>() {
Ok(Some(expr.clone()))
} else if let Some(negative) = expr_any.downcast_ref::<NegativeExpr>() {
update_expr(negative.arg(), projected_exprs, sync_with_child).map(|maybe_expr| {
maybe_expr.map(|expr| Arc::new(NegativeExpr::new(expr)) as _)
})
} else if let Some(scalar_func) = expr_any.downcast_ref::<ScalarFunctionExpr>() {
scalar_func
.args()
.iter()
.map(|expr| update_expr(expr, projected_exprs, sync_with_child))
.collect::<Result<Option<Vec<_>>>>()
.map(|maybe_args| {
maybe_args.map(|new_args| {
Arc::new(ScalarFunctionExpr::new(
scalar_func.name(),
scalar_func.fun().clone(),
new_args,
scalar_func.return_type(),
scalar_func.monotonicity().clone(),
)) as _
})
})
} else if let Some(case) = expr_any.downcast_ref::<CaseExpr>() {
update_case_expr(case, projected_exprs, sync_with_child)
} else {
Ok(None)
}
}

/// Updates the indices `case` refers to according to `projected_exprs`.
fn update_case_expr(
case: &CaseExpr,
projected_exprs: &[(Arc<dyn PhysicalExpr>, String)],
sync_with_child: bool,
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
let new_case = case
.expr()
.map(|expr| update_expr(expr, projected_exprs, sync_with_child))
.transpose()?
.flatten();

let new_else = case
.else_expr()
.map(|expr| update_expr(expr, projected_exprs, sync_with_child))
.transpose()?
.flatten();

let new_when_then = case
.when_then_expr()
.iter()
.map(|(when, then)| {
Ok((
update_expr(when, projected_exprs, sync_with_child)?,
update_expr(then, projected_exprs, sync_with_child)?,
))
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.filter_map(|(maybe_when, maybe_then)| match (maybe_when, maybe_then) {
(Some(when), Some(then)) => Some((when, then)),
_ => None,
})
.collect::<Vec<_>>();
let mut rewritten = false;

if new_when_then.len() != case.when_then_expr().len()
|| case.expr().is_some() && new_case.is_none()
|| case.else_expr().is_some() && new_else.is_none()
{
return Ok(None);
}
let new_expr = expr
.clone()
.transform_down_mut(&mut |expr: Arc<dyn PhysicalExpr>| {
let Some(column) = expr.as_any().downcast_ref::<Column>() else {
return Ok(Transformed::No(expr));
};
if sync_with_child {
rewritten = true;
// Update the index of `column`:
Ok(Transformed::Yes(projected_exprs[column.index()].0.clone()))
} else {
// Determine how to update `column` to accommodate `projected_exprs`:
let new_col = projected_exprs.iter().enumerate().find_map(
|(index, (projected_expr, alias))| {
projected_expr.as_any().downcast_ref::<Column>().and_then(
|projected_column| {
column
.name()
.eq(projected_column.name())
.then(|| Arc::new(Column::new(alias, index)) as _)
},
)
},
);
if let Some(new_col) = new_col {
rewritten = true;
Ok(Transformed::Yes(new_col))
} else {
Ok(Transformed::No(expr))
}
}
});

CaseExpr::try_new(new_case, new_when_then, new_else).map(|e| Some(Arc::new(e) as _))
new_expr.map(|new_expr| if rewritten { Some(new_expr) } else { None })
}

/// Creates a new [`ProjectionExec`] instance with the given child plan and
Expand Down

0 comments on commit d524da9

Please sign in to comment.