Skip to content

Commit

Permalink
feat: Support use of is_between range predicate with IEJoin operati…
Browse files Browse the repository at this point in the history
…ons (`join_where`) (#19547)
  • Loading branch information
alexander-beedie authored Oct 31, 2024
1 parent bda606e commit 519ccb3
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 39 deletions.
40 changes: 40 additions & 0 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ use polars_expr::{create_physical_expr, ExpressionConversionState};
use polars_io::RowIndex;
use polars_mem_engine::{create_physical_plan, Executor};
use polars_ops::frame::JoinCoalesce;
#[cfg(feature = "is_between")]
use polars_ops::prelude::ClosedInterval;
pub use polars_plan::frame::{AllowedOptimizations, OptFlags};
use polars_plan::global::FETCH_ROWS;
use polars_utils::pl_str::PlSmallStr;
Expand Down Expand Up @@ -2160,6 +2162,44 @@ impl JoinBuilder {
opt_state |= OptFlags::FILE_CACHING;
}

// Decompose `is_between` predicates to allow for cleaner expression of range joins
#[cfg(feature = "is_between")]
let predicates: Vec<Expr> = {
let mut expanded_predicates = Vec::with_capacity(predicates.len() * 2);
for predicate in predicates {
if let Expr::Function {
function: FunctionExpr::Boolean(BooleanFunction::IsBetween { closed }),
input,
..
} = &predicate
{
if let [expr, lower, upper] = input.as_slice() {
match closed {
ClosedInterval::Both => {
expanded_predicates.push(expr.clone().gt_eq(lower.clone()));
expanded_predicates.push(expr.clone().lt_eq(upper.clone()));
},
ClosedInterval::Right => {
expanded_predicates.push(expr.clone().gt(lower.clone()));
expanded_predicates.push(expr.clone().lt_eq(upper.clone()));
},
ClosedInterval::Left => {
expanded_predicates.push(expr.clone().gt_eq(lower.clone()));
expanded_predicates.push(expr.clone().lt(upper.clone()));
},
ClosedInterval::None => {
expanded_predicates.push(expr.clone().gt(lower.clone()));
expanded_predicates.push(expr.clone().lt(upper.clone()));
},
}
continue;
}
}
expanded_predicates.push(predicate);
}
expanded_predicates
};

let args = JoinArgs {
how: self.how,
validation: self.validation,
Expand Down
27 changes: 15 additions & 12 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,18 @@ fn resolve_join_where(
.get(input_right)
.schema(ctxt.lp_arena)
.into_owned();
for e in &predicates {
let no_binary_comparisons = e

for expr in &predicates {
let mut comparison_count = 0;
for _e in expr
.into_iter()
.filter(|e| match e {
Expr::BinaryExpr { op, .. } => op.is_comparison(),
_ => false,
})
.count();
polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition");
.filter(|e| matches!(e, Expr::BinaryExpr { op, .. } if op.is_comparison()))
{
comparison_count += 1;
if comparison_count > 1 {
polars_bail!(InvalidOperation: "only one binary comparison allowed in each 'join_where' predicate, found: {:?}", expr);
}
}

fn all_in_schema(
schema: &Schema,
Expand All @@ -186,14 +189,14 @@ fn resolve_join_where(
})
}

let valid = e.into_iter().all(|e| match e {
let valid = expr.into_iter().all(|e| match e {
Expr::BinaryExpr { left, op, right } if op.is_comparison() => {
!(all_in_schema(&schema_left, None, left, right)
|| all_in_schema(&schema_right, Some(&schema_left), left, right))
},
_ => true,
});
polars_ensure!( valid, InvalidOperation: "join predicate in 'join_where' only refers to columns of a single table")
polars_ensure!( valid, InvalidOperation: "'join_where' predicate only refers to columns from a single table")
}

let owned = |e: Arc<Expr>| (*e).clone();
Expand Down Expand Up @@ -266,7 +269,7 @@ fn resolve_join_where(
(left_names, right_names, left, op, right)
};
for name in &left_names {
polars_ensure!(!right_names.contains(name.as_str()), InvalidOperation: "got ambiguous column names in 'join_where'\n\n\
polars_ensure!(!right_names.contains(name.as_str()), InvalidOperation: "found ambiguous column names in 'join_where'\n\n\
Note that you should refer to the column names as they are post-join operation.")
}

Expand Down Expand Up @@ -309,7 +312,7 @@ fn resolve_join_where(
let suffix = options.args.suffix().clone();
for pred in predicates.into_iter() {
let Expr::BinaryExpr { left, op, right } = pred.clone() else {
polars_bail!(InvalidOperation: "can only join on binary expressions")
polars_bail!(InvalidOperation: "can only join on binary (in)equality expressions, found {:?}", pred)
};
polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate");
let (left, op, right) = determine_order_and_pre_join_names(
Expand Down
64 changes: 37 additions & 27 deletions py-polars/tests/unit/operations/test_inequality_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,21 @@ def test_ie_join_with_expressions() -> None:
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)


def test_join_where_predicates() -> None:
@pytest.mark.parametrize(
"range_constraint",
[
[
# can write individual components
pl.col("time") >= pl.col("start_time"),
pl.col("time") < pl.col("end_time"),
],
[
# or a single `is_between` expression
pl.col("time").is_between("start_time", "end_time", closed="left")
],
],
)
def test_join_where_predicates(range_constraint: list[pl.Expr]) -> None:
left = pl.DataFrame(
{
"id": [0, 1, 2, 3, 4, 5],
Expand Down Expand Up @@ -209,11 +223,7 @@ def test_join_where_predicates() -> None:
}
)

actual = left.join_where(
right,
pl.col("time") >= pl.col("start_time"),
pl.col("time") < pl.col("end_time"),
).select("id", "id_right")
actual = left.join_where(right, *range_constraint).select("id", "id_right")

expected = pl.DataFrame(
{
Expand All @@ -227,9 +237,8 @@ def test_join_where_predicates() -> None:
left.lazy()
.join_where(
right.lazy(),
pl.col("time") >= pl.col("start_time"),
pl.col("time") < pl.col("end_time"),
pl.col("group_right") == pl.col("group"),
*range_constraint,
)
.select("id", "id_right", "group")
.sort("id")
Expand All @@ -242,11 +251,7 @@ def test_join_where_predicates() -> None:

expected = (
left.join(right, how="cross")
.filter(
pl.col("time") >= pl.col("start_time"),
pl.col("time") < pl.col("end_time"),
pl.col("group") == pl.col("group_right"),
)
.filter(pl.col("group") == pl.col("group_right"), *range_constraint)
.select("id", "id_right", "group")
.sort("id")
)
Expand All @@ -255,10 +260,7 @@ def test_join_where_predicates() -> None:
q = (
left.lazy()
.join_where(
right.lazy(),
pl.col("time") >= pl.col("start_time"),
pl.col("time") < pl.col("end_time"),
pl.col("group") != pl.col("group_right"),
right.lazy(), pl.col("group") != pl.col("group_right"), *range_constraint
)
.select("id", "id_right", "group")
.sort("id")
Expand All @@ -271,11 +273,7 @@ def test_join_where_predicates() -> None:

expected = (
left.join(right, how="cross")
.filter(
pl.col("time") >= pl.col("start_time"),
pl.col("time") < pl.col("end_time"),
pl.col("group") != pl.col("group_right"),
)
.filter(pl.col("group") != pl.col("group_right"), *range_constraint)
.select("id", "id_right", "group")
.sort("id")
)
Expand Down Expand Up @@ -451,21 +449,30 @@ def test_ie_join_with_floats(

def test_raise_on_ambiguous_name() -> None:
df = pl.DataFrame({"id": [1, 2]})
with pytest.raises(pl.exceptions.InvalidOperationError):
with pytest.raises(
pl.exceptions.InvalidOperationError,
match="'join_where' predicate only refers to columns from a single table",
):
df.join_where(df, pl.col("id") >= pl.col("id"))


def test_raise_on_multiple_binary_comparisons() -> None:
df = pl.DataFrame({"id": [1, 2]})
with pytest.raises(pl.exceptions.InvalidOperationError):
with pytest.raises(
pl.exceptions.InvalidOperationError,
match="only one binary comparison allowed in each 'join_where' predicate, found: ",
):
df.join_where(
df, (pl.col("id") < pl.col("id")) & (pl.col("id") >= pl.col("id"))
df, (pl.col("id") < pl.col("id")) ^ (pl.col("id") >= pl.col("id"))
)


def test_raise_invalid_input_join_where() -> None:
df = pl.DataFrame({"id": [1, 2]})
with pytest.raises(pl.exceptions.InvalidOperationError):
with pytest.raises(
pl.exceptions.InvalidOperationError,
match="expected join keys/predicates",
):
df.join_where(df)


Expand Down Expand Up @@ -573,7 +580,10 @@ def test_raise_invalid_predicate() -> None:
left = pl.LazyFrame({"a": [1, 2]}).with_row_index()
right = pl.LazyFrame({"b": [1, 2]}).with_row_index()

with pytest.raises(pl.exceptions.InvalidOperationError):
with pytest.raises(
pl.exceptions.InvalidOperationError,
match="'join_where' predicate only refers to columns from a single table",
):
left.join_where(right, pl.col.index >= pl.col.a).collect()


Expand Down

0 comments on commit 519ccb3

Please sign in to comment.