Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable joins between compatible differing numeric key columns #20332

Merged
merged 8 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,14 +598,15 @@ impl DataType {
}

/// Check if this [`DataType`] is a basic floating point type (excludes Decimal).
/// Note, this also includes `Unknown(UnknownKind::Float)`.
pub fn is_float(&self) -> bool {
matches!(
self,
DataType::Float32 | DataType::Float64 | DataType::Unknown(UnknownKind::Float)
)
}

/// Check if this [`DataType`] is an integer.
/// Check if this [`DataType`] is an integer. Note, this also includes `Unknown(UnknownKind::Int)`.
pub fn is_integer(&self) -> bool {
matches!(
self,
Expand Down
70 changes: 70 additions & 0 deletions crates/polars-core/src/utils/supertype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,76 @@ pub fn try_get_supertype(l: &DataType, r: &DataType) -> PolarsResult<DataType> {
)
}

/// Returns a numeric supertype that `l` and `r` can be safely upcasted to if it exists.
pub fn get_numeric_upcast_supertype_lossless(l: &DataType, r: &DataType) -> Option<DataType> {
use DataType::*;

if l == r || matches!(l, Unknown(_)) || matches!(r, Unknown(_)) {
None
} else if l.is_float() && r.is_float() {
match (l, r) {
(Float64, _) | (_, Float64) => Some(Float64),
v => {
// Did we add a new float type?
if cfg!(debug_assertions) {
panic!("{:?}", v)
} else {
None
}
},
}
} else if l.is_signed_integer() && r.is_signed_integer() {
match (l, r) {
(Int128, _) | (_, Int128) => Some(Int128),
(Int64, _) | (_, Int64) => Some(Int64),
(Int32, _) | (_, Int32) => Some(Int32),
(Int16, _) | (_, Int16) => Some(Int16),
(Int8, _) | (_, Int8) => Some(Int8),
v => {
if cfg!(debug_assertions) {
panic!("{:?}", v)
} else {
None
}
},
}
} else if l.is_unsigned_integer() && r.is_unsigned_integer() {
match (l, r) {
(UInt64, _) | (_, UInt64) => Some(UInt64),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(UInt16, _) | (_, UInt16) => Some(UInt16),
(UInt8, _) | (_, UInt8) => Some(UInt8),
v => {
if cfg!(debug_assertions) {
panic!("{:?}", v)
} else {
None
}
},
}
} else if l.is_integer() && r.is_integer() {
// One side is signed, the other is unsigned. We just need to upcast the
// unsigned side to a signed integer with the next-largest bit width.
match (l, r) {
(UInt64, _) | (_, UInt64) | (Int128, _) | (_, Int128) => Some(Int128),
(UInt32, _) | (_, UInt32) | (Int64, _) | (_, Int64) => Some(Int64),
(UInt16, _) | (_, UInt16) | (Int32, _) | (_, Int32) => Some(Int32),
(UInt8, _) | (_, UInt8) | (Int16, _) | (_, Int16) => Some(Int16),
v => {
// One side was UInt and we should have already matched against
// all the UInt types
if cfg!(debug_assertions) {
panic!("{:?}", v)
} else {
None
}
},
}
} else {
None
}
}

bitflags! {
#[repr(transparent)]
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
Expand Down
6 changes: 5 additions & 1 deletion crates/polars-ops/src/frame/join/hash_join/sort_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,17 @@ pub(super) fn par_sorted_merge_left(
DataType::Int64 => {
par_sorted_merge_left_impl(s_left.i64().unwrap(), s_right.i64().unwrap())
},
#[cfg(feature = "dtype-i128")]
DataType::Int128 => {
par_sorted_merge_left_impl(s_left.i128().unwrap(), s_right.i128().unwrap())
},
DataType::Float32 => {
par_sorted_merge_left_impl(s_left.f32().unwrap(), s_right.f32().unwrap())
},
DataType::Float64 => {
par_sorted_merge_left_impl(s_left.f64().unwrap(), s_right.f64().unwrap())
},
_ => unreachable!(),
dt => panic!("{:?}", dt),
}
}
#[cfg(feature = "performant")]
Expand Down
110 changes: 98 additions & 12 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use arrow::legacy::error::PolarsResult;
use either::Either;
use polars_core::chunked_array::cast::CastOptions;
use polars_core::error::feature_gated;
use polars_core::utils::get_numeric_upcast_supertype_lossless;

use super::*;
use crate::dsl::Expr;
Expand Down Expand Up @@ -74,10 +76,10 @@ pub fn resolve_join(
);
}

let input_left = input_left.map_right(Ok).right_or_else(|input| {
let mut input_left = input_left.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join left)))
})?;
let input_right = input_right.map_right(Ok).right_or_else(|input| {
let mut input_right = input_right.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join right)))
})?;

