From c2b8acdb8760c35d96faaaa26dc11a3516c28650 Mon Sep 17 00:00:00 2001 From: ritchie Date: Thu, 2 Jan 2025 12:17:48 +0100 Subject: [PATCH] feat: Support arbitraty expressions in 'join_where' --- .../polars-plan/src/plans/conversion/join.rs | 336 +----------------- .../unit/operations/test_inequality_join.py | 22 +- 2 files changed, 28 insertions(+), 330 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index dd94ed2d7784..ac7e6e1a9ae4 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -389,6 +389,11 @@ fn resolve_join_where( mut options: Arc, ctxt: &mut DslConversionContext, ) -> PolarsResult<(Node, Node)> { + // If not eager, respect the flag. + if ctxt.opt_flags.eager() { + ctxt.opt_flags.set(OptFlags::PREDICATE_PUSHDOWN, true); + } + ctxt.opt_flags.set(OptFlags::COLLAPSE_JOINS, true); check_join_keys(&predicates)?; let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt) .map_err(|e| e.context(failed_here!(join left)))?; @@ -403,17 +408,6 @@ fn resolve_join_where( .into_owned(); for expr in &predicates { - let mut comparison_count = 0; - for _e in expr - .into_iter() - .filter(|e| matches!(e, Expr::BinaryExpr { op, .. } if op.is_comparison())) - { - comparison_count += 1; - if comparison_count > 1 { - polars_bail!(InvalidOperation: "only one binary comparison allowed in each 'join_where' predicate; found {:?}", expr); - } - } - fn all_in_schema( schema: &Schema, other: Option<&Schema>, @@ -437,317 +431,21 @@ fn resolve_join_where( polars_ensure!( valid, InvalidOperation: "'join_where' predicate only refers to columns from a single table") } - let owned = |e: Arc| (*e).clone(); - - // We do a few things - // First we partition to: - // - IEjoin supported inequality predicates - // - equality predicates - // - remaining predicates - // And then decide to which join we dispatch. - // The remaining predicates will be applied as filter. - - // What make things a bit complicated is that duplicate join names - // are referred to in the query with the name post-join, but on joins - // we refer to the names pre-join (e.g. without suffix). So there is some - // bookkeeping. - // - // - First we determine which side of the binary expression refers to the left and right table - // and make sure that lhs of the binary expr, maps to the lhs of the join tables and vice versa. - // Next we ensure the suffixes are removed when we partition. - // - // If a predicate has to be applied as post-join filter, we put the suffixes back if needed. - let mut ie_left_on = vec![]; - let mut ie_right_on = vec![]; - let mut ie_op = vec![]; - - let mut eq_left_on = vec![]; - let mut eq_right_on = vec![]; - - let mut remaining_preds = vec![]; - - fn to_inequality_operator(op: &Operator) -> Option { - match op { - Operator::Lt => Some(InequalityOperator::Lt), - Operator::LtEq => Some(InequalityOperator::LtEq), - Operator::Gt => Some(InequalityOperator::Gt), - Operator::GtEq => Some(InequalityOperator::GtEq), - _ => None, - } - } - - fn rename_expr(e: Expr, old: &str, new: &str) -> Expr { - e.map_expr(|e| match e { - Expr::Column(name) if name.as_str() == old => Expr::Column(new.into()), - e => e, - }) - } - - fn determine_order_and_pre_join_names( - left: Expr, - op: Operator, - right: Expr, - schema_left: &Schema, - schema_right: &Schema, - suffix: &str, - ) -> PolarsResult<(Expr, Operator, Expr)> { - let left_names = expr_to_leaf_column_names_iter(&left).collect::>(); - let right_names = expr_to_leaf_column_names_iter(&right).collect::>(); - - // All left should be in the left schema. - let (left_names, right_names, left, op, mut right) = - if !left_names.iter().all(|n| schema_left.contains(n)) { - // If all right names are in left schema -> swap - if right_names.iter().all(|n| schema_left.contains(n)) { - (right_names, left_names, right, op.swap_operands(), left) - } else { - polars_bail!(InvalidOperation: "got ambiguous column names in 'join_where'") - } - } else { - (left_names, right_names, left, op, right) - }; - for name in &left_names { - polars_ensure!(!right_names.contains(name.as_str()), InvalidOperation: "found ambiguous column names in 'join_where'\n\n\ - Note that you should refer to the column names as they are post-join operation.") - } - - // Now we know left belongs to the left schema, rhs suffixes are dealt with. - for post_join_name in right_names { - if let Some(pre_join_name) = post_join_name.strip_suffix(suffix) { - // Name is both sides, so a suffix will be added by the join. - // We rename - if schema_right.contains(pre_join_name) && schema_left.contains(pre_join_name) { - right = rename_expr(right, &post_join_name, pre_join_name); - } - } - } - Ok((left, op, right)) - } - - // Make it a binary comparison and ensure the columns refer to post join names. - fn to_binary_post_join( - l: Expr, - op: Operator, - mut r: Expr, - schema_right: &Schema, - suffix: &str, - ) -> Expr { - let names = expr_to_leaf_column_names_iter(&r).collect::>(); - for pre_join_name in &names { - if !schema_right.contains(pre_join_name) { - let post_join_name = _join_suffix_name(pre_join_name, suffix); - r = rename_expr(r, pre_join_name, post_join_name.as_str()); - } - } - - Expr::BinaryExpr { - left: Arc::from(l), - op, - right: Arc::from(r), - } - } - - let suffix = options.args.suffix().clone(); - for pred in predicates.into_iter() { - let Expr::BinaryExpr { left, op, right } = pred.clone() else { - polars_bail!(InvalidOperation: "can only join on binary (in)equality expressions, found {:?}", pred) - }; - polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate"); - let (left, op, right) = determine_order_and_pre_join_names( - owned(left), - op, - owned(right), - &schema_left, - &schema_right, - &suffix, - )?; - - if let Some(ie_op_) = to_inequality_operator(&op) { - fn is_numeric(e: &Expr, schema: &Schema) -> bool { - expr_to_leaf_column_names_iter(e).any(|name| { - if let Some(dt) = schema.get(name.as_str()) { - dt.to_physical().is_numeric() - } else { - false - } - }) - } - - // We fallback to remaining if: - // - we already have an IEjoin or Inner join - // - we already have an Inner join - // - data is not numeric (our iejoin doesn't yet implement that) - if ie_op.len() >= 2 - || !eq_right_on.is_empty() - || !is_numeric(&left, &schema_left) - || !is_numeric(&right, &schema_right) - { - remaining_preds.push(to_binary_post_join(left, op, right, &schema_right, &suffix)) - } else { - ie_left_on.push(left); - ie_right_on.push(right); - ie_op.push(ie_op_) - } - } else if matches!(op, Operator::Eq) { - eq_left_on.push(left); - eq_right_on.push(right); - } else { - remaining_preds.push(to_binary_post_join(left, op, right, &schema_right, &suffix)); - } - } - - // Now choose a primary join and do the remaining predicates as filters - // Add the ie predicates to the remaining predicates buffer so that they will be executed in the - // filter node. - fn ie_predicates_to_remaining( - remaining_preds: &mut Vec, - ie_left_on: Vec, - ie_right_on: Vec, - ie_op: Vec, - schema_right: &Schema, - suffix: &str, - ) { - for ((l, op), r) in ie_left_on - .into_iter() - .zip(ie_op.into_iter()) - .zip(ie_right_on.into_iter()) - { - remaining_preds.push(to_binary_post_join(l, op.into(), r, schema_right, suffix)) - } - } - - let (mut last_node, join_node) = if !eq_left_on.is_empty() { - // We found one or more equality predicates. Go into a default equi join - // as those are cheapest on avg. - let (last_node, join_node) = resolve_join( - Either::Right(input_left), - Either::Right(input_right), - eq_left_on, - eq_right_on, - vec![], - options.clone(), - ctxt, - )?; - - ie_predicates_to_remaining( - &mut remaining_preds, - ie_left_on, - ie_right_on, - ie_op, - &schema_right, - &suffix, - ); - (last_node, join_node) - } else if ie_right_on.len() >= 2 { - // Do an IEjoin. - let opts = Arc::make_mut(&mut options); - - opts.args.how = JoinType::IEJoin; - opts.options = Some(JoinTypeOptionsIR::IEJoin(IEJoinOptions { - operator1: ie_op[0], - operator2: Some(ie_op[1]), - })); - - let (last_node, join_node) = resolve_join( - Either::Right(input_left), - Either::Right(input_right), - ie_left_on[..2].to_vec(), - ie_right_on[..2].to_vec(), - vec![], - options.clone(), - ctxt, - )?; - - // The surplus ie-predicates will be added to the remaining predicates so that - // they will be applied in a filter node. - while ie_right_on.len() > 2 { - // Invariant: they all have equal length, so we can pop and unwrap all while len > 2. - // The first 2 predicates are used in the - let l = ie_right_on.pop().unwrap(); - let r = ie_left_on.pop().unwrap(); - let op = ie_op.pop().unwrap(); - - remaining_preds.push(to_binary_post_join(l, op.into(), r, &schema_right, &suffix)) - } - (last_node, join_node) - } else if ie_right_on.len() == 1 { - // For a single inequality comparison, we use the piecewise merge join algorithm - let opts = Arc::make_mut(&mut options); - opts.args.how = JoinType::IEJoin; - opts.options = Some(JoinTypeOptionsIR::IEJoin(IEJoinOptions { - operator1: ie_op[0], - operator2: None, - })); - - resolve_join( - Either::Right(input_left), - Either::Right(input_right), - ie_left_on, - ie_right_on, - vec![], - options.clone(), - ctxt, - )? - } else { - // No predicates found that are supported in a fast algorithm. - // Do a cross join and follow up with filters. - let opts = Arc::make_mut(&mut options); - opts.args.how = JoinType::Cross; - - resolve_join( - Either::Right(input_left), - Either::Right(input_right), - vec![], - vec![], - vec![], - options.clone(), - ctxt, - )? - }; - - let IR::Join { - input_left, - input_right, - .. - } = ctxt.lp_arena.get(join_node) - else { - unreachable!() - }; - let schema_right = ctxt - .lp_arena - .get(*input_right) - .schema(ctxt.lp_arena) - .into_owned(); + let opts = Arc::make_mut(&mut options); + opts.args.how = JoinType::Cross; - let schema_left = ctxt - .lp_arena - .get(*input_left) - .schema(ctxt.lp_arena) - .into_owned(); + let (mut last_node, join_node) = resolve_join( + Either::Right(input_left), + Either::Right(input_right), + vec![], + vec![], + vec![], + options.clone(), + ctxt, + )?; - // Ensure that the predicates use the proper suffix - for e in remaining_preds { + for e in predicates { let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?; - let AExpr::BinaryExpr { mut right, .. } = *ctxt.expr_arena.get(predicate.node()) else { - unreachable!() - }; - - let original_right = right; - - for name in aexpr_to_leaf_names(right, ctxt.expr_arena) { - polars_ensure!(schema_right.contains(name.as_str()), ColumnNotFound: "could not find column {name} in the right table during join operation"); - if schema_left.contains(name.as_str()) { - let new_name = _join_suffix_name(name.as_str(), suffix.as_str()); - - right = rename_matching_aexpr_leaf_names( - right, - ctxt.expr_arena, - name.as_str(), - new_name, - ); - } - } - ctxt.expr_arena.swap(right, original_right); let ir = IR::Filter { input: last_node, diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 848a4b2b7f85..2495d5b84f2d 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -461,17 +461,6 @@ def test_raise_on_ambiguous_name() -> None: df.join_where(df, pl.col("id") >= pl.col("id")) -def test_raise_on_multiple_binary_comparisons() -> None: - df = pl.DataFrame({"id": [1, 2]}) - with pytest.raises( - pl.exceptions.InvalidOperationError, - match="only one binary comparison allowed in each 'join_where' predicate; found ", - ): - df.join_where( - df, (pl.col("id") < pl.col("id")) ^ (pl.col("id") >= pl.col("id")) - ) - - def test_raise_invalid_input_join_where() -> None: df = pl.DataFrame({"id": [1, 2]}) with pytest.raises( @@ -681,3 +670,14 @@ def test_join_where_literal_20061() -> None: "value_right": [5, 5, 5, 25], "flag_right": [1, 1, 1, 1], } + + +def test_boolean_predicate_join_where() -> None: + urls = pl.LazyFrame({"url": "abcd.com/page"}) + categories = pl.LazyFrame({"base_url": "abcd.com", "category": "landing page"}) + assert ( + "NESTED LOOP JOIN" + in urls.join_where( + categories, pl.col("url").str.starts_with(pl.col("base_url")) + ).explain() + )