Skip to content

Commit

Permalink
Add TreeNodeMutator API
Browse files Browse the repository at this point in the history
Use TreeNode API in Optimizer
  • Loading branch information
alamb committed Apr 1, 2024
1 parent 9487ca0 commit 1bedfec
Show file tree
Hide file tree
Showing 35 changed files with 1,049 additions and 536 deletions.
2 changes: 1 addition & 1 deletion datafusion-examples/examples/rewrite_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub fn main() -> Result<()> {

// then run the optimizer with our custom rule
let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]);
let optimized_plan = optimizer.optimize(&analyzed_plan, &config, observe)?;
let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?;
println!(
"Optimized Logical Plan:\n\n{}\n",
optimized_plan.display_indent()
Expand Down
162 changes: 159 additions & 3 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

use std::sync::Arc;

use crate::Result;
use crate::{error::_not_impl_err, Result};

/// This macro is used to control continuation behaviors during tree traversals
/// based on the specified direction. Depending on `$DIRECTION` and the value of
Expand Down Expand Up @@ -100,6 +100,10 @@ pub trait TreeNode: Sized {
/// Visit the tree node using the given [`TreeNodeVisitor`], performing a
/// depth-first walk of the node and its children.
///
/// See also:
/// * [`Self::mutate`] to rewrite `TreeNode`s in place
/// * [`Self::rewrite`] to rewrite owned `TreeNode`s
///
/// Consider the following tree structure:
/// ```text
/// ParentNode
Expand Down Expand Up @@ -144,6 +148,10 @@ pub trait TreeNode: Sized {
/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for
/// recursively transforming [`TreeNode`]s.
///
/// See also:
/// * [`Self::mutate`] to rewrite `TreeNode`s in place
/// * [`Self::visit`] for inspecting (without modification) `TreeNode`s
///
/// Consider the following tree structure:
/// ```text
/// ParentNode
Expand Down Expand Up @@ -174,6 +182,70 @@ pub trait TreeNode: Sized {
})
}

/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for
/// recursively mutating / rewriting [`TreeNode`]s in place.
///
/// See also:
/// * [`Self::rewrite`] to rewrite owned `TreeNode`s
/// * [`Self::visit`] for inspecting (without modification) `TreeNode`s
///
/// Consider the following tree structure:
/// ```text
/// ParentNode
/// left: ChildNode1
/// right: ChildNode2
/// ```
///
/// Here, the nodes would be mutataed in the following order:
/// ```text
/// TreeNodeMutator::f_down(ParentNode)
/// TreeNodeMutator::f_down(ChildNode1)
/// TreeNodeMutator::f_up(ChildNode1)
/// TreeNodeMutator::f_down(ChildNode2)
/// TreeNodeMutator::f_up(ChildNode2)
/// TreeNodeMutator::f_up(ParentNode)
/// ```
///
/// See [`TreeNodeRecursion`] for more details on controlling the traversal.
///
/// # Error Handling
///
/// If [`TreeNodeVisitor::f_down()`] or [`TreeNodeVisitor::f_up()`] returns [`Err`],
/// the recursion stops immediately and the tree may be left partially changed
///
/// # Changing Children During Traversal
///
/// If `f_down` changes the nodes children, the new children are visited
/// (not the old children prior to rewrite)
fn mutate<M: TreeNodeMutator<Node = Self>>(
&mut self,
mutator: &mut M,
) -> Result<Transformed<()>> {
// Note this is an inlined version of handle_transform_recursion!
let pre_visited = mutator.f_down(self)?;

// Traverse children and then call f_up on self if necessary
match pre_visited.tnr {
TreeNodeRecursion::Continue => {
// rewrite children recursively with mutator
self.mutate_children(|c| c.mutate(mutator))?
.try_transform_node_with(
|_: ()| mutator.f_up(self),
TreeNodeRecursion::Jump,
)
}
TreeNodeRecursion::Jump => {
// skip other children and start back up
mutator.f_up(self)
}
TreeNodeRecursion::Stop => return Ok(pre_visited),
}
.map(|mut post_visited| {
post_visited.transformed |= pre_visited.transformed;
post_visited
})
}

/// Applies `f` to the node and its children. `f` is applied in a pre-order
/// way, and it is controlled by [`TreeNodeRecursion`], which means result
/// of the `f` on a node can cause an early return.
Expand Down Expand Up @@ -353,13 +425,38 @@ pub trait TreeNode: Sized {
}

/// Apply the closure `F` to the node's children.
///
/// See `mutate_children` for rewriting in place
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
f: &mut F,
) -> Result<TreeNodeRecursion>;

/// Apply transform `F` to the node's children. Note that the transform `F`
/// might have a direction (pre-order or post-order).
/// Rewrite the node's children in place using `F`.
///
/// On error, `self` is left partially rewritten.
///
/// # Notes
///
/// Using [`Self::map_children`], the owned API, has clearer semantics on
/// error (the node is consumed). However, it requires copying the interior
/// fields of the tree node during rewrite.
///
/// This API writes the nodes in place, which can be faster as it avoids
/// copying, but leaves the tree node in an partially rewritten state when
/// an error occurs.
fn mutate_children<F: FnMut(&mut Self) -> Result<Transformed<()>>>(
&mut self,
_f: F,
) -> Result<Transformed<()>> {
_not_impl_err!(
"mutate_children not implemented for {} yet",
std::any::type_name::<Self>()
)
}