Expand All @@ -87,8 +89,8 @@ pub fn resolve_join(
let schema = det_join_schema(&schema_left, &schema_right, &left_on, &right_on, &options)
.map_err(|e| e.context(failed_here!(join schema resolving)))?;

let left_on = to_expr_irs_ignore_alias(left_on, ctxt.expr_arena)?;
let right_on = to_expr_irs_ignore_alias(right_on, ctxt.expr_arena)?;
let mut left_on = to_expr_irs_ignore_alias(left_on, ctxt.expr_arena)?;
let mut right_on = to_expr_irs_ignore_alias(right_on, ctxt.expr_arena)?;
let mut joined_on = PlHashSet::new();

#[cfg(feature = "iejoin")]
Expand Down Expand Up @@ -118,14 +120,58 @@ pub fn resolve_join(
.coerce_types(ctxt.expr_arena, ctxt.lp_arena, input_right)
.map_err(|e| e.context("'join' failed".into()))?;

let get_dtype = |expr: &ExprIR, schema: &SchemaRef| {
ctxt.expr_arena
.get(expr.node())
.get_type(schema, Context::Default, ctxt.expr_arena)
};
for (lnode, rnode) in left_on.iter().zip(right_on.iter()) {
let ltype = get_dtype(lnode, &schema_left)?;
let rtype = get_dtype(rnode, &schema_right)?;
// Not a closure to avoid borrow issues because we mutate expr_arena as well.
macro_rules! get_dtype {
($expr:expr, $schema:expr) => {
ctxt.expr_arena
.get($expr.node())
.get_type($schema, Context::Default, ctxt.expr_arena)
};
}

let mut to_cast_left = vec![];
let mut to_cast_right = vec![];
let mut to_cast_indices = vec![];

for (i, (lnode, rnode)) in left_on.iter().zip(right_on.iter()).enumerate() {
let ltype = get_dtype!(lnode, &schema_left)?;
let rtype = get_dtype!(rnode, &schema_right)?;

if let (AExpr::Column(lname), AExpr::Column(rname)) = (
ctxt.expr_arena.get(lnode.node()).clone(),
ctxt.expr_arena.get(rnode.node()).clone(),
) {
if let Some(dtype) = get_numeric_upcast_supertype_lossless(&ltype, &rtype) {
{
let expr = ctxt.expr_arena.add(AExpr::Column(lname.clone()));
to_cast_left.push(ExprIR::new(
ctxt.expr_arena.add(AExpr::Cast {
expr,
dtype: dtype.clone(),
options: CastOptions::Strict,
}),
OutputName::ColumnLhs(lname),
))
};

{
let expr = ctxt.expr_arena.add(AExpr::Column(rname.clone()));
to_cast_right.push(ExprIR::new(
ctxt.expr_arena.add(AExpr::Cast {
expr,
dtype: dtype.clone(),
options: CastOptions::Strict,
}),
OutputName::ColumnLhs(rname),
))
};

to_cast_indices.push(i);

continue;
}
}

polars_ensure!(
ltype == rtype,
SchemaMismatch: "datatypes of join keys don't match - `{}`: {} on left does not match `{}`: {} on right",
Expand All @@ -145,6 +191,46 @@ pub fn resolve_join(
InvalidOperation: "all join key expressions must be elementwise."
);

// These are Arc<Schema>, into_owned is free.
let schema_left = schema_left.into_owned();
let schema_right = schema_right.into_owned();

let key_cols_coalesced =
options.args.should_coalesce() && matches!(&options.args.how, JoinType::Full);

if key_cols_coalesced {
Copy link
Collaborator Author

@nameexhaustion nameexhaustion Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made it so that we maintain the input column types in the output for all cases except for full-join with coalesce=True. Alternatively we could also just always use the supertype in the result.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that what you did now is correct. 👍

input_left = if to_cast_left.is_empty() {
input_left
} else {
ctxt.lp_arena.add(IR::HStack {
input: input_left,
exprs: to_cast_left,
schema: schema_left,
options: ProjectionOptions::default(),
})
};

input_right = if to_cast_right.is_empty() {
input_right
} else {
ctxt.lp_arena.add(IR::HStack {
input: input_right,
exprs: to_cast_right,
schema: schema_right,
options: ProjectionOptions::default(),
})
};
} else {
for ((i, ir_left), ir_right) in to_cast_indices
.into_iter()
.zip(to_cast_left)
.zip(to_cast_right)
{
left_on[i] = ir_left;
right_on[i] = ir_right;
}
}

let lp = IR::Join {
input_left,
input_right,
Expand Down
111 changes: 109 additions & 2 deletions py-polars/tests/unit/operations/test_join.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing
import warnings
from datetime import date, datetime
from typing import TYPE_CHECKING, Literal

Expand Down Expand Up @@ -976,7 +977,11 @@ def test_join_raise_on_redundant_keys() -> None:
def test_join_raise_on_repeated_expression_key_names(coalesce: bool) -> None:
left = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7]})
right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]})
with pytest.raises(InvalidOperationError, match="already joined on"):
with ( # noqa: PT012
pytest.raises(InvalidOperationError, match="already joined on"),
warnings.catch_warnings(),
):
warnings.simplefilter(action="ignore", category=UserWarning)
left.join(
right, on=[pl.col("a"), pl.col("a") % 2], how="full", coalesce=coalesce
)
Expand Down Expand Up @@ -1232,7 +1237,6 @@ def test_join_preserve_order_full() -> None:
right = pl.LazyFrame({"a": [1, None, 2, 6], "b": [6, 7, 8, 9]})

