-
Notifications
You must be signed in to change notification settings - Fork 174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: filter null join key optimization rule #3583
Changes from 2 commits
f105e22
1ef77bb
aa44aad
7cc2ec7
a5e9e4a
b9865b8
3a0b1a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
use std::sync::Arc; | ||
|
||
use common_error::DaftResult; | ||
use common_treenode::{Transformed, TreeNode}; | ||
use daft_core::join::JoinType; | ||
use daft_dsl::{null_lit, optimization::conjuct}; | ||
|
||
use super::OptimizerRule; | ||
use crate::{ | ||
ops::{Filter, Join}, | ||
LogicalPlan, | ||
}; | ||
|
||
/// Optimization rule for filtering out nulls from join keys. | ||
/// | ||
/// Inserts a filter before each side of the join to remove rows where a join key is null when it is valid to do so. | ||
/// This will reduce the cardinality of the tables before a join, which may improve join performance. | ||
#[derive(Default, Debug)] | ||
pub struct FilterNullJoinKey {} | ||
|
||
impl FilterNullJoinKey { | ||
pub fn new() -> Self { | ||
Self {} | ||
} | ||
} | ||
|
||
impl OptimizerRule for FilterNullJoinKey { | ||
fn try_optimize(&self, plan: Arc<LogicalPlan>) -> DaftResult<Transformed<Arc<LogicalPlan>>> { | ||
plan.transform(|node| { | ||
if let LogicalPlan::Join(Join { | ||
left, | ||
right, | ||
left_on, | ||
right_on, | ||
null_equals_nulls, | ||
join_type, | ||
.. | ||
}) = node.as_ref() | ||
{ | ||
let mut null_equals_nulls_iter = null_equals_nulls.as_ref().map_or_else( | ||
|| Box::new(std::iter::repeat(false)) as Box<dyn Iterator<Item = bool>>, | ||
|x| Box::new(x.clone().into_iter()), | ||
); | ||
|
||
let (can_filter_left, can_filter_right) = match join_type { | ||
JoinType::Inner => (true, true), | ||
JoinType::Left => (false, true), | ||
JoinType::Right => (true, false), | ||
JoinType::Outer => (false, false), | ||
JoinType::Anti => (false, true), | ||
JoinType::Semi => (true, true), | ||
}; | ||
|
||
let left_null_pred = if can_filter_left { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how are we actually determining if the join key is nullable? AFAIK, we don't have a concept of nullable in our fields/dtypes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. like this should only push it down if the expr or colum is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This rule creates a filter that removes null rows. The join keys themselves do not have to be a null literal or type. So if the join keys are not nullable or do not have null values, this would essentially be a no-op, but if they had say a row where the value was null, it would be removed prior to the join. |
||
conjuct( | ||
null_equals_nulls_iter | ||
.by_ref() | ||
.zip(left_on) | ||
.filter(|(null_eq_null, _)| !null_eq_null) | ||
.map(|(_, left_key)| left_key.clone().not_eq(null_lit())), | ||
) | ||
} else { | ||
None | ||
}; | ||
|
||
let right_null_pred = if can_filter_right { | ||
conjuct( | ||
null_equals_nulls_iter | ||
.by_ref() | ||
.zip(right_on) | ||
.filter(|(null_eq_null, _)| !null_eq_null) | ||
.map(|(_, right_key)| right_key.clone().not_eq(null_lit())), | ||
) | ||
} else { | ||
None | ||
}; | ||
|
||
if left_null_pred.is_none() && right_null_pred.is_none() { | ||
Ok(Transformed::no(node.clone())) | ||
} else { | ||
let new_left = if let Some(pred) = left_null_pred { | ||
Arc::new(LogicalPlan::Filter(Filter::try_new(left.clone(), pred)?)) | ||
} else { | ||
left.clone() | ||
}; | ||
|
||
let new_right = if let Some(pred) = right_null_pred { | ||
Arc::new(LogicalPlan::Filter(Filter::try_new(right.clone(), pred)?)) | ||
} else { | ||
right.clone() | ||
}; | ||
|
||
let new_join = Arc::new(node.with_new_children(&[new_left, new_right])); | ||
|
||
Ok(Transformed::yes(new_join)) | ||
} | ||
} else { | ||
Ok(Transformed::no(node)) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use std::sync::Arc; | ||
|
||
use common_error::DaftResult; | ||
use daft_core::prelude::*; | ||
use daft_dsl::{col, null_lit}; | ||
|
||
use crate::{ | ||
optimization::{ | ||
optimizer::{RuleBatch, RuleExecutionStrategy}, | ||
rules::filter_null_join_key::FilterNullJoinKey, | ||
test::assert_optimized_plan_with_rules_eq, | ||
}, | ||
test::{dummy_scan_node, dummy_scan_operator}, | ||
LogicalPlan, | ||
}; | ||
|
||
/// Helper that creates an optimizer with the FilterNullJoinKey rule registered, optimizes | ||
/// the provided plan with said optimizer, and compares the optimized plan with | ||
/// the provided expected plan. | ||
fn assert_optimized_plan_eq( | ||
plan: Arc<LogicalPlan>, | ||
expected: Arc<LogicalPlan>, | ||
) -> DaftResult<()> { | ||
assert_optimized_plan_with_rules_eq( | ||
plan, | ||
expected, | ||
vec![RuleBatch::new( | ||
vec![Box::new(FilterNullJoinKey::new())], | ||
RuleExecutionStrategy::Once, | ||
)], | ||
) | ||
} | ||
|
||
#[test] | ||
fn filter_keys_basic() -> DaftResult<()> { | ||
let left_scan = dummy_scan_node(dummy_scan_operator(vec![ | ||
Field::new("a", DataType::Int64), | ||
Field::new("b", DataType::Utf8), | ||
])); | ||
|
||
let right_scan = dummy_scan_node(dummy_scan_operator(vec![ | ||
Field::new("c", DataType::Int64), | ||
Field::new("d", DataType::Utf8), | ||
])); | ||
|
||
let plan = left_scan | ||
.join( | ||
right_scan.clone(), | ||
vec![col("a")], | ||
vec![col("c")], | ||
JoinType::Inner, | ||
None, | ||
None, | ||
None, | ||
false, | ||
)? | ||
.build(); | ||
|
||
let expected = left_scan | ||
.filter(col("a").not_eq(null_lit()))? | ||
.clone() | ||
.join( | ||
right_scan.filter(col("c").not_eq(null_lit()))?, | ||
vec![col("a")], | ||
vec![col("c")], | ||
JoinType::Inner, | ||
None, | ||
None, | ||
None, | ||
false, | ||
)? | ||
.build(); | ||
|
||
assert_optimized_plan_eq(plan, expected)?; | ||
|
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn filter_keys_null_equals_nulls() -> DaftResult<()> { | ||
let left_scan = dummy_scan_node(dummy_scan_operator(vec![ | ||
Field::new("a", DataType::Int64), | ||
Field::new("b", DataType::Utf8), | ||
Field::new("c", DataType::Boolean), | ||
])); | ||
|
||
let right_scan = dummy_scan_node(dummy_scan_operator(vec![ | ||
Field::new("d", DataType::Int64), | ||
Field::new("e", DataType::Utf8), | ||
Field::new("f", DataType::Boolean), | ||
])); | ||
|
||
let plan = left_scan | ||
.join_with_null_safe_equal( | ||
right_scan.clone(), | ||
vec![col("a"), col("b"), col("c")], | ||
vec![col("d"), col("e"), col("f")], | ||
Some(vec![false, true, false]), | ||
JoinType::Left, | ||
None, | ||
None, | ||
None, | ||
false, | ||
)? | ||
.build(); | ||
|
||
let expected_predicate = col("d").not_eq(null_lit()).and(col("f").not_eq(null_lit())); | ||
|
||
let expected = left_scan | ||
.clone() | ||
.join_with_null_safe_equal( | ||
right_scan.filter(expected_predicate)?, | ||
vec![col("a"), col("b"), col("c")], | ||
vec![col("d"), col("e"), col("f")], | ||
Some(vec![false, true, false]), | ||
JoinType::Left, | ||
None, | ||
None, | ||
None, | ||
false, | ||
)? | ||
.build(); | ||
|
||
assert_optimized_plan_eq(plan, expected)?; | ||
|
||
Ok(()) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add an example of a query that would benefit from this optimization