Skip to content

Commit

Permalink
wrap-up
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 7, 2024
1 parent 773e61d commit 3e2121d
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 20 deletions.
78 changes: 60 additions & 18 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ fn resolve_join_where(

let owned = |e: Arc<Expr>| (*e).clone();

// Partition to:
// - IEjoin supported inequality predicates
// - equality predicates
// - remaining predicates
let mut ie_left_on = vec![];
let mut ie_right_on = vec![];
let mut ie_op = vec![];
Expand Down Expand Up @@ -174,7 +178,34 @@ fn resolve_join_where(
}
}

// Now choose a primary join and do the remaining predicates as filters
fn to_binary(l: Expr, op: Operator, r: Expr) -> Expr {
Expr::BinaryExpr {
left: Arc::from(l),
op,
right: Arc::from(r),
}
}
// Add the ie predicates to the remaining predicates buffer so that they will be executed in the
// filter node.
fn ie_predicates_to_remaining(
remaining_preds: &mut Vec<Expr>,
ie_left_on: Vec<Expr>,
ie_right_on: Vec<Expr>,
ie_op: Vec<InequalityOperator>,
) {
for ((l, op), r) in ie_left_on
.into_iter()
.zip(ie_op.into_iter())
.zip(ie_right_on.into_iter())
{
remaining_preds.push(to_binary(l, op.into(), r))
}
}

let join_node = if !eq_left_on.is_empty() {
// We found one or more equality predicates. Go into a default equi join
// as those are cheapest on avg.
let join_node = resolve_join(
input_left,
input_right,
Expand All @@ -185,47 +216,58 @@ fn resolve_join_where(
ctxt,
)?;

for ((l, op), r) in ie_left_on
.into_iter()
.zip(ie_op.into_iter())
.zip(ie_right_on.into_iter())
{
remaining_preds.push(Expr::BinaryExpr {
left: Arc::from(l),
op: op.into(),
right: Arc::from(r),
})
}
ie_predicates_to_remaining(&mut remaining_preds, ie_left_on, ie_right_on, ie_op);
join_node
} else if ie_right_on.len() == 2 {
}
// TODO! once we support single IEjoin predicates, we must add a branch for the singe ie_pred case.
else if ie_right_on.len() >= 2 {
// Do an IEjoin.
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::IEJoin(IEJoinOptions {
operator1: ie_op[0],
operator2: ie_op[1],
});

resolve_join(
let join_node = resolve_join(
input_left,
input_right,
ie_left_on,
ie_right_on,
ie_left_on[..2].to_vec(),
ie_right_on[..2].to_vec(),
vec![],
options.clone(),
ctxt,
)?
)?;

// The surplus ie-predicates will be added to the remaining predicates so that
// they will be applied in a filter node.
while ie_right_on.len() > 2 {
// Invariant: they all have equal length, so we can pop and unwrap all while len > 2.
// The first 2 predicates are used in the
let l = ie_right_on.pop().unwrap();
let r = ie_left_on.pop().unwrap();
let op = ie_op.pop().unwrap();

remaining_preds.push(to_binary(l, op.into(), r))
}
join_node
} else {
// No predicates found that are supported in a fast algorithm.
// Do a cross join and follow up with filters.
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::Cross;

resolve_join(
let join_node = resolve_join(
input_left,
input_right,
vec![],
vec![],
vec![],
options.clone(),
ctxt,
)?
)?;
// TODO: This can be removed once we support the single IEjoin.
ie_predicates_to_remaining(&mut remaining_preds, ie_left_on, ie_right_on, ie_op);
join_node
};

let IR::Join {
Expand Down
22 changes: 21 additions & 1 deletion py-polars/tests/unit/operations/test_inequality_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_ie_join_with_expressions() -> None:
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)


def test_join_between() -> None:
def test_join_where_predicates() -> None:
left = pl.DataFrame(
{
"id": [0, 1, 2, 3, 4, 5],
Expand Down Expand Up @@ -274,6 +274,26 @@ def test_join_between() -> None:
)
assert_frame_equal(actual, expected, check_exact=True)

q = (
left.lazy()
.join_where(
right.lazy(),
pl.col("group") != pl.col("group"),
)
.select("id", "group", "group_right")
.sort("id")
.select("group", "group_right")
)

explained = q.explain()
assert "CROSS" in explained
assert "FILTER" in explained
actual = q.collect()
assert actual.to_dict(as_series=False) == {
"group": [0, 0, 0, 0, 0, 0, 1, 1, 1],
"group_right": [1, 1, 1, 1, 1, 1, 0, 0, 0],
}


def _inequality_expression(col1: str, op: str, col2: str) -> pl.Expr:
if op == "<":
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def test_arr_eval_named_cols() -> None:
def test_alias_in_join_keys() -> None:
df = pl.DataFrame({"A": ["a", "b"], "B": [["a", "b"], ["c", "d"]]})
with pytest.raises(
ComputeError,
InvalidOperationError,
match=r"'alias' is not allowed in a join key, use 'with_columns' first",
):
df.join(df, on=pl.col("A").alias("foo"))
Expand Down

0 comments on commit 3e2121d

Please sign in to comment.