Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix validation for inner and left join when join_nulls unflaged #19698

Merged
merged 6 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion crates/polars-ops/src/frame/join/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ impl JoinValidation {
s_left: &Series,
s_right: &Series,
build_shortest_table: bool,
join_nulls: bool,
) -> PolarsResult<()> {
// In default, probe is the left series.
//
Expand All @@ -253,7 +254,13 @@ impl JoinValidation {
// Only check the `build` side.
// The other side use `validate_build` to check
ManyToMany | ManyToOne => true,
OneToMany | OneToOne => probe.n_unique()? == probe.len(),
OneToMany | OneToOne => {
if !join_nulls && probe.null_count() > 0 {
probe.n_unique()? - 1 == probe.len() - probe.null_count()
} else {
probe.n_unique()? == probe.len()
}
},
};
polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);
Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub trait SeriesJoin: SeriesSealed + Sized {
) -> PolarsResult<LeftJoinIds> {
let s_self = self.as_series();
let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr());
validate.validate_probe(&lhs, &rhs, false)?;
validate.validate_probe(&lhs, &rhs, false, join_nulls)?;

let lhs_dtype = lhs.dtype();
let rhs_dtype = rhs.dtype();
Expand All @@ -35,7 +35,8 @@ pub trait SeriesJoin: SeriesSealed + Sized {
let (lhs, rhs, _, _) = prepare_binary::<BinaryType>(lhs, rhs, false);
let lhs = lhs.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
let rhs = rhs.iter().map(|v| v.as_slice()).collect::<Vec<_>>();
hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls)
let build_null_count = other.null_count();
hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls, build_null_count)
},
T::BinaryOffset => {
let lhs = lhs.binary_offset().unwrap();
Expand All @@ -44,7 +45,8 @@ pub trait SeriesJoin: SeriesSealed + Sized {
// Take slices so that vecs are not copied
let lhs = lhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
let rhs = rhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls)
let build_null_count = other.null_count();
hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls, build_null_count)
},
x if x.is_float() => {
with_match_physical_float_polars_type!(lhs.dtype(), |$T| {
Expand Down Expand Up @@ -168,7 +170,7 @@ pub trait SeriesJoin: SeriesSealed + Sized {
) -> PolarsResult<(InnerJoinIds, bool)> {
let s_self = self.as_series();
let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr());
validate.validate_probe(&lhs, &rhs, true)?;
validate.validate_probe(&lhs, &rhs, true, join_nulls)?;

let lhs_dtype = lhs.dtype();
let rhs_dtype = rhs.dtype();
Expand All @@ -184,8 +186,20 @@ pub trait SeriesJoin: SeriesSealed + Sized {
// Take slices so that vecs are not copied
let lhs = lhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
let rhs = rhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
let build_null_count = if swapped {
s_self.null_count()
} else {
other.null_count()
};
Ok((
hash_join_tuples_inner(lhs, rhs, swapped, validate, join_nulls)?,
hash_join_tuples_inner(
lhs,
rhs,
swapped,
validate,
join_nulls,
build_null_count,
)?,
!swapped,
))
},
Expand All @@ -196,8 +210,20 @@ pub trait SeriesJoin: SeriesSealed + Sized {
// Take slices so that vecs are not copied
let lhs = lhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
let rhs = rhs.iter().map(|k| k.as_slice()).collect::<Vec<_>>();
let build_null_count = if swapped {
s_self.null_count()
} else {
other.null_count()
};
Ok((
hash_join_tuples_inner(lhs, rhs, swapped, validate, join_nulls)?,
hash_join_tuples_inner(
lhs,
rhs,
swapped,
validate,
join_nulls,
build_null_count,
)?,
!swapped,
))
},
Expand Down Expand Up @@ -244,7 +270,7 @@ pub trait SeriesJoin: SeriesSealed + Sized {
) -> PolarsResult<(PrimitiveArray<IdxSize>, PrimitiveArray<IdxSize>)> {
let s_self = self.as_series();
let (lhs, rhs) = (s_self.to_physical_repr(), other.to_physical_repr());
validate.validate_probe(&lhs, &rhs, true)?;
validate.validate_probe(&lhs, &rhs, true, join_nulls)?;

let lhs_dtype = lhs.dtype();
let rhs_dtype = rhs.dtype();
Expand Down Expand Up @@ -352,20 +378,38 @@ where
.map(|arr| arr.as_slice().unwrap())
.collect::<Vec<_>>();
Ok((
hash_join_tuples_inner(splitted_a, splitted_b, swapped, validate, join_nulls)?,
hash_join_tuples_inner(
splitted_a, splitted_b, swapped, validate, join_nulls, 0,
)?,
!swapped,
))
} else {
Ok((
hash_join_tuples_inner(splitted_a, splitted_b, swapped, validate, join_nulls)?,
hash_join_tuples_inner(
splitted_a, splitted_b, swapped, validate, join_nulls, 0,
)?,
!swapped,
))
}
},
_ => Ok((
hash_join_tuples_inner(splitted_a, splitted_b, swapped, validate, join_nulls)?,
!swapped,
)),
_ => {
let build_null_count = if swapped {
left.null_count()
} else {
right.null_count()
};
Ok((
hash_join_tuples_inner(
splitted_a,
splitted_b,
swapped,
validate,
join_nulls,
build_null_count,
)?,
!swapped,
))
},
}
}

