Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support LIMIT Push-down logical plan optimization for Extension nodes #12685

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2557,6 +2557,10 @@ mod tests {
) -> Result<Self> {
unimplemented!("NoOp");
}

fn supports_limit_pushdown(&self) -> bool {
false // Disallow limit push-down by default
}
}

#[derive(Debug)]
Expand Down
4 changes: 4 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,10 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode {
expr: replace_sort_expression(self.expr.clone(), exprs.swap_remove(0)),
})
}

fn supports_limit_pushdown(&self) -> bool {
false // Disallow limit push-down by default
}
}

/// Physical planner for TopK nodes
Expand Down
24 changes: 24 additions & 0 deletions datafusion/expr/src/logical_plan/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,16 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
/// directly because it must remain object safe.
fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool;
fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering>;

/// Returns `true` if a limit can be safely pushed down through this
/// `UserDefinedLogicalNode` node.
///
/// If this method returns `true`, and the query plan contains a limit at
/// the output of this node, DataFusion will push the limit to the input
/// of this node.
fn supports_limit_pushdown(&self) -> bool {
false
}
}

impl Hash for dyn UserDefinedLogicalNode {
Expand Down Expand Up @@ -295,6 +305,16 @@ pub trait UserDefinedLogicalNodeCore:
) -> Option<Vec<Vec<usize>>> {
None
}

/// Returns `true` if a limit can be safely pushed down through this
/// `UserDefinedLogicalNode` node.
///
/// If this method returns `true`, and the query plan contains a limit at
/// the output of this node, DataFusion will push the limit to the input
/// of this node.
fn supports_limit_pushdown(&self) -> bool {
false // Disallow limit push-down by default
}
}

/// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode`
Expand Down Expand Up @@ -361,6 +381,10 @@ impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
.downcast_ref::<Self>()
.and_then(|other| self.partial_cmp(other))
}

fn supports_limit_pushdown(&self) -> bool {
self.supports_limit_pushdown()
}
}

fn get_all_columns_from_schema(schema: &DFSchema) -> HashSet<String> {
Expand Down
4 changes: 4 additions & 0 deletions datafusion/optimizer/src/analyzer/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ mod test {
empty_schema: Arc::clone(&self.empty_schema),
})
}

fn supports_limit_pushdown(&self) -> bool {
false // Disallow limit push-down by default
}
}

#[test]
Expand Down
8 changes: 8 additions & 0 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,10 @@ mod tests {
// Since schema is same. Output columns requires their corresponding version in the input columns.
Some(vec![output_columns.to_vec()])
}

fn supports_limit_pushdown(&self) -> bool {
false // Disallow limit push-down by default
}
}

#[derive(Debug, Hash, PartialEq, Eq)]
Expand Down Expand Up @@ -991,6 +995,10 @@ mod tests {
}
Some(vec![left_reqs, right_reqs])
}

fn supports_limit_pushdown(&self) -> bool {
false // Disallow limit push-down by default
}
}

#[test]
Expand Down
4 changes: 4 additions & 0 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,10 @@ mod tests {
schema: Arc::clone(&self.schema),
})
}

fn supports_limit_pushdown(&self) -> bool {
false // Disallow limit push-down by default
}
}

#[test]
Expand Down
249 changes: 248 additions & 1 deletion datafusion/optimizer/src/push_down_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,29 @@ impl OptimizerRule for PushDownLimit {
subquery_alias.input = Arc::new(new_limit);
Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias)))
}
LogicalPlan::Extension(extension_plan)
if extension_plan.node.supports_limit_pushdown() =>
{
let new_children = extension_plan
.node
.inputs()
.into_iter()
.map(|child| {
LogicalPlan::Limit(Limit {
skip: 0,
fetch: Some(fetch + skip),
input: Arc::new(child.clone()),
})
})
.collect::<Vec<_>>();

// Create a new extension node with updated inputs
let child_plan = LogicalPlan::Extension(extension_plan);
let new_extension =
child_plan.with_new_exprs(child_plan.expressions(), new_children)?;

transformed_limit(skip, fetch, new_extension)
}
input => original_limit(skip, fetch, input),
}
}
Expand Down Expand Up @@ -258,17 +281,241 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed<Join> {

#[cfg(test)]
mod test {
use std::cmp::Ordering;
use std::fmt::{Debug, Formatter};
use std::vec;

use super::*;
use crate::test::*;
use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder};

use datafusion_common::DFSchemaRef;
use datafusion_expr::{
col, exists, logical_plan::builder::LogicalPlanBuilder, Expr, Extension,
UserDefinedLogicalNodeCore,
};
use datafusion_functions_aggregate::expr_fn::max;

fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected)
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct NoopPlan {
input: Vec<LogicalPlan>,
schema: DFSchemaRef,
}

