Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encode all join conditions in a single expression field #7612

Merged
merged 6 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 66 additions & 42 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef};

use datafusion::logical_expr::{
aggregate_function, window_function::find_df_window_func, BinaryExpr,
BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
Expand Down Expand Up @@ -129,6 +130,51 @@ fn scalar_function_type_from_str(name: &str) -> Result<ScalarFunctionType> {
}
}

fn split_eq_and_noneq_join_predicate_with_nulls_equality(
filter: &Expr,
) -> (Vec<(Column, Column)>, bool, Option<Expr>) {
let exprs = split_conjunction(filter);

let mut accum_join_keys: Vec<(Column, Column)> = vec![];
let mut accum_filters: Vec<Expr> = vec![];
let mut nulls_equal_nulls = false;

for expr in exprs {
match expr {
Expr::BinaryExpr(binary_expr) => match binary_expr {
x @ (BinaryExpr {
left,
op: Operator::Eq,
right,
}
| BinaryExpr {
left,
op: Operator::IsNotDistinctFrom,
right,
}) => {
nulls_equal_nulls = match x.op {
Operator::Eq => false,
Operator::IsNotDistinctFrom => true,
_ => unreachable!(),
};

match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => {
accum_join_keys.push((l.clone(), r.clone()));
}
_ => accum_filters.push(expr.clone()),
}
}
_ => accum_filters.push(expr.clone()),
},
_ => accum_filters.push(expr.clone()),
}
}

let join_filter = accum_filters.into_iter().reduce(Expr::and);
(accum_join_keys, nulls_equal_nulls, join_filter)
}