Expand Down Expand Up @@ -430,7 +474,7 @@ where
(0, 0, 1, 1) => {
let keys_a = chunks_as_slices(&splitted_a);
let keys_b = chunks_as_slices(&splitted_b);
hash_join_tuples_left(keys_a, keys_b, None, None, validate, join_nulls)
hash_join_tuples_left(keys_a, keys_b, None, None, validate, join_nulls, 0)
},
(0, 0, _, _) => {
let keys_a = chunks_as_slices(&splitted_a);
Expand All @@ -445,20 +489,23 @@ where
mapping_right.as_deref(),
validate,
join_nulls,
0,
)
},
_ => {
let keys_a = get_arrays(&splitted_a);
let keys_b = get_arrays(&splitted_b);
let (mapping_left, mapping_right) =
create_mappings(left.chunks(), right.chunks(), left.len(), right.len());
let build_null_count = right.null_count();
hash_join_tuples_left(
keys_a,
keys_b,
mapping_left.as_deref(),
mapping_right.as_deref(),
validate,
join_nulls,
build_null_count,
)
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pub(super) fn hash_join_tuples_inner<T, I>(
swapped: bool,
validate: JoinValidation,
join_nulls: bool,
// We should know the number of nulls to avoid extra calculation
ritchie46 marked this conversation as resolved.
Show resolved Hide resolved
build_null_count: usize,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure about adding this param, but I wanted to avoid performing another pass on all values to calculate it inside the function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's worth it.

) -> PolarsResult<(Vec<IdxSize>, Vec<IdxSize>)>
where
I: IntoIterator<Item = T> + Send + Sync + Clone,
Expand All @@ -53,10 +55,13 @@ where
// NOTE: see the left join for more elaborate comments
// first we hash one relation
let hash_tbls = if validate.needs_checks() {
let expected_size = build
let mut expected_size = build
.iter()
.map(|v| v.clone().into_iter().size_hint().1.unwrap())
.sum();
if !join_nulls {
expected_size -= build_null_count;
}
let hash_tbls = build_tables(build, join_nulls);
let build_size = hash_tbls.iter().map(|m| m.len()).sum();
validate.validate_build(build_size, expected_size, swapped)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ pub(super) fn hash_join_tuples_left<T, I>(
chunk_mapping_right: Option<&[ChunkId]>,
validate: JoinValidation,
join_nulls: bool,
// We should know the number of nulls to avoid extra calculation
build_null_count: usize,
) -> PolarsResult<LeftJoinIds>
where
I: IntoIterator<Item = T>,
Expand All @@ -123,7 +125,10 @@ where
let build = build.into_iter().map(|i| i.into_iter()).collect::<Vec<_>>();
// first we hash one relation
let hash_tbls = if validate.needs_checks() {
let expected_size = build.iter().map(|v| v.size_hint().1.unwrap()).sum();
let mut expected_size = build.iter().map(|v| v.size_hint().1.unwrap()).sum();
if !join_nulls {
expected_size -= build_null_count;
}
let hash_tbls = build_tables(build, join_nulls);
let build_size = hash_tbls.iter().map(|m| m.len()).sum();
validate.validate_build(build_size, expected_size, false)?;
Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/sql/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,26 @@ def test_nested_join(join_clause: str) -> None:
"Species": "Human",
},
]


def test_join_nulls_19624() -> None:
df1 = pl.DataFrame({"a": [1, 2, None, None]})
df2 = pl.DataFrame({"a": [1, 1, 2, 2, None], "b": [0, 1, 2, 3, 4]})

# left join
result_df = df1.join(df2, how="left", on="a", join_nulls=False, validate="1:m")
expected_df = pl.DataFrame(
{"a": [1, 1, 2, 2, None, None], "b": [0, 1, 2, 3, None, None]}
)
assert_frame_equal(result_df, expected_df)
result_df = df2.join(df1, how="left", on="a", join_nulls=False, validate="m:1")
expected_df = pl.DataFrame({"a": [1, 1, 2, 2, None], "b": [0, 1, 2, 3, 4]})
assert_frame_equal(result_df, expected_df)

# inner join
result_df = df1.join(df2, how="inner", on="a", join_nulls=False, validate="1:m")
expected_df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [0, 1, 2, 3]})
assert_frame_equal(result_df, expected_df)
result_df = df2.join(df1, how="inner", on="a", join_nulls=False, validate="m:1")
expected_df = pl.DataFrame({"a": [1, 1, 2, 2], "b": [0, 1, 2, 3]})
assert_frame_equal(result_df, expected_df)
Loading