From f105e222934c8e5bb2bce2de2db823ae81f38076 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Mon, 16 Dec 2024 13:05:04 -0800 Subject: [PATCH 1/6] perf: filter null join key optimization rule --- .../src/optimization/optimizer.rs | 8 +- .../rules/filter_null_join_key.rs | 233 ++++++++++++++++++ .../src/optimization/rules/mod.rs | 2 + 3 files changed, 240 insertions(+), 3 deletions(-) create mode 100644 src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs diff --git a/src/daft-logical-plan/src/optimization/optimizer.rs b/src/daft-logical-plan/src/optimization/optimizer.rs index 76f6251438..64ed0a65c9 100644 --- a/src/daft-logical-plan/src/optimization/optimizer.rs +++ b/src/daft-logical-plan/src/optimization/optimizer.rs @@ -6,9 +6,10 @@ use common_treenode::Transformed; use super::{ logical_plan_tracker::LogicalPlanTracker, rules::{ - DropRepartition, EliminateCrossJoin, EnrichWithStats, LiftProjectFromAgg, MaterializeScans, - OptimizerRule, PushDownFilter, PushDownLimit, PushDownProjection, SimplifyExpressionsRule, - SplitActorPoolProjects, UnnestPredicateSubquery, UnnestScalarSubquery, + DropRepartition, EliminateCrossJoin, EnrichWithStats, FilterNullJoinKey, + LiftProjectFromAgg, MaterializeScans, OptimizerRule, PushDownFilter, PushDownLimit, + PushDownProjection, SimplifyExpressionsRule, SplitActorPoolProjects, + UnnestPredicateSubquery, UnnestScalarSubquery, }, }; use crate::LogicalPlan; @@ -109,6 +110,7 @@ impl Optimizer { RuleBatch::new( vec![ Box::new(DropRepartition::new()), + Box::new(FilterNullJoinKey::new()), Box::new(PushDownFilter::new()), Box::new(PushDownProjection::new()), Box::new(EliminateCrossJoin::new()), diff --git a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs new file mode 100644 index 0000000000..37755f57f0 --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs @@ -0,0 +1,233 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use common_treenode::{Transformed, TreeNode}; +use daft_core::join::JoinType; +use daft_dsl::{null_lit, optimization::conjuct, Expr}; + +use super::OptimizerRule; +use crate::{ + ops::{Filter, Join}, + LogicalPlan, +}; + +/// Optimization rule for filtering out nulls from join keys. +/// +/// Inserts a filter before each side of the join to remove rows where a join key is null when it is valid to do so. +/// This will reduce the cardinality of the tables before a join, which may improve join performance. +#[derive(Default, Debug)] +pub struct FilterNullJoinKey {} + +impl FilterNullJoinKey { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for FilterNullJoinKey { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform(|node| { + if let LogicalPlan::Join(Join { + left, + right, + left_on, + right_on, + null_equals_nulls, + join_type, + .. + }) = node.as_ref() + { + let mut null_equals_nulls_iter = null_equals_nulls.as_ref().map_or_else( + || Box::new(std::iter::repeat(false)) as Box>, + |x| Box::new(x.clone().into_iter()), + ); + + let (can_filter_left, can_filter_right) = match join_type { + JoinType::Inner => (true, true), + JoinType::Left => (false, true), + JoinType::Right => (true, false), + JoinType::Outer => (false, false), + JoinType::Anti => (false, true), + JoinType::Semi => (true, true), + }; + + let left_null_pred = if can_filter_left { + conjuct( + null_equals_nulls_iter + .by_ref() + .zip(left_on) + .filter(|(null_eq_null, _)| !null_eq_null) + .map(|(_, left_key)| Expr::eq(left_key.clone(), null_lit())), + ) + } else { + None + }; + + let right_null_pred = if can_filter_right { + conjuct( + null_equals_nulls_iter + .by_ref() + .zip(right_on) + .filter(|(null_eq_null, _)| !null_eq_null) + .map(|(_, right_key)| Expr::eq(right_key.clone(), null_lit())), + ) + } else { + None + }; + + if left_null_pred.is_none() && right_null_pred.is_none() { + Ok(Transformed::no(node.clone())) + } else { + let new_left = if let Some(pred) = left_null_pred { + Arc::new(LogicalPlan::Filter(Filter::try_new(left.clone(), pred)?)) + } else { + left.clone() + }; + + let new_right = if let Some(pred) = right_null_pred { + Arc::new(LogicalPlan::Filter(Filter::try_new(right.clone(), pred)?)) + } else { + right.clone() + }; + + let new_join = Arc::new(node.with_new_children(&[new_left, new_right])); + + Ok(Transformed::yes(new_join)) + } + } else { + Ok(Transformed::no(node)) + } + }) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_error::DaftResult; + use daft_core::prelude::*; + use daft_dsl::{col, null_lit}; + + use crate::{ + optimization::{ + optimizer::{RuleBatch, RuleExecutionStrategy}, + rules::filter_null_join_key::FilterNullJoinKey, + test::assert_optimized_plan_with_rules_eq, + }, + test::{dummy_scan_node, dummy_scan_operator}, + LogicalPlan, + }; + + /// Helper that creates an optimizer with the FilterNullJoinKey rule registered, optimizes + /// the provided plan with said optimizer, and compares the optimized plan with + /// the provided expected plan. + fn assert_optimized_plan_eq( + plan: Arc, + expected: Arc, + ) -> DaftResult<()> { + assert_optimized_plan_with_rules_eq( + plan, + expected, + vec![RuleBatch::new( + vec![Box::new(FilterNullJoinKey::new())], + RuleExecutionStrategy::Once, + )], + ) + } + + #[test] + fn filter_keys_basic() -> DaftResult<()> { + let left_scan = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Utf8), + ])); + + let right_scan = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("c", DataType::Int64), + Field::new("d", DataType::Utf8), + ])); + + let plan = left_scan + .join( + right_scan.clone(), + vec![col("a")], + vec![col("c")], + JoinType::Inner, + None, + None, + None, + false, + )? + .build(); + + let expected = left_scan + .filter(col("a").eq(null_lit()))? + .clone() + .join( + right_scan.filter(col("c").eq(null_lit()))?, + vec![col("a")], + vec![col("c")], + JoinType::Inner, + None, + None, + None, + false, + )? + .build(); + + assert_optimized_plan_eq(plan, expected)?; + + Ok(()) + } + + #[test] + fn filter_keys_null_equals_nulls() -> DaftResult<()> { + let left_scan = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Utf8), + Field::new("c", DataType::Boolean), + ])); + + let right_scan = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("d", DataType::Int64), + Field::new("e", DataType::Utf8), + Field::new("f", DataType::Boolean), + ])); + + let plan = left_scan + .join_with_null_safe_equal( + right_scan.clone(), + vec![col("a"), col("b"), col("c")], + vec![col("d"), col("e"), col("f")], + Some(vec![false, true, false]), + JoinType::Left, + None, + None, + None, + false, + )? + .build(); + + let expected_predicate = col("d").eq(null_lit()).and(col("f").eq(null_lit())); + + let expected = left_scan + .clone() + .join_with_null_safe_equal( + right_scan.filter(expected_predicate)?, + vec![col("a"), col("b"), col("c")], + vec![col("d"), col("e"), col("f")], + Some(vec![false, true, false]), + JoinType::Left, + None, + None, + None, + false, + )? + .build(); + + assert_optimized_plan_eq(plan, expected)?; + + Ok(()) + } +} diff --git a/src/daft-logical-plan/src/optimization/rules/mod.rs b/src/daft-logical-plan/src/optimization/rules/mod.rs index f540a77cb0..787c0e7ac6 100644 --- a/src/daft-logical-plan/src/optimization/rules/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/mod.rs @@ -1,6 +1,7 @@ mod drop_repartition; mod eliminate_cross_join; mod enrich_with_stats; +mod filter_null_join_key; mod lift_project_from_agg; mod materialize_scans; mod push_down_filter; @@ -15,6 +16,7 @@ mod unnest_subquery; pub use drop_repartition::DropRepartition; pub use eliminate_cross_join::EliminateCrossJoin; pub use enrich_with_stats::EnrichWithStats; +pub use filter_null_join_key::FilterNullJoinKey; pub use lift_project_from_agg::LiftProjectFromAgg; pub use materialize_scans::MaterializeScans; pub use push_down_filter::PushDownFilter; From 1ef77bb029ad7847b67524f891a4e8099e00c106 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Mon, 16 Dec 2024 13:36:04 -0800 Subject: [PATCH 2/6] not equal --- .../src/optimization/rules/filter_null_join_key.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs index 37755f57f0..a4b454a1fc 100644 --- a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs +++ b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use common_error::DaftResult; use common_treenode::{Transformed, TreeNode}; use daft_core::join::JoinType; -use daft_dsl::{null_lit, optimization::conjuct, Expr}; +use daft_dsl::{null_lit, optimization::conjuct}; use super::OptimizerRule; use crate::{ @@ -57,7 +57,7 @@ impl OptimizerRule for FilterNullJoinKey { .by_ref() .zip(left_on) .filter(|(null_eq_null, _)| !null_eq_null) - .map(|(_, left_key)| Expr::eq(left_key.clone(), null_lit())), + .map(|(_, left_key)| left_key.clone().not_eq(null_lit())), ) } else { None @@ -69,7 +69,7 @@ impl OptimizerRule for FilterNullJoinKey { .by_ref() .zip(right_on) .filter(|(null_eq_null, _)| !null_eq_null) - .map(|(_, right_key)| Expr::eq(right_key.clone(), null_lit())), + .map(|(_, right_key)| right_key.clone().not_eq(null_lit())), ) } else { None @@ -162,10 +162,10 @@ mod tests { .build(); let expected = left_scan - .filter(col("a").eq(null_lit()))? + .filter(col("a").not_eq(null_lit()))? .clone() .join( - right_scan.filter(col("c").eq(null_lit()))?, + right_scan.filter(col("c").not_eq(null_lit()))?, vec![col("a")], vec![col("c")], JoinType::Inner, @@ -209,7 +209,7 @@ mod tests { )? .build(); - let expected_predicate = col("d").eq(null_lit()).and(col("f").eq(null_lit())); + let expected_predicate = col("d").not_eq(null_lit()).and(col("f").not_eq(null_lit())); let expected = left_scan .clone() From aa44aadfd1f4ed2aa278b988f8fc3ceb5fd51a8f Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Mon, 16 Dec 2024 13:51:04 -0800 Subject: [PATCH 3/6] is_null instead of == null --- .../src/optimization/rules/filter_null_join_key.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs index a4b454a1fc..7298fcc561 100644 --- a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs +++ b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use common_error::DaftResult; use common_treenode::{Transformed, TreeNode}; use daft_core::join::JoinType; -use daft_dsl::{null_lit, optimization::conjuct}; +use daft_dsl::optimization::conjuct; use super::OptimizerRule; use crate::{ @@ -57,7 +57,7 @@ impl OptimizerRule for FilterNullJoinKey { .by_ref() .zip(left_on) .filter(|(null_eq_null, _)| !null_eq_null) - .map(|(_, left_key)| left_key.clone().not_eq(null_lit())), + .map(|(_, left_key)| left_key.clone().is_null().not()), ) } else { None @@ -69,7 +69,7 @@ impl OptimizerRule for FilterNullJoinKey { .by_ref() .zip(right_on) .filter(|(null_eq_null, _)| !null_eq_null) - .map(|(_, right_key)| right_key.clone().not_eq(null_lit())), + .map(|(_, right_key)| right_key.clone().is_null().not()), ) } else { None @@ -162,10 +162,10 @@ mod tests { .build(); let expected = left_scan - .filter(col("a").not_eq(null_lit()))? + .filter(col("a").is_null().not())? .clone() .join( - right_scan.filter(col("c").not_eq(null_lit()))?, + right_scan.filter(col("c").is_null().not())?, vec![col("a")], vec![col("c")], JoinType::Inner, @@ -209,7 +209,7 @@ mod tests { )? .build(); - let expected_predicate = col("d").not_eq(null_lit()).and(col("f").not_eq(null_lit())); + let expected_predicate = col("d").is_null().not().and(col("f").is_null().not()); let expected = left_scan .clone() From 7cc2ec75a64ace9be5a170f860a4a7ca28d64f13 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Mon, 16 Dec 2024 14:31:05 -0800 Subject: [PATCH 4/6] add docs --- .../rules/filter_null_join_key.rs | 71 ++++++++++++++++++- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs index 7298fcc561..acd32cd397 100644 --- a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs +++ b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs @@ -13,8 +13,75 @@ use crate::{ /// Optimization rule for filtering out nulls from join keys. /// -/// Inserts a filter before each side of the join to remove rows where a join key is null when it is valid to do so. -/// This will reduce the cardinality of the tables before a join, which may improve join performance. +/// When a join will always discard null keys from a join side, +/// this rule inserts a filter before that side to remove rows where a join key is null. +/// This reduces the cardinality of the tables before a join to improve join performance, +/// and can also be pushed down with other rules to reduce source and intermediate output sizes. +/// +/// # Example +/// ```sql +/// SELECT * FROM left JOIN right ON left.x = right.y +/// ``` +/// turns into +/// ```sql +/// SELECT * +/// FROM (SELECT * FROM left WHERE x IS NOT NULL) AS non_null_left +/// JOIN (SELECT * FROM right WHERE x IS NOT NULL) AS non_null_right +/// ON non_null_left.x = non_null_right.y +/// ``` +/// +/// So if `left` was: +/// ``` +/// ╭───────╮ +/// │ x │ +/// │ --- │ +/// │ Int64 │ +/// ╞═══════╡ +/// │ 1 │ +/// ├╌╌╌╌╌╌╌┤ +/// │ 2 │ +/// ├╌╌╌╌╌╌╌┤ +/// │ None │ +/// ╰───────╯ +/// ``` +/// And `right` was: +/// ``` +/// ╭───────╮ +/// │ y │ +/// │ --- │ +/// │ Int64 │ +/// ╞═══════╡ +/// │ 1 │ +/// ├╌╌╌╌╌╌╌┤ +/// │ None │ +/// ├╌╌╌╌╌╌╌┤ +/// │ None │ +/// ╰───────╯ +/// ``` +/// the original query would join on all rows, whereas the new query would first filter out null rows and join on the following: +/// +/// `non_null_left`: +/// ``` +/// ╭───────╮ +/// │ x │ +/// │ --- │ +/// │ Int64 │ +/// ╞═══════╡ +/// │ 1 │ +/// ├╌╌╌╌╌╌╌┤ +/// │ 2 │ +/// ╰───────╯ +/// ``` +/// `non_null_right`: +/// ``` +/// ╭───────╮ +/// │ y │ +/// │ --- │ +/// │ Int64 │ +/// ╞═══════╡ +/// │ 1 │ +/// ╰───────╯ +/// ``` #[derive(Default, Debug)] pub struct FilterNullJoinKey {} From a5e9e4a7ddf916dc86c2555fc915b93ee74e282d Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Mon, 16 Dec 2024 14:31:54 -0800 Subject: [PATCH 5/6] remove import --- .../src/optimization/rules/filter_null_join_key.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs index acd32cd397..e93db867e5 100644 --- a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs +++ b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs @@ -174,7 +174,7 @@ mod tests { use common_error::DaftResult; use daft_core::prelude::*; - use daft_dsl::{col, null_lit}; + use daft_dsl::col; use crate::{ optimization::{ From 3a0b1a9e51164c1373ff02d00ed6dc7331f272e6 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Tue, 17 Dec 2024 13:22:03 -0800 Subject: [PATCH 6/6] merge fix --- .../src/optimization/rules/filter_null_join_key.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs index e93db867e5..80aab76c18 100644 --- a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs +++ b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs @@ -2,8 +2,8 @@ use std::sync::Arc; use common_error::DaftResult; use common_treenode::{Transformed, TreeNode}; +use daft_algebra::boolean::combine_conjunction; use daft_core::join::JoinType; -use daft_dsl::optimization::conjuct; use super::OptimizerRule; use crate::{ @@ -119,7 +119,7 @@ impl OptimizerRule for FilterNullJoinKey { }; let left_null_pred = if can_filter_left { - conjuct( + combine_conjunction( null_equals_nulls_iter .by_ref() .zip(left_on) @@ -131,7 +131,7 @@ impl OptimizerRule for FilterNullJoinKey { }; let right_null_pred = if can_filter_right { - conjuct( + combine_conjunction( null_equals_nulls_iter .by_ref() .zip(right_on)