/// Convert Substrait Plan to DataFusion DataFrame
pub async fn from_substrait_plan(
ctx: &mut SessionContext,
Expand Down Expand Up @@ -336,7 +382,13 @@ pub async fn from_substrait_rel(
}
}
Some(RelType::Join(join)) => {
let left = LogicalPlanBuilder::from(
if join.post_join_filter.is_some() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

return not_impl_err!(
"JoinRel with post_join_filter is not yet supported"
);
}

let left: LogicalPlanBuilder = LogicalPlanBuilder::from(
from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?,
);
let right = LogicalPlanBuilder::from(
Expand All @@ -346,60 +398,32 @@ pub async fn from_substrait_rel(
// The join condition expression needs full input schema and not the output schema from join since we lose columns from
// certain join types such as semi and anti joins
let in_join_schema = left.schema().join(right.schema())?;
// Parse post join filter if exists
let join_filter = match &join.post_join_filter {
Some(filter) => {
let parsed_filter =
from_substrait_rex(filter, &in_join_schema, extensions).await?;
Some(parsed_filter.as_ref().clone())
}
None => None,
};

// If join expression exists, parse the `on` condition expression, build join and return
// Otherwise, build join with koin filter, without join keys
// Otherwise, build join with only the filter, without join keys
match &join.expression.as_ref() {
Some(expr) => {
let on =
from_substrait_rex(expr, &in_join_schema, extensions).await?;
let predicates = split_conjunction(&on);
// TODO: collect only one null_eq_null
let join_exprs: Vec<(Column, Column, bool)> = predicates
.iter()
.map(|p| match p {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => match op {
Operator::Eq => Ok((l.clone(), r.clone(), false)),
Operator::IsNotDistinctFrom => {
Ok((l.clone(), r.clone(), true))
}
_ => plan_err!("invalid join condition op"),
},
_ => plan_err!("invalid join condition expression"),
}
}
_ => plan_err!(
"Non-binary expression is not supported in join condition"
),
})
.collect::<Result<Vec<_>>>()?;
let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>, Vec<_>) =
itertools::multiunzip(join_exprs);
// The join expression can contain both equal and non-equal ops.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the ExtractEquijoinPredicate optimizer pass already splits up join predicates into equijoin predicates and "other" predicates, I wonder if simply create the LogicalPlan::Join using join.expression (and let the subsequent optimizer pass sort it out)?

Something like

left.join(
  right.build()?,
  join_type,
  (vec![], vec![]),
  on, // <-- use the filter directly here, let optimizer pass extract the equijoin columns
  nulls_equal_nulls,
)?

It makes me realize when looking at the API for LogicalPlanBuilder::join that the API is super confusing. It would be nice to improve that API to make it clear that a join can just take a single Expr and DataFusion will sort out figuring out the join columns, etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out this is exactly what DataFrame::join_on does -- I have filed a ticket with a way to make this clearer: #7766 (comment)

// As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields.
// So we extract each part as follows:
// - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector
// - Otherwise we add the expression to join_filter (use conjunction if filter already exists)
let (join_ons, nulls_equal_nulls, join_filter) =
split_eq_and_noneq_join_predicate_with_nulls_equality(&on);
let (left_cols, right_cols): (Vec<_>, Vec<_>) =
itertools::multiunzip(join_ons);
left.join_detailed(
right.build()?,
join_type,
(left_cols, right_cols),
join_filter,
null_eq_nulls[0],
nulls_equal_nulls,
)?
.build()
}
None => match &join_filter {
Some(_) => left
.join_on(right.build()?, join_type, join_filter)?
.build(),
None => plan_err!("Join without join keys require a valid filter"),
},
None => plan_err!("JoinRel without join condition is not allowed"),
}
}
Some(RelType::Read(read)) => match &read.as_ref().read_type {
Expand Down
33 changes: 25 additions & 8 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,39 +278,56 @@ pub fn to_substrait_rel(
// parse filter if exists
let in_join_schema = join.left.schema().join(join.right.schema())?;
let join_filter = match &join.filter {
Some(filter) => Some(Box::new(to_substrait_rex(
Some(filter) => Some(to_substrait_rex(
filter,
&Arc::new(in_join_schema),
0,
extension_info,
)?)),
)?),
None => None,
};

// map the left and right columns to binary expressions in the form `l = r`
// build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b`
let eq_op = if join.null_equals_null {
Operator::IsNotDistinctFrom
} else {
Operator::Eq
};

let join_expr = to_substrait_join_expr(
let join_on = to_substrait_join_expr(
&join.on,
eq_op,
join.left.schema(),
join.right.schema(),
extension_info,
)?
.map(Box::new);
)?;

// create conjunction between `join_on` and `join_filter` to embed all join conditions,
// whether equal or non-equal in a single expression
let join_expr = match &join_on {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you could use conjunction here and simplify the code https://docs.rs/datafusion/latest/datafusion/optimizer/utils/fn.conjunction.html

Some(on_expr) => match &join_filter {
Some(filter) => Some(Box::new(make_binary_op_scalar_func(
on_expr,
filter,
Operator::And,
extension_info,
))),
None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist
},
None => match &join_filter {
Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist
None => None,
},
};

Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
common: None,
left: Some(left),
right: Some(right),
r#type: join_type as i32,
expression: join_expr,
post_join_filter: join_filter,
expression: join_expr.clone(),
post_join_filter: None,
advanced_extension: None,
}))),
}))
Expand Down
118 changes: 114 additions & 4 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@ use std::hash::Hash;
use std::sync::Arc;

use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::common::{DFSchema, DFSchemaRef};
use datafusion::error::Result;
use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef};
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::SessionState;
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode};
use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST;
use datafusion::prelude::*;

use substrait::proto::extensions::simple_extension_declaration::MappingType;
use substrait::proto::rel::RelType;
use substrait::proto::{plan_rel, Plan, Rel};

struct MockSerializerRegistry;

