Skip to content

Commit

Permalink
fix(rust): Tighten up error checking on join keys (#17517)
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- authored Jul 15, 2024
1 parent d76609a commit e570e77
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 25 deletions.
11 changes: 11 additions & 0 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ pub fn to_alp_impl(
}
if turn_off_coalesce {
let options = Arc::make_mut(&mut options);
if matches!(options.args.coalesce, JoinCoalesce::CoalesceColumns) {
polars_warn!("Coalescing join requested but not all join keys are column references, turning off key coalescing");
}
options.args.coalesce = JoinCoalesce::KeepColumns;
}

Expand Down Expand Up @@ -523,6 +526,14 @@ pub fn to_alp_impl(
convert.fill_scratch(&left_on, expr_arena);
convert.fill_scratch(&right_on, expr_arena);

// Every expression must be elementwise so that we are
// guaranteed the keys for a join are all the same length.
let all_elementwise =
|aexprs: &[ExprIR]| all_streamable(aexprs, &*expr_arena, Context::Default);
polars_ensure!(
all_elementwise(&left_on) && all_elementwise(&right_on),
InvalidOperation: "All join key expressions must be elementwise."
);
let lp = IR::Join {
input_left,
input_right,
Expand Down
28 changes: 16 additions & 12 deletions crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,27 @@ fn can_pushdown_slice_past_projections(exprs: &[ExprIR], arena: &Arena<AExpr>) -
for expr_ir in exprs.iter() {
// `select(c = Literal([1, 2, 3])).slice(0, 0)` must block slice pushdown,
// because `c` projects to a height independent from the input height. We check
// this by observing that `c` does not have any columns in its input notes.
// this by observing that `c` does not have any columns in its input nodes.
//
// TODO: Simply checking that a column node is present does not handle e.g.:
// `select(c = Literal([1, 2, 3]).is_in(col(a)))`, for functions like `is_in`,
// `str.contains`, `str.contains_many` etc. - observe a column node is present
// but the output height is not dependent on it.
let mut has_column = false;
let mut literals_all_scalar = true;
let is_elementwise = arena.iter(expr_ir.node()).all(|(_node, ae)| {
has_column |= matches!(ae, AExpr::Column(_));
literals_all_scalar &= if let AExpr::Literal(v) = ae {
v.projects_as_scalar()
} else {
true
};
single_aexpr_is_elementwise(ae)
});
let is_elementwise = is_streamable(expr_ir.node(), arena, Context::Default);
let (has_column, literals_all_scalar) = arena.iter(expr_ir.node()).fold(
(false, true),
|(has_column, lit_scalar), (_node, ae)| {
(
has_column | matches!(ae, AExpr::Column(_)),
lit_scalar
& if let AExpr::Literal(v) = ae {
v.projects_as_scalar()
} else {
true
},
)
},
);

// If there is no column then all literals must be scalar
if !is_elementwise || !(has_column || literals_all_scalar) {
Expand Down
13 changes: 0 additions & 13 deletions crates/polars-plan/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,6 @@ pub(crate) fn aexpr_is_simple_projection(current_node: Node, arena: &Arena<AExpr
.all(|(_node, e)| matches!(e, AExpr::Column(_) | AExpr::Alias(_, _)))
}

pub(crate) fn single_aexpr_is_elementwise(ae: &AExpr) -> bool {
use AExpr::*;
match ae {
AnonymousFunction { options, .. } | Function { options, .. } => {
!matches!(options.collect_groups, ApplyOptions::GroupWise)
},
Column(_) | Alias(_, _) | Literal(_) | BinaryExpr { .. } | Ternary { .. } | Cast { .. } => {
true
},
_ => false,
}
}

pub fn has_aexpr<F>(current_node: Node, arena: &Arena<AExpr>, matches: F) -> bool
where
F: Fn(&AExpr) -> bool,
Expand Down
44 changes: 44 additions & 0 deletions py-polars/tests/unit/operations/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,47 @@ def test_join_empty_literal_17027() -> None:
.height
== 1
)


@pytest.mark.parametrize(
("left_on", "right_on"),
zip(
[pl.col("a"), pl.col("a").sort(), [pl.col("a"), pl.col("b")]],
[pl.col("a").slice(0, 2) * 2, pl.col("b"), [pl.col("a"), pl.col("b").head()]],
),
)
def test_join_non_elementwise_keys_raises(left_on: pl.Expr, right_on: pl.Expr) -> None:
# https://github.com/pola-rs/polars/issues/17184
left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
right = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})

q = left.join(
right,
left_on=left_on,
right_on=right_on,
how="inner",
)

with pytest.raises(pl.exceptions.InvalidOperationError):
q.collect()


def test_join_coalesce_not_supported_warning() -> None:
# https://github.com/pola-rs/polars/issues/17184
left = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
right = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5]})

q = left.join(
right,
left_on=[pl.col("a") * 2],
right_on=[pl.col("a") * 2],
how="inner",
coalesce=True,
)
with pytest.warns(UserWarning, match="turning off key coalescing"):
got = q.collect()
expect = pl.DataFrame(
{"a": [1, 2, 3], "b": [3, 4, 5], "a_right": [1, 2, 3], "b_right": [3, 4, 5]}
)

assert_frame_equal(expect, got, check_row_order=False)

0 comments on commit e570e77

Please sign in to comment.