/// Apply transform `F` to potentially rewrite the node's children. Note
/// that the transform `F` might have a direction (pre-order or post-order).
fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: F,
Expand Down Expand Up @@ -411,6 +508,41 @@ pub trait TreeNodeRewriter: Sized {
}
}

/// Trait for mutating (rewriting in place) [`TreeNode`]s
///
/// # See Also:
/// * [`TreeNodeRewriter`] for rewriting owned `TreeNode`e
/// * [`TreeNodeVisitor`] for visiting, but not changing, `TreeNode`s
pub trait TreeNodeMutator: Sized {
/// The node type to mutating.
type Node: TreeNode;

/// Invoked while traversing down the tree before any children are mutated.
/// Default implementation does nothing to the node and continues recursion.
///
/// # Notes
///
/// As the node maybe mutated in place, the returned [`Transformed`] object
/// returns `()` (no data).
///
/// If the node's children are changed by `f_down`, the *new* children are
/// visited, not the original children.
fn f_down(&mut self, _node: &mut Self::Node) -> Result<Transformed<()>> {
Ok(Transformed::no(()))
}

/// Invoked while traversing up the tree after all children have been mutated.
/// Default implementation does nothing to the node and continues recursion.
///
/// # Notes
///
/// As the node maybe mutated in place, the returned [`Transformed`] object
/// returns `()` (no data).
fn f_up(&mut self, _node: &mut Self::Node) -> Result<Transformed<()>> {
Ok(Transformed::no(()))
}
}

/// Controls how [`TreeNode`] recursions should proceed.
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum TreeNodeRecursion {
Expand Down Expand Up @@ -489,6 +621,11 @@ impl<T> Transformed<T> {
f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr))
}

/// Invokes f(), depending on the value of self.tnr.
///
/// This is used to conditionally apply a function during a f_up tree
/// traversal, if the result of children traversal was `[`TreeNodeRecursion::Continue`].
///
/// Handling [`TreeNodeRecursion::Continue`] and [`TreeNodeRecursion::Stop`]
/// is straightforward, but [`TreeNodeRecursion::Jump`] can behave differently
/// when we are traversing down or up on a tree. If [`TreeNodeRecursion`] of
Expand Down Expand Up @@ -532,6 +669,25 @@ impl<T> Transformed<T> {
}
}

impl Transformed<()> {
/// Invoke the given function `f` and combine the transformed state with
/// the current state:
///
/// * if `f` returns an Err, returns that err
///
/// * If `f` returns Ok, sets `self.transformed` to `true` if either self or
/// the result of `f` were transformed.
pub fn and_then<F>(self, f: F) -> Result<Transformed<()>>
where
F: FnOnce() -> Result<Transformed<()>>,
{
f().map(|mut t| {
t.transformed |= self.transformed;
t
})
}
}

/// Transformation helper to process tree nodes that are siblings.
pub trait TransformedIterator: Iterator {
fn map_until_stop_and_collect<
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1877,7 +1877,7 @@ impl SessionState {

// optimize the child plan, capturing the output of each optimizer
let optimized_plan = self.optimizer.optimize(
&analyzed_plan,
analyzed_plan,
self,
|optimized_plan, optimizer| {
let optimizer_name = optimizer.name().to_string();
Expand Down Expand Up @@ -1907,7 +1907,7 @@ impl SessionState {
let analyzed_plan =
self.analyzer
.execute_and_check(plan, self.options(), |_, _| {})?;
self.optimizer.optimize(&analyzed_plan, self, |_, _| {})
self.optimizer.optimize(analyzed_plan, self, |_, _| {})
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
let optimizer = Optimizer::new();
// analyze and optimize the logical plan
let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?;
optimizer.optimize(&plan, &config, |_, _| {})
optimizer.optimize(plan, &config, |_, _| {})
}

#[derive(Default)]
Expand Down
18 changes: 18 additions & 0 deletions datafusion/expr/src/logical_plan/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,24 @@ impl DdlStatement {
}
}

/// Return a mutable reference to the input `LogicalPlan`, if any
pub fn input_mut(&mut self) -> Option<&mut Arc<LogicalPlan>> {
match self {
DdlStatement::CreateMemoryTable(CreateMemoryTable { input, .. }) => {
Some(input)
}
DdlStatement::CreateExternalTable(_) => None,
DdlStatement::CreateView(CreateView { input, .. }) => Some(input),
DdlStatement::CreateCatalogSchema(_) => None,
DdlStatement::CreateCatalog(_) => None,
DdlStatement::DropTable(_) => None,
DdlStatement::DropView(_) => None,
DdlStatement::DropCatalogSchema(_) => None,
DdlStatement::CreateFunction(_) => None,
DdlStatement::DropFunction(_) => None,
}
}

/// Return a `format`able structure with the a human readable
/// description of this LogicalPlan node per node, not including
/// children.
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod ddl;
pub mod display;
pub mod dml;
mod extension;
mod mutate;
mod plan;
mod statement;

Expand Down
Loading

0 comments on commit 1bedfec

Please sign in to comment.