Expand Down Expand Up @@ -383,12 +386,15 @@ async fn roundtrip_inner_join() -> Result<()> {

#[tokio::test]
async fn roundtrip_non_equi_inner_join() -> Result<()> {
roundtrip("SELECT data.a FROM data JOIN data2 ON data.a <> data2.a").await
roundtrip_verify_post_join_filter(
"SELECT data.a FROM data JOIN data2 ON data.a <> data2.a",
)
.await
}

#[tokio::test]
async fn roundtrip_non_equi_join() -> Result<()> {
roundtrip(
roundtrip_verify_post_join_filter(
"SELECT data.a FROM data, data2 WHERE data.a = data2.a AND data.e > data2.a",
)
.await
Expand Down Expand Up @@ -620,6 +626,91 @@ async fn extension_logical_plan() -> Result<()> {
Ok(())
}

fn check_post_join_filters(rel: &Rel) -> Result<()> {
// search for target_rel and field value in proto
match &rel.rel_type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it might be helpful (eventually) do define TreeNode for Rel to implement walking efficiently.

Some(RelType::Join(join)) => {
// check if join filter is None
if join.post_join_filter.is_some() {
plan_err!(
"DataFusion generated Susbtrait plan cannot have post_join_filter in JoinRel"
)
} else {
// recursively check JoinRels
match check_post_join_filters(join.left.as_ref().unwrap().as_ref()) {
Err(e) => Err(e),
Ok(_) => {
check_post_join_filters(join.right.as_ref().unwrap().as_ref())
}
}
}
}
Some(RelType::Project(p)) => {
check_post_join_filters(p.input.as_ref().unwrap().as_ref())
}
Some(RelType::Filter(filter)) => {
check_post_join_filters(filter.input.as_ref().unwrap().as_ref())
}
Some(RelType::Fetch(fetch)) => {
check_post_join_filters(fetch.input.as_ref().unwrap().as_ref())
}
Some(RelType::Sort(sort)) => {
check_post_join_filters(sort.input.as_ref().unwrap().as_ref())
}
Some(RelType::Aggregate(agg)) => {
check_post_join_filters(agg.input.as_ref().unwrap().as_ref())
}
Some(RelType::Set(set)) => {
for input in &set.inputs {
match check_post_join_filters(input) {
Err(e) => return Err(e),
Ok(_) => continue,
}
}
Ok(())
}
Some(RelType::ExtensionSingle(ext)) => {
check_post_join_filters(ext.input.as_ref().unwrap().as_ref())
}
Some(RelType::ExtensionMulti(ext)) => {
for input in &ext.inputs {
match check_post_join_filters(input) {
Err(e) => return Err(e),
Ok(_) => continue,
}
}
Ok(())
}
Some(RelType::ExtensionLeaf(_)) | Some(RelType::Read(_)) => Ok(()),
_ => not_impl_err!(
"Unsupported RelType: {:?} in post join filter check",
rel.rel_type
),
}
}

async fn verify_post_join_filter_value(proto: Box<Plan>) -> Result<()> {
for relation in &proto.relations {
match relation.rel_type.as_ref() {
Some(rt) => match rt {
plan_rel::RelType::Rel(rel) => match check_post_join_filters(rel) {
Err(e) => return Err(e),
Ok(_) => continue,
},
plan_rel::RelType::Root(root) => {
match check_post_join_filters(root.input.as_ref().unwrap()) {
Err(e) => return Err(e),
Ok(_) => continue,
}
}
},
None => return plan_err!("Cannot parse plan relation: None"),
}
}

Ok(())
}

async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> {
let mut ctx = create_context().await?;
let df = ctx.sql(sql).await?;
Expand Down Expand Up @@ -688,6 +779,25 @@ async fn roundtrip(sql: &str) -> Result<()> {
Ok(())
}

async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> {
let mut ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
let proto = to_substrait_plan(&plan, &ctx)?;
let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;

println!("{plan:#?}");
println!("{plan2:#?}");

let plan1str = format!("{plan:?}");
let plan2str = format!("{plan2:?}");
assert_eq!(plan1str, plan2str);

// verify that the join filters are None
verify_post_join_filter_value(proto).await
}

async fn roundtrip_all_types(sql: &str) -> Result<()> {
let mut ctx = create_all_type_context().await?;
let df = ctx.sql(sql).await?;
Expand Down