full_left = left.join(right, on="a", how="full", maintain_order="left").collect()
print(full_left)
assert full_left.get_column("a").cast(pl.UInt32).to_list()[:5] == [
None,
2,
Expand Down Expand Up @@ -1274,3 +1278,106 @@ def test_join_preserve_order_full() -> None:
None,
5,
]


@pytest.mark.parametrize(
"dtypes",
[
["Int128" , "Int128" , "Int64" ],
["Int128" , "Int128" , "Int32" ],
["Int128" , "Int128" , "Int16" ],
["Int128" , "Int128" , "Int8" ],
["Int128" , "UInt64" , "Int128" ],
["Int128" , "UInt64" , "Int64" ],
["Int128" , "UInt64" , "Int32" ],
["Int128" , "UInt64" , "Int16" ],
["Int128" , "UInt64" , "Int8" ],
["Int128" , "UInt32" , "Int128" ],
["Int128" , "UInt16" , "Int128" ],
["Int128" , "UInt8" , "Int128" ],

["Int64" , "Int64" , "Int32" ],
["Int64" , "Int64" , "Int16" ],
["Int64" , "Int64" , "Int8" ],
["Int64" , "UInt32" , "Int64" ],
["Int64" , "UInt32" , "Int32" ],
["Int64" , "UInt32" , "Int16" ],
["Int64" , "UInt32" , "Int8" ],
["Int64" , "UInt16" , "Int64" ],
["Int64" , "UInt8" , "Int64" ],

["Int32" , "Int32" , "Int16" ],
["Int32" , "Int32" , "Int8" ],
["Int32" , "UInt16" , "Int32" ],
["Int32" , "UInt16" , "Int16" ],
["Int32" , "UInt16" , "Int8" ],
["Int32" , "UInt8" , "Int32" ],

["Int16" , "Int16" , "Int8" ],
["Int16" , "UInt8" , "Int16" ],
["Int16" , "UInt8" , "Int8" ],

["UInt64" , "UInt64" , "UInt32" ],
["UInt64" , "UInt64" , "UInt16" ],
["UInt64" , "UInt64" , "UInt8" ],

["UInt32" , "UInt32" , "UInt16" ],
["UInt32" , "UInt32" , "UInt8" ],

["UInt16" , "UInt16" , "UInt8" ],

["Float64", "Float64", "Float32"],
],
) # fmt: skip
@pytest.mark.parametrize("swap", [True, False])
def test_join_numeric_type_upcast_15338(
dtypes: tuple[str, str, str], swap: bool
) -> None:
supertype, ltype, rtype = (getattr(pl, x) for x in dtypes)
ltype, rtype = (rtype, ltype) if swap else (ltype, rtype)

left = pl.select(pl.Series("a", [1, 1, 3]).cast(ltype)).lazy()
right = pl.select(pl.Series("a", [1]).cast(rtype), b=pl.lit("A")).lazy()

assert_frame_equal(
left.join(right, on="a", how="left").collect(),
pl.select(a=pl.Series([1, 1, 3]).cast(ltype), b=pl.Series(["A", "A", None])),
)

assert_frame_equal(
left.join(right, on="a", how="left", coalesce=False).drop("a_right").collect(),
pl.select(a=pl.Series([1, 1, 3]).cast(ltype), b=pl.Series(["A", "A", None])),
)

assert_frame_equal(
left.join(right, on="a", how="full").collect(),
pl.select(
a=pl.Series([1, 1, 3]).cast(ltype),
a_right=pl.Series([1, 1, None]).cast(rtype),
b=pl.Series(["A", "A", None]),
),
)

assert_frame_equal(
left.join(right, on="a", how="full", coalesce=True).collect(),
pl.select(
a=pl.Series([1, 1, 3]).cast(supertype),
b=pl.Series(["A", "A", None]),
),
)

assert_frame_equal(
left.join(right, on="a", how="semi").collect(),
pl.select(a=pl.Series([1, 1]).cast(ltype)),
)


def test_join_numeric_type_upcast_forbid_float_int() -> None:
ltype = pl.Float64
rtype = pl.Int32

left = pl.LazyFrame(schema={"a": ltype})
right = pl.LazyFrame(schema={"a": rtype})

with pytest.raises(SchemaError, match="datatypes of join keys don't match"):
left.join(right, on="a", how="left").collect()
Loading
Loading