From 519ccb3935faf5d782d4bee2e746f5c2b81fbb80 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 31 Oct 2024 17:54:22 +0400 Subject: [PATCH] feat: Support use of `is_between` range predicate with IEJoin operations (`join_where`) (#19547) --- crates/polars-lazy/src/frame/mod.rs | 40 ++++++++++++ .../polars-plan/src/plans/conversion/join.rs | 27 ++++---- .../unit/operations/test_inequality_join.py | 64 +++++++++++-------- 3 files changed, 92 insertions(+), 39 deletions(-) diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index cf90c5232450..59c70cc78932 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -35,6 +35,8 @@ use polars_expr::{create_physical_expr, ExpressionConversionState}; use polars_io::RowIndex; use polars_mem_engine::{create_physical_plan, Executor}; use polars_ops::frame::JoinCoalesce; +#[cfg(feature = "is_between")] +use polars_ops::prelude::ClosedInterval; pub use polars_plan::frame::{AllowedOptimizations, OptFlags}; use polars_plan::global::FETCH_ROWS; use polars_utils::pl_str::PlSmallStr; @@ -2160,6 +2162,44 @@ impl JoinBuilder { opt_state |= OptFlags::FILE_CACHING; } + // Decompose `is_between` predicates to allow for cleaner expression of range joins + #[cfg(feature = "is_between")] + let predicates: Vec = { + let mut expanded_predicates = Vec::with_capacity(predicates.len() * 2); + for predicate in predicates { + if let Expr::Function { + function: FunctionExpr::Boolean(BooleanFunction::IsBetween { closed }), + input, + .. + } = &predicate + { + if let [expr, lower, upper] = input.as_slice() { + match closed { + ClosedInterval::Both => { + expanded_predicates.push(expr.clone().gt_eq(lower.clone())); + expanded_predicates.push(expr.clone().lt_eq(upper.clone())); + }, + ClosedInterval::Right => { + expanded_predicates.push(expr.clone().gt(lower.clone())); + expanded_predicates.push(expr.clone().lt_eq(upper.clone())); + }, + ClosedInterval::Left => { + expanded_predicates.push(expr.clone().gt_eq(lower.clone())); + expanded_predicates.push(expr.clone().lt(upper.clone())); + }, + ClosedInterval::None => { + expanded_predicates.push(expr.clone().gt(lower.clone())); + expanded_predicates.push(expr.clone().lt(upper.clone())); + }, + } + continue; + } + } + expanded_predicates.push(predicate); + } + expanded_predicates + }; + let args = JoinArgs { how: self.how, validation: self.validation, diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 9d63d18b0a46..7684062de23f 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -163,15 +163,18 @@ fn resolve_join_where( .get(input_right) .schema(ctxt.lp_arena) .into_owned(); - for e in &predicates { - let no_binary_comparisons = e + + for expr in &predicates { + let mut comparison_count = 0; + for _e in expr .into_iter() - .filter(|e| match e { - Expr::BinaryExpr { op, .. } => op.is_comparison(), - _ => false, - }) - .count(); - polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition"); + .filter(|e| matches!(e, Expr::BinaryExpr { op, .. } if op.is_comparison())) + { + comparison_count += 1; + if comparison_count > 1 { + polars_bail!(InvalidOperation: "only one binary comparison allowed in each 'join_where' predicate, found: {:?}", expr); + } + } fn all_in_schema( schema: &Schema, @@ -186,14 +189,14 @@ fn resolve_join_where( }) } - let valid = e.into_iter().all(|e| match e { + let valid = expr.into_iter().all(|e| match e { Expr::BinaryExpr { left, op, right } if op.is_comparison() => { !(all_in_schema(&schema_left, None, left, right) || all_in_schema(&schema_right, Some(&schema_left), left, right)) }, _ => true, }); - polars_ensure!( valid, InvalidOperation: "join predicate in 'join_where' only refers to columns of a single table") + polars_ensure!( valid, InvalidOperation: "'join_where' predicate only refers to columns from a single table") } let owned = |e: Arc| (*e).clone(); @@ -266,7 +269,7 @@ fn resolve_join_where( (left_names, right_names, left, op, right) }; for name in &left_names { - polars_ensure!(!right_names.contains(name.as_str()), InvalidOperation: "got ambiguous column names in 'join_where'\n\n\ + polars_ensure!(!right_names.contains(name.as_str()), InvalidOperation: "found ambiguous column names in 'join_where'\n\n\ Note that you should refer to the column names as they are post-join operation.") } @@ -309,7 +312,7 @@ fn resolve_join_where( let suffix = options.args.suffix().clone(); for pred in predicates.into_iter() { let Expr::BinaryExpr { left, op, right } = pred.clone() else { - polars_bail!(InvalidOperation: "can only join on binary expressions") + polars_bail!(InvalidOperation: "can only join on binary (in)equality expressions, found {:?}", pred) }; polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate"); let (left, op, right) = determine_order_and_pre_join_names( diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 891ac32fa0ba..fd9fca28f72d 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -177,7 +177,21 @@ def test_ie_join_with_expressions() -> None: assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) -def test_join_where_predicates() -> None: +@pytest.mark.parametrize( + "range_constraint", + [ + [ + # can write individual components + pl.col("time") >= pl.col("start_time"), + pl.col("time") < pl.col("end_time"), + ], + [ + # or a single `is_between` expression + pl.col("time").is_between("start_time", "end_time", closed="left") + ], + ], +) +def test_join_where_predicates(range_constraint: list[pl.Expr]) -> None: left = pl.DataFrame( { "id": [0, 1, 2, 3, 4, 5], @@ -209,11 +223,7 @@ def test_join_where_predicates() -> None: } ) - actual = left.join_where( - right, - pl.col("time") >= pl.col("start_time"), - pl.col("time") < pl.col("end_time"), - ).select("id", "id_right") + actual = left.join_where(right, *range_constraint).select("id", "id_right") expected = pl.DataFrame( { @@ -227,9 +237,8 @@ def test_join_where_predicates() -> None: left.lazy() .join_where( right.lazy(), - pl.col("time") >= pl.col("start_time"), - pl.col("time") < pl.col("end_time"), pl.col("group_right") == pl.col("group"), + *range_constraint, ) .select("id", "id_right", "group") .sort("id") @@ -242,11 +251,7 @@ def test_join_where_predicates() -> None: expected = ( left.join(right, how="cross") - .filter( - pl.col("time") >= pl.col("start_time"), - pl.col("time") < pl.col("end_time"), - pl.col("group") == pl.col("group_right"), - ) + .filter(pl.col("group") == pl.col("group_right"), *range_constraint) .select("id", "id_right", "group") .sort("id") ) @@ -255,10 +260,7 @@ def test_join_where_predicates() -> None: q = ( left.lazy() .join_where( - right.lazy(), - pl.col("time") >= pl.col("start_time"), - pl.col("time") < pl.col("end_time"), - pl.col("group") != pl.col("group_right"), + right.lazy(), pl.col("group") != pl.col("group_right"), *range_constraint ) .select("id", "id_right", "group") .sort("id") @@ -271,11 +273,7 @@ def test_join_where_predicates() -> None: expected = ( left.join(right, how="cross") - .filter( - pl.col("time") >= pl.col("start_time"), - pl.col("time") < pl.col("end_time"), - pl.col("group") != pl.col("group_right"), - ) + .filter(pl.col("group") != pl.col("group_right"), *range_constraint) .select("id", "id_right", "group") .sort("id") ) @@ -451,21 +449,30 @@ def test_ie_join_with_floats( def test_raise_on_ambiguous_name() -> None: df = pl.DataFrame({"id": [1, 2]}) - with pytest.raises(pl.exceptions.InvalidOperationError): + with pytest.raises( + pl.exceptions.InvalidOperationError, + match="'join_where' predicate only refers to columns from a single table", + ): df.join_where(df, pl.col("id") >= pl.col("id")) def test_raise_on_multiple_binary_comparisons() -> None: df = pl.DataFrame({"id": [1, 2]}) - with pytest.raises(pl.exceptions.InvalidOperationError): + with pytest.raises( + pl.exceptions.InvalidOperationError, + match="only one binary comparison allowed in each 'join_where' predicate, found: ", + ): df.join_where( - df, (pl.col("id") < pl.col("id")) & (pl.col("id") >= pl.col("id")) + df, (pl.col("id") < pl.col("id")) ^ (pl.col("id") >= pl.col("id")) ) def test_raise_invalid_input_join_where() -> None: df = pl.DataFrame({"id": [1, 2]}) - with pytest.raises(pl.exceptions.InvalidOperationError): + with pytest.raises( + pl.exceptions.InvalidOperationError, + match="expected join keys/predicates", + ): df.join_where(df) @@ -573,7 +580,10 @@ def test_raise_invalid_predicate() -> None: left = pl.LazyFrame({"a": [1, 2]}).with_row_index() right = pl.LazyFrame({"b": [1, 2]}).with_row_index() - with pytest.raises(pl.exceptions.InvalidOperationError): + with pytest.raises( + pl.exceptions.InvalidOperationError, + match="'join_where' predicate only refers to columns from a single table", + ): left.join_where(right, pl.col.index >= pl.col.a).collect()