Skip to content

Commit

Permalink
feat: Support arbitrary expressions in 'join_where' (#20525)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 2, 2025
1 parent 91d04b8 commit 9d7a7d3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 330 deletions.
336 changes: 17 additions & 319 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,11 @@ fn resolve_join_where(
mut options: Arc<JoinOptions>,
ctxt: &mut DslConversionContext,
) -> PolarsResult<(Node, Node)> {
// If not eager, respect the flag.
if ctxt.opt_flags.eager() {
ctxt.opt_flags.set(OptFlags::PREDICATE_PUSHDOWN, true);
}
ctxt.opt_flags.set(OptFlags::COLLAPSE_JOINS, true);
check_join_keys(&predicates)?;
let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt)
.map_err(|e| e.context(failed_here!(join left)))?;
Expand All @@ -403,17 +408,6 @@ fn resolve_join_where(
.into_owned();

for expr in &predicates {
let mut comparison_count = 0;
for _e in expr
.into_iter()
.filter(|e| matches!(e, Expr::BinaryExpr { op, .. } if op.is_comparison()))
{
comparison_count += 1;
if comparison_count > 1 {
polars_bail!(InvalidOperation: "only one binary comparison allowed in each 'join_where' predicate; found {:?}", expr);
}
}

fn all_in_schema(
schema: &Schema,
other: Option<&Schema>,
Expand All @@ -437,317 +431,21 @@ fn resolve_join_where(
polars_ensure!( valid, InvalidOperation: "'join_where' predicate only refers to columns from a single table")
}

let owned = |e: Arc<Expr>| (*e).clone();

// We do a few things
// First we partition to:
// - IEjoin supported inequality predicates
// - equality predicates
// - remaining predicates
// And then decide to which join we dispatch.
// The remaining predicates will be applied as filter.

// What make things a bit complicated is that duplicate join names
// are referred to in the query with the name post-join, but on joins
// we refer to the names pre-join (e.g. without suffix). So there is some
// bookkeeping.
//
// - First we determine which side of the binary expression refers to the left and right table
// and make sure that lhs of the binary expr, maps to the lhs of the join tables and vice versa.
// Next we ensure the suffixes are removed when we partition.
//
// If a predicate has to be applied as post-join filter, we put the suffixes back if needed.
let mut ie_left_on = vec![];
let mut ie_right_on = vec![];
let mut ie_op = vec![];

let mut eq_left_on = vec![];
let mut eq_right_on = vec![];

let mut remaining_preds = vec![];

fn to_inequality_operator(op: &Operator) -> Option<InequalityOperator> {
match op {
Operator::Lt => Some(InequalityOperator::Lt),
Operator::LtEq => Some(InequalityOperator::LtEq),
Operator::Gt => Some(InequalityOperator::Gt),
Operator::GtEq => Some(InequalityOperator::GtEq),
_ => None,
}
}

fn rename_expr(e: Expr, old: &str, new: &str) -> Expr {
e.map_expr(|e| match e {
Expr::Column(name) if name.as_str() == old => Expr::Column(new.into()),
e => e,
})
}

fn determine_order_and_pre_join_names(
left: Expr,
op: Operator,
right: Expr,
schema_left: &Schema,
schema_right: &Schema,
suffix: &str,
) -> PolarsResult<(Expr, Operator, Expr)> {
let left_names = expr_to_leaf_column_names_iter(&left).collect::<PlHashSet<_>>();
let right_names = expr_to_leaf_column_names_iter(&right).collect::<PlHashSet<_>>();

// All left should be in the left schema.
let (left_names, right_names, left, op, mut right) =
if !left_names.iter().all(|n| schema_left.contains(n)) {
// If all right names are in left schema -> swap
if right_names.iter().all(|n| schema_left.contains(n)) {
(right_names, left_names, right, op.swap_operands(), left)
} else {
polars_bail!(InvalidOperation: "got ambiguous column names in 'join_where'")
}
} else {
(left_names, right_names, left, op, right)
};
for name in &left_names {
polars_ensure!(!right_names.contains(name.as_str()), InvalidOperation: "found ambiguous column names in 'join_where'\n\n\
Note that you should refer to the column names as they are post-join operation.")
}

// Now we know left belongs to the left schema, rhs suffixes are dealt with.
for post_join_name in right_names {
if let Some(pre_join_name) = post_join_name.strip_suffix(suffix) {
// Name is both sides, so a suffix will be added by the join.
// We rename
if schema_right.contains(pre_join_name) && schema_left.contains(pre_join_name) {
right = rename_expr(right, &post_join_name, pre_join_name);
}
}
}
Ok((left, op, right))
}

// Make it a binary comparison and ensure the columns refer to post join names.
fn to_binary_post_join(
l: Expr,
op: Operator,
mut r: Expr,
schema_right: &Schema,
suffix: &str,
) -> Expr {
let names = expr_to_leaf_column_names_iter(&r).collect::<Vec<_>>();
for pre_join_name in &names {
if !schema_right.contains(pre_join_name) {
let post_join_name = _join_suffix_name(pre_join_name, suffix);
r = rename_expr(r, pre_join_name, post_join_name.as_str());
}
}

Expr::BinaryExpr {
left: Arc::from(l),
op,
right: Arc::from(r),
}
}

let suffix = options.args.suffix().clone();
for pred in predicates.into_iter() {
let Expr::BinaryExpr { left, op, right } = pred.clone() else {
polars_bail!(InvalidOperation: "can only join on binary (in)equality expressions, found {:?}", pred)
};
polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate");
let (left, op, right) = determine_order_and_pre_join_names(
owned(left),
op,
owned(right),
&schema_left,
&schema_right,
&suffix,
)?;

if let Some(ie_op_) = to_inequality_operator(&op) {
fn is_numeric(e: &Expr, schema: &Schema) -> bool {
expr_to_leaf_column_names_iter(e).any(|name| {
if let Some(dt) = schema.get(name.as_str()) {
dt.to_physical().is_numeric()
} else {
false
}
})
}

// We fallback to remaining if:
// - we already have an IEjoin or Inner join
// - we already have an Inner join
// - data is not numeric (our iejoin doesn't yet implement that)
if ie_op.len() >= 2
|| !eq_right_on.is_empty()
|| !is_numeric(&left, &schema_left)
|| !is_numeric(&right, &schema_right)
{
remaining_preds.push(to_binary_post_join(left, op, right, &schema_right, &suffix))
} else {
ie_left_on.push(left);
ie_right_on.push(right);
ie_op.push(ie_op_)
}
} else if matches!(op, Operator::Eq) {
eq_left_on.push(left);
eq_right_on.push(right);
} else {
remaining_preds.push(to_binary_post_join(left, op, right, &schema_right, &suffix));
}
}

// Now choose a primary join and do the remaining predicates as filters
// Add the ie predicates to the remaining predicates buffer so that they will be executed in the
// filter node.
fn ie_predicates_to_remaining(
remaining_preds: &mut Vec<Expr>,
ie_left_on: Vec<Expr>,
ie_right_on: Vec<Expr>,
ie_op: Vec<InequalityOperator>,
schema_right: &Schema,
suffix: &str,
) {
for ((l, op), r) in ie_left_on
.into_iter()
.zip(ie_op.into_iter())
.zip(ie_right_on.into_iter())
{
remaining_preds.push(to_binary_post_join(l, op.into(), r, schema_right, suffix))
}
}

let (mut last_node, join_node) = if !eq_left_on.is_empty() {
// We found one or more equality predicates. Go into a default equi join
// as those are cheapest on avg.
let (last_node, join_node) = resolve_join(
Either::Right(input_left),
Either::Right(input_right),
eq_left_on,
eq_right_on,
vec![],
options.clone(),
ctxt,
)?;

ie_predicates_to_remaining(
&mut remaining_preds,
ie_left_on,
ie_right_on,
ie_op,
&schema_right,
&suffix,
);
(last_node, join_node)
} else if ie_right_on.len() >= 2 {
// Do an IEjoin.
let opts = Arc::make_mut(&mut options);

opts.args.how = JoinType::IEJoin;
opts.options = Some(JoinTypeOptionsIR::IEJoin(IEJoinOptions {
operator1: ie_op[0],
operator2: Some(ie_op[1]),
}));

let (last_node, join_node) = resolve_join(
Either::Right(input_left),
Either::Right(input_right),
ie_left_on[..2].to_vec(),
ie_right_on[..2].to_vec(),
vec![],
options.clone(),
ctxt,
)?;

// The surplus ie-predicates will be added to the remaining predicates so that
// they will be applied in a filter node.
while ie_right_on.len() > 2 {
// Invariant: they all have equal length, so we can pop and unwrap all while len > 2.
// The first 2 predicates are used in the
let l = ie_right_on.pop().unwrap();
let r = ie_left_on.pop().unwrap();
let op = ie_op.pop().unwrap();

remaining_preds.push(to_binary_post_join(l, op.into(), r, &schema_right, &suffix))
}
(last_node, join_node)
} else if ie_right_on.len() == 1 {
// For a single inequality comparison, we use the piecewise merge join algorithm
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::IEJoin;
opts.options = Some(JoinTypeOptionsIR::IEJoin(IEJoinOptions {
operator1: ie_op[0],
operator2: None,
}));

resolve_join(
Either::Right(input_left),
Either::Right(input_right),
ie_left_on,
ie_right_on,
vec![],
options.clone(),
ctxt,
)?
} else {
// No predicates found that are supported in a fast algorithm.
// Do a cross join and follow up with filters.
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::Cross;

resolve_join(
Either::Right(input_left),
Either::Right(input_right),
vec![],
vec![],
vec![],
options.clone(),
ctxt,
)?
};

let IR::Join {
input_left,
input_right,
..
} = ctxt.lp_arena.get(join_node)
else {
unreachable!()
};
let schema_right = ctxt
.lp_arena
.get(*input_right)
.schema(ctxt.lp_arena)
.into_owned();
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::Cross;

let schema_left = ctxt
.lp_arena
.get(*input_left)
.schema(ctxt.lp_arena)
.into_owned();
let (mut last_node, join_node) = resolve_join(
Either::Right(input_left),
Either::Right(input_right),
vec![],
vec![],
vec![],
options.clone(),
ctxt,
)?;

// Ensure that the predicates use the proper suffix
for e in remaining_preds {
for e in predicates {
let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?;
let AExpr::BinaryExpr { mut right, .. } = *ctxt.expr_arena.get(predicate.node()) else {
unreachable!()
};

let original_right = right;

for name in aexpr_to_leaf_names(right, ctxt.expr_arena) {
polars_ensure!(schema_right.contains(name.as_str()), ColumnNotFound: "could not find column {name} in the right table during join operation");
if schema_left.contains(name.as_str()) {
let new_name = _join_suffix_name(name.as_str(), suffix.as_str());

right = rename_matching_aexpr_leaf_names(
right,
ctxt.expr_arena,
name.as_str(),
new_name,
);
}
}
ctxt.expr_arena.swap(right, original_right);

let ir = IR::Filter {
input: last_node,
Expand Down
22 changes: 11 additions & 11 deletions py-polars/tests/unit/operations/test_inequality_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,17 +461,6 @@ def test_raise_on_ambiguous_name() -> None:
df.join_where(df, pl.col("id") >= pl.col("id"))


def test_raise_on_multiple_binary_comparisons() -> None:
df = pl.DataFrame({"id": [1, 2]})
with pytest.raises(
pl.exceptions.InvalidOperationError,
match="only one binary comparison allowed in each 'join_where' predicate; found ",
):
df.join_where(
df, (pl.col("id") < pl.col("id")) ^ (pl.col("id") >= pl.col("id"))
)


def test_raise_invalid_input_join_where() -> None:
df = pl.DataFrame({"id": [1, 2]})
with pytest.raises(
Expand Down Expand Up @@ -681,3 +670,14 @@ def test_join_where_literal_20061() -> None:
"value_right": [5, 5, 5, 25],
"flag_right": [1, 1, 1, 1],
}


def test_boolean_predicate_join_where() -> None:
urls = pl.LazyFrame({"url": "abcd.com/page"})
categories = pl.LazyFrame({"base_url": "abcd.com", "category": "landing page"})
assert (
"NESTED LOOP JOIN"
in urls.join_where(
categories, pl.col("url").str.starts_with(pl.col("base_url"))
).explain()
)

0 comments on commit 9d7a7d3

Please sign in to comment.