Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: filter null join key optimization rule #3583

Merged
merged 7 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/daft-logical-plan/src/optimization/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ use common_treenode::Transformed;
use super::{
logical_plan_tracker::LogicalPlanTracker,
rules::{
DropRepartition, EliminateCrossJoin, EnrichWithStats, LiftProjectFromAgg, MaterializeScans,
OptimizerRule, PushDownFilter, PushDownLimit, PushDownProjection, SimplifyExpressionsRule,
SplitActorPoolProjects, UnnestPredicateSubquery, UnnestScalarSubquery,
DropRepartition, EliminateCrossJoin, EnrichWithStats, FilterNullJoinKey,
LiftProjectFromAgg, MaterializeScans, OptimizerRule, PushDownFilter, PushDownLimit,
PushDownProjection, SimplifyExpressionsRule, SplitActorPoolProjects,
UnnestPredicateSubquery, UnnestScalarSubquery,
},
};
use crate::LogicalPlan;
Expand Down Expand Up @@ -109,6 +110,7 @@ impl Optimizer {
RuleBatch::new(
vec![
Box::new(DropRepartition::new()),
Box::new(FilterNullJoinKey::new()),
Box::new(PushDownFilter::new()),
Box::new(PushDownProjection::new()),
Box::new(EliminateCrossJoin::new()),
Expand Down
233 changes: 233 additions & 0 deletions src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
use std::sync::Arc;

use common_error::DaftResult;
use common_treenode::{Transformed, TreeNode};
use daft_core::join::JoinType;
use daft_dsl::{null_lit, optimization::conjuct};

use super::OptimizerRule;
use crate::{
ops::{Filter, Join},
LogicalPlan,
};

/// Optimization rule for filtering out nulls from join keys.
///
/// Inserts a filter before each side of the join to remove rows where a join key is null when it is valid to do so.
/// This will reduce the cardinality of the tables before a join, which may improve join performance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an example of a query that would benefit from this optimization

#[derive(Default, Debug)]
pub struct FilterNullJoinKey {}

impl FilterNullJoinKey {
pub fn new() -> Self {
Self {}
}
}

impl OptimizerRule for FilterNullJoinKey {
fn try_optimize(&self, plan: Arc<LogicalPlan>) -> DaftResult<Transformed<Arc<LogicalPlan>>> {
plan.transform(|node| {
if let LogicalPlan::Join(Join {
left,
right,
left_on,
right_on,
null_equals_nulls,
join_type,
..
}) = node.as_ref()
{
let mut null_equals_nulls_iter = null_equals_nulls.as_ref().map_or_else(
|| Box::new(std::iter::repeat(false)) as Box<dyn Iterator<Item = bool>>,
|x| Box::new(x.clone().into_iter()),
);

let (can_filter_left, can_filter_right) = match join_type {
JoinType::Inner => (true, true),
JoinType::Left => (false, true),
JoinType::Right => (true, false),
JoinType::Outer => (false, false),
JoinType::Anti => (false, true),
JoinType::Semi => (true, true),
};

let left_null_pred = if can_filter_left {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how are we actually determining if the join key is nullable? AFAIK, we don't have a concept of nullable in our fields/dtypes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like this should only push it down if the expr or colum is null, but we don't have a way to determine that. Maybe I'm missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rule creates a filter that removes null rows. The join keys themselves do not have to be a null literal or type. So if the join keys are not nullable or do not have null values, this would essentially be a no-op, but if they had say a row where the value was null, it would be removed prior to the join.

conjuct(
null_equals_nulls_iter
.by_ref()
.zip(left_on)
.filter(|(null_eq_null, _)| !null_eq_null)
.map(|(_, left_key)| left_key.clone().not_eq(null_lit())),
)
} else {
None
};

let right_null_pred = if can_filter_right {
conjuct(
null_equals_nulls_iter
.by_ref()
.zip(right_on)
.filter(|(null_eq_null, _)| !null_eq_null)
.map(|(_, right_key)| right_key.clone().not_eq(null_lit())),
)
} else {
None
};

if left_null_pred.is_none() && right_null_pred.is_none() {
Ok(Transformed::no(node.clone()))
} else {
let new_left = if let Some(pred) = left_null_pred {
Arc::new(LogicalPlan::Filter(Filter::try_new(left.clone(), pred)?))
} else {
left.clone()
};

let new_right = if let Some(pred) = right_null_pred {
Arc::new(LogicalPlan::Filter(Filter::try_new(right.clone(), pred)?))
} else {
right.clone()
};

let new_join = Arc::new(node.with_new_children(&[new_left, new_right]));

Ok(Transformed::yes(new_join))
}
} else {
Ok(Transformed::no(node))
}
})
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use common_error::DaftResult;
use daft_core::prelude::*;
use daft_dsl::{col, null_lit};

use crate::{
optimization::{
optimizer::{RuleBatch, RuleExecutionStrategy},
rules::filter_null_join_key::FilterNullJoinKey,
test::assert_optimized_plan_with_rules_eq,
},
test::{dummy_scan_node, dummy_scan_operator},
LogicalPlan,
};

/// Helper that creates an optimizer with the FilterNullJoinKey rule registered, optimizes
/// the provided plan with said optimizer, and compares the optimized plan with
/// the provided expected plan.
fn assert_optimized_plan_eq(
plan: Arc<LogicalPlan>,
expected: Arc<LogicalPlan>,
) -> DaftResult<()> {
assert_optimized_plan_with_rules_eq(
plan,
expected,
vec![RuleBatch::new(
vec![Box::new(FilterNullJoinKey::new())],
RuleExecutionStrategy::Once,
)],
)
}

#[test]
fn filter_keys_basic() -> DaftResult<()> {
let left_scan = dummy_scan_node(dummy_scan_operator(vec![
Field::new("a", DataType::Int64),
Field::new("b", DataType::Utf8),
]));

let right_scan = dummy_scan_node(dummy_scan_operator(vec![
Field::new("c", DataType::Int64),
Field::new("d", DataType::Utf8),
]));

let plan = left_scan
.join(
right_scan.clone(),
vec![col("a")],
vec![col("c")],
JoinType::Inner,
None,
None,
None,
false,
)?
.build();

let expected = left_scan
.filter(col("a").not_eq(null_lit()))?
.clone()
.join(
right_scan.filter(col("c").not_eq(null_lit()))?,
vec![col("a")],
vec![col("c")],
JoinType::Inner,
None,
None,
None,
false,
)?
.build();

assert_optimized_plan_eq(plan, expected)?;

Ok(())
}

#[test]
fn filter_keys_null_equals_nulls() -> DaftResult<()> {
let left_scan = dummy_scan_node(dummy_scan_operator(vec![
Field::new("a", DataType::Int64),
Field::new("b", DataType::Utf8),
Field::new("c", DataType::Boolean),
]));

let right_scan = dummy_scan_node(dummy_scan_operator(vec![
Field::new("d", DataType::Int64),
Field::new("e", DataType::Utf8),
Field::new("f", DataType::Boolean),
]));

let plan = left_scan
.join_with_null_safe_equal(
right_scan.clone(),
vec![col("a"), col("b"), col("c")],
vec![col("d"), col("e"), col("f")],
Some(vec![false, true, false]),
JoinType::Left,
None,
None,
None,
false,
)?
.build();

let expected_predicate = col("d").not_eq(null_lit()).and(col("f").not_eq(null_lit()));

let expected = left_scan
.clone()
.join_with_null_safe_equal(
right_scan.filter(expected_predicate)?,
vec![col("a"), col("b"), col("c")],
vec![col("d"), col("e"), col("f")],
Some(vec![false, true, false]),
JoinType::Left,
None,
None,
None,
false,
)?
.build();

assert_optimized_plan_eq(plan, expected)?;

Ok(())
}
}
2 changes: 2 additions & 0 deletions src/daft-logical-plan/src/optimization/rules/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod drop_repartition;
mod eliminate_cross_join;
mod enrich_with_stats;
mod filter_null_join_key;
mod lift_project_from_agg;
mod materialize_scans;
mod push_down_filter;
Expand All @@ -15,6 +16,7 @@ mod unnest_subquery;
pub use drop_repartition::DropRepartition;
pub use eliminate_cross_join::EliminateCrossJoin;
pub use enrich_with_stats::EnrichWithStats;
pub use filter_null_join_key::FilterNullJoinKey;
pub use lift_project_from_agg::LiftProjectFromAgg;
pub use materialize_scans::MaterializeScans;
pub use push_down_filter::PushDownFilter;
Expand Down
Loading