diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 0433b2b14e69..9e33f2079d11 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -132,6 +132,10 @@ fn resolve_join_where( let owned = |e: Arc| (*e).clone(); + // Partition to: + // - IEjoin supported inequality predicates + // - equality predicates + // - remaining predicates let mut ie_left_on = vec![]; let mut ie_right_on = vec![]; let mut ie_op = vec![]; @@ -174,7 +178,34 @@ fn resolve_join_where( } } + // Now choose a primary join and do the remaining predicates as filters + fn to_binary(l: Expr, op: Operator, r: Expr) -> Expr { + Expr::BinaryExpr { + left: Arc::from(l), + op, + right: Arc::from(r), + } + } + // Add the ie predicates to the remaining predicates buffer so that they will be executed in the + // filter node. + fn ie_predicates_to_remaining( + remaining_preds: &mut Vec, + ie_left_on: Vec, + ie_right_on: Vec, + ie_op: Vec, + ) { + for ((l, op), r) in ie_left_on + .into_iter() + .zip(ie_op.into_iter()) + .zip(ie_right_on.into_iter()) + { + remaining_preds.push(to_binary(l, op.into(), r)) + } + } + let join_node = if !eq_left_on.is_empty() { + // We found one or more equality predicates. Go into a default equi join + // as those are cheapest on avg. let join_node = resolve_join( input_left, input_right, @@ -185,39 +216,47 @@ fn resolve_join_where( ctxt, )?; - for ((l, op), r) in ie_left_on - .into_iter() - .zip(ie_op.into_iter()) - .zip(ie_right_on.into_iter()) - { - remaining_preds.push(Expr::BinaryExpr { - left: Arc::from(l), - op: op.into(), - right: Arc::from(r), - }) - } + ie_predicates_to_remaining(&mut remaining_preds, ie_left_on, ie_right_on, ie_op); join_node - } else if ie_right_on.len() == 2 { + } + // TODO! once we support single IEjoin predicates, we must add a branch for the singe ie_pred case. + else if ie_right_on.len() >= 2 { + // Do an IEjoin. let opts = Arc::make_mut(&mut options); opts.args.how = JoinType::IEJoin(IEJoinOptions { operator1: ie_op[0], operator2: ie_op[1], }); - resolve_join( + let join_node = resolve_join( input_left, input_right, - ie_left_on, - ie_right_on, + ie_left_on[..2].to_vec(), + ie_right_on[..2].to_vec(), vec![], options.clone(), ctxt, - )? + )?; + + // The surplus ie-predicates will be added to the remaining predicates so that + // they will be applied in a filter node. + while ie_right_on.len() > 2 { + // Invariant: they all have equal length, so we can pop and unwrap all while len > 2. + // The first 2 predicates are used in the + let l = ie_right_on.pop().unwrap(); + let r = ie_left_on.pop().unwrap(); + let op = ie_op.pop().unwrap(); + + remaining_preds.push(to_binary(l, op.into(), r)) + } + join_node } else { + // No predicates found that are supported in a fast algorithm. + // Do a cross join and follow up with filters. let opts = Arc::make_mut(&mut options); opts.args.how = JoinType::Cross; - resolve_join( + let join_node = resolve_join( input_left, input_right, vec![], @@ -225,7 +264,10 @@ fn resolve_join_where( vec![], options.clone(), ctxt, - )? + )?; + // TODO: This can be removed once we support the single IEjoin. + ie_predicates_to_remaining(&mut remaining_preds, ie_left_on, ie_right_on, ie_op); + join_node }; let IR::Join { diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 7a0108eeb6db..428afa4f5c01 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -170,7 +170,7 @@ def test_ie_join_with_expressions() -> None: assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) -def test_join_between() -> None: +def test_join_where_predicates() -> None: left = pl.DataFrame( { "id": [0, 1, 2, 3, 4, 5], @@ -274,6 +274,26 @@ def test_join_between() -> None: ) assert_frame_equal(actual, expected, check_exact=True) + q = ( + left.lazy() + .join_where( + right.lazy(), + pl.col("group") != pl.col("group"), + ) + .select("id", "group", "group_right") + .sort("id") + .select("group", "group_right") + ) + + explained = q.explain() + assert "CROSS" in explained + assert "FILTER" in explained + actual = q.collect() + assert actual.to_dict(as_series=False) == { + "group": [0, 0, 0, 0, 0, 0, 1, 1, 1], + "group_right": [1, 1, 1, 1, 1, 1, 0, 0, 0], + } + def _inequality_expression(col1: str, op: str, col2: str) -> pl.Expr: if op == "<": diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index 2087387b1a8a..07b98d9d8111 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -349,7 +349,7 @@ def test_arr_eval_named_cols() -> None: def test_alias_in_join_keys() -> None: df = pl.DataFrame({"A": ["a", "b"], "B": [["a", "b"], ["c", "d"]]}) with pytest.raises( - ComputeError, + InvalidOperationError, match=r"'alias' is not allowed in a join key, use 'with_columns' first", ): df.join(df, on=pl.col("A").alias("foo"))