// Manual implementation needed because of `schema` field. Comparison excludes this field.
impl PartialOrd for NoopPlan {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.input.partial_cmp(&other.input)
}
}

impl UserDefinedLogicalNodeCore for NoopPlan {
fn name(&self) -> &str {
"NoopPlan"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
self.input.iter().collect()
}

fn schema(&self) -> &DFSchemaRef {
&self.schema
}

fn expressions(&self) -> Vec<Expr> {
self.input
.iter()
.flat_map(|child| child.expressions())
.collect()
}

fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "NoopPlan")
}

fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Self> {
Ok(Self {
input: inputs,
schema: Arc::clone(&self.schema),
})
}

fn supports_limit_pushdown(&self) -> bool {
true // Allow limit push-down
}
}

#[derive(Debug, PartialEq, Eq, Hash)]
struct NoLimitNoopPlan {
input: Vec<LogicalPlan>,
schema: DFSchemaRef,
}

// Manual implementation needed because of `schema` field. Comparison excludes this field.
impl PartialOrd for NoLimitNoopPlan {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.input.partial_cmp(&other.input)
}
}

impl UserDefinedLogicalNodeCore for NoLimitNoopPlan {
fn name(&self) -> &str {
"NoLimitNoopPlan"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
self.input.iter().collect()
}

fn schema(&self) -> &DFSchemaRef {
&self.schema
}

fn expressions(&self) -> Vec<Expr> {
self.input
.iter()
.flat_map(|child| child.expressions())
.collect()
}

fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "NoLimitNoopPlan")
}

fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Self> {
Ok(Self {
input: inputs,
schema: Arc::clone(&self.schema),
})
}

fn supports_limit_pushdown(&self) -> bool {
false // Disallow limit push-down by default
}
}
#[test]
fn limit_pushdown_basic() -> Result<()> {
let table_scan = test_table_scan()?;
let noop_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});

let plan = LogicalPlanBuilder::from(noop_plan)
.limit(0, Some(1000))?
.build()?;

let expected = "Limit: skip=0, fetch=1000\
\n NoopPlan\
\n Limit: skip=0, fetch=1000\
\n TableScan: test, fetch=1000";

assert_optimized_plan_equal(plan, expected)
}

#[test]
fn limit_pushdown_with_skip() -> Result<()> {
let table_scan = test_table_scan()?;
let noop_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});

let plan = LogicalPlanBuilder::from(noop_plan)
.limit(10, Some(1000))?
.build()?;

let expected = "Limit: skip=10, fetch=1000\
\n NoopPlan\
\n Limit: skip=0, fetch=1010\
\n TableScan: test, fetch=1010";

assert_optimized_plan_equal(plan, expected)
}

#[test]
fn limit_pushdown_multiple_limits() -> Result<()> {
let table_scan = test_table_scan()?;
let noop_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});

let plan = LogicalPlanBuilder::from(noop_plan)
.limit(10, Some(1000))?
.limit(20, Some(500))?
.build()?;

let expected = "Limit: skip=30, fetch=500\
\n NoopPlan\
\n Limit: skip=0, fetch=530\
\n TableScan: test, fetch=530";

assert_optimized_plan_equal(plan, expected)
}

#[test]
fn limit_pushdown_multiple_inputs() -> Result<()> {
let table_scan = test_table_scan()?;
let noop_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone(), table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});

let plan = LogicalPlanBuilder::from(noop_plan)
.limit(0, Some(1000))?
.build()?;

let expected = "Limit: skip=0, fetch=1000\
\n NoopPlan\
\n Limit: skip=0, fetch=1000\
\n TableScan: test, fetch=1000\
\n Limit: skip=0, fetch=1000\
\n TableScan: test, fetch=1000";

assert_optimized_plan_equal(plan, expected)
}

#[test]
fn limit_pushdown_disallowed_noop_plan() -> Result<()> {
let table_scan = test_table_scan()?;
let no_limit_noop_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoLimitNoopPlan {
input: vec![table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});

let plan = LogicalPlanBuilder::from(no_limit_noop_plan)
.limit(0, Some(1000))?
.build()?;

let expected = "Limit: skip=0, fetch=1000\
\n NoLimitNoopPlan\
\n TableScan: test";

assert_optimized_plan_equal(plan, expected)
}

#[test]
fn limit_pushdown_projection_table_provider() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down
4 changes: 4 additions & 0 deletions datafusion/optimizer/src/test/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,8 @@ impl UserDefinedLogicalNodeCore for TestUserDefinedPlanNode {
input: inputs.swap_remove(0),
})
}

fn supports_limit_pushdown(&self) -> bool {
false // Disallow limit push-down by default
}
}
Loading