diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index 95dee57901e..e2a686c37bb 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -263,4 +263,8 @@ impl AExpr { pub(crate) fn is_leaf(&self) -> bool { matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len) } + + pub(crate) fn is_col(&self) -> bool { + matches!(self, AExpr::Column(_)) + } } diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 9c973d446ea..d177119dc82 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -148,55 +148,60 @@ pub fn resolve_join( }; } + // If we do a full join and keys are coalesced, the casted keys must be added up front. + let key_cols_coalesced = + options.args.should_coalesce() && matches!(&options.args.how, JoinType::Full); let mut to_cast_left = vec![]; let mut to_cast_right = vec![]; let mut to_cast_indices = vec![]; - for (i, (lnode, rnode)) in left_on.iter().zip(right_on.iter()).enumerate() { + for (i, (lnode, rnode)) in left_on.iter_mut().zip(right_on.iter_mut()).enumerate() { let ltype = get_dtype!(lnode, &schema_left)?; let rtype = get_dtype!(rnode, &schema_right)?; - if let (AExpr::Column(lname), AExpr::Column(rname)) = ( - ctxt.expr_arena.get(lnode.node()).clone(), - ctxt.expr_arena.get(rnode.node()).clone(), - ) { - if let Some(dtype) = get_numeric_upcast_supertype_lossless(<ype, &rtype) { - { - let expr = ctxt.expr_arena.add(AExpr::Column(lname.clone())); - to_cast_left.push(ExprIR::new( - ctxt.expr_arena.add(AExpr::Cast { - expr, - dtype: dtype.clone(), - options: CastOptions::Strict, - }), - OutputName::ColumnLhs(lname), - )) - }; - - { - let expr = ctxt.expr_arena.add(AExpr::Column(rname.clone())); - to_cast_right.push(ExprIR::new( - ctxt.expr_arena.add(AExpr::Cast { - expr, - dtype: dtype.clone(), - options: CastOptions::Strict, - }), - OutputName::ColumnLhs(rname), - )) - }; + if let Some(dtype) = get_numeric_upcast_supertype_lossless(<ype, &rtype) { + let casted_l = ctxt.expr_arena.add(AExpr::Cast { + expr: lnode.node(), + dtype: dtype.clone(), + options: CastOptions::Strict, + }); + let casted_r = ctxt.expr_arena.add(AExpr::Cast { + expr: rnode.node(), + dtype, + options: CastOptions::Strict, + }); + + if key_cols_coalesced { + let mut lnode = lnode.clone(); + let mut rnode = rnode.clone(); + + let ael = ctxt.expr_arena.get(lnode.node()); + let aer = ctxt.expr_arena.get(rnode.node()); + + polars_ensure!( + ael.is_col() && aer.is_col(), + SchemaMismatch: "can only 'coalesce' full join if join keys are column expressions", + ); - to_cast_indices.push(i); + lnode.set_node(casted_l); + rnode.set_node(casted_r); - continue; + to_cast_indices.push(i); + to_cast_right.push(rnode); + to_cast_left.push(lnode); + } else { + lnode.set_node(casted_l); + rnode.set_node(casted_r); } + } else { + polars_ensure!( + ltype == rtype, + SchemaMismatch: "datatypes of join keys don't match - `{}`: {} on left does not match `{}`: {} on right", + lnode.output_name(), ltype, rnode.output_name(), rtype + ) } - - polars_ensure!( - ltype == rtype, - SchemaMismatch: "datatypes of join keys don't match - `{}`: {} on left does not match `{}`: {} on right", - lnode.output_name(), ltype, rnode.output_name(), rtype - ) } + // Every expression must be elementwise so that we are // guaranteed the keys for a join are all the same length. @@ -234,15 +239,6 @@ pub fn resolve_join( options: ProjectionOptions::default(), }) }; - } else { - for ((i, ir_left), ir_right) in to_cast_indices - .into_iter() - .zip(to_cast_left) - .zip(to_cast_right) - { - left_on[i] = ir_left; - right_on[i] = ir_right; - } } let lp = IR::Join { diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index fd384408e24..95d16438cda 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -1320,48 +1320,48 @@ def test_join_preserve_order_full() -> None: @pytest.mark.parametrize( "dtypes", [ - ["Int128" , "Int128" , "Int64" ], - ["Int128" , "Int128" , "Int32" ], - ["Int128" , "Int128" , "Int16" ], - ["Int128" , "Int128" , "Int8" ], - ["Int128" , "UInt64" , "Int128" ], - ["Int128" , "UInt64" , "Int64" ], - ["Int128" , "UInt64" , "Int32" ], - ["Int128" , "UInt64" , "Int16" ], - ["Int128" , "UInt64" , "Int8" ], - ["Int128" , "UInt32" , "Int128" ], - ["Int128" , "UInt16" , "Int128" ], - ["Int128" , "UInt8" , "Int128" ], - - ["Int64" , "Int64" , "Int32" ], - ["Int64" , "Int64" , "Int16" ], - ["Int64" , "Int64" , "Int8" ], - ["Int64" , "UInt32" , "Int64" ], - ["Int64" , "UInt32" , "Int32" ], - ["Int64" , "UInt32" , "Int16" ], - ["Int64" , "UInt32" , "Int8" ], - ["Int64" , "UInt16" , "Int64" ], - ["Int64" , "UInt8" , "Int64" ], - - ["Int32" , "Int32" , "Int16" ], - ["Int32" , "Int32" , "Int8" ], - ["Int32" , "UInt16" , "Int32" ], - ["Int32" , "UInt16" , "Int16" ], - ["Int32" , "UInt16" , "Int8" ], - ["Int32" , "UInt8" , "Int32" ], - - ["Int16" , "Int16" , "Int8" ], - ["Int16" , "UInt8" , "Int16" ], - ["Int16" , "UInt8" , "Int8" ], - - ["UInt64" , "UInt64" , "UInt32" ], - ["UInt64" , "UInt64" , "UInt16" ], - ["UInt64" , "UInt64" , "UInt8" ], - - ["UInt32" , "UInt32" , "UInt16" ], - ["UInt32" , "UInt32" , "UInt8" ], - - ["UInt16" , "UInt16" , "UInt8" ], + ["Int128", "Int128", "Int64"], + ["Int128", "Int128", "Int32"], + ["Int128", "Int128", "Int16"], + ["Int128", "Int128", "Int8"], + ["Int128", "UInt64", "Int128"], + ["Int128", "UInt64", "Int64"], + ["Int128", "UInt64", "Int32"], + ["Int128", "UInt64", "Int16"], + ["Int128", "UInt64", "Int8"], + ["Int128", "UInt32", "Int128"], + ["Int128", "UInt16", "Int128"], + ["Int128", "UInt8", "Int128"], + + ["Int64", "Int64", "Int32"], + ["Int64", "Int64", "Int16"], + ["Int64", "Int64", "Int8"], + ["Int64", "UInt32", "Int64"], + ["Int64", "UInt32", "Int32"], + ["Int64", "UInt32", "Int16"], + ["Int64", "UInt32", "Int8"], + ["Int64", "UInt16", "Int64"], + ["Int64", "UInt8", "Int64"], + + ["Int32", "Int32", "Int16"], + ["Int32", "Int32", "Int8"], + ["Int32", "UInt16", "Int32"], + ["Int32", "UInt16", "Int16"], + ["Int32", "UInt16", "Int8"], + ["Int32", "UInt8", "Int32"], + + ["Int16", "Int16", "Int8"], + ["Int16", "UInt8", "Int16"], + ["Int16", "UInt8", "Int8"], + + ["UInt64", "UInt64", "UInt32"], + ["UInt64", "UInt64", "UInt16"], + ["UInt64", "UInt64", "UInt8"], + + ["UInt32", "UInt32", "UInt16"], + ["UInt32", "UInt32", "UInt8"], + + ["UInt16", "UInt16", "UInt8"], ["Float64", "Float64", "Float32"], ],