Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hash join for nested types #11232

Merged
merged 7 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ pub fn can_hash(data_type: &DataType) -> bool {
DataType::List(_) => true,
DataType::LargeList(_) => true,
DataType::FixedSizeList(_, _) => true,
DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
_ => false,
}
}
Expand Down
111 changes: 108 additions & 3 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1212,11 +1212,16 @@ fn eq_dyn_null(
right: &dyn Array,
null_equals_null: bool,
) -> Result<BooleanArray, ArrowError> {
// Nested datatypes cannot use the underlying not_distinct function and must use a special
// Nested datatypes cannot use the underlying not_distinct/eq function and must use a special
// implementation
// <https://github.com/apache/datafusion/issues/10749>
if left.data_type().is_nested() && null_equals_null {
return Ok(compare_op_for_nested(&Operator::Eq, &left, &right)?);
if left.data_type().is_nested() {
let op = if null_equals_null {
Operator::IsNotDistinctFrom
} else {
Operator::Eq
};
return Ok(compare_op_for_nested(&op, &left, &right)?);
}
match (left.data_type(), right.data_type()) {
_ if null_equals_null => not_distinct(&left, &right),
Expand Down Expand Up @@ -1546,6 +1551,8 @@ mod tests {

use arrow::array::{Date32Array, Int32Array, UInt32Builder, UInt64Builder};
use arrow::datatypes::{DataType, Field};
use arrow_array::StructArray;
use arrow_buffer::NullBuffer;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err,
ScalarValue,
Expand Down Expand Up @@ -3844,6 +3851,104 @@ mod tests {
Ok(())
}

fn build_table_struct(
struct_name: &str,
field_name_and_values: (&str, &Vec<Option<i32>>),
nulls: Option<NullBuffer>,
) -> Arc<dyn ExecutionPlan> {
let (field_name, values) = field_name_and_values;
let inner_fields = vec![Field::new(field_name, DataType::Int32, true)];
let schema = Schema::new(vec![Field::new(
struct_name,
DataType::Struct(inner_fields.clone().into()),
nulls.is_some(),
)]);

let batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(StructArray::new(
inner_fields.into(),
vec![Arc::new(Int32Array::from(values.clone()))],
nulls,
))],
)
.unwrap();
let schema_ref = batch.schema();
Arc::new(MemoryExec::try_new(&[vec![batch]], schema_ref, None).unwrap())
}

#[tokio::test]
async fn join_on_struct() -> Result<()> {
eejbyfeldt marked this conversation as resolved.
Show resolved Hide resolved
let task_ctx = Arc::new(TaskContext::default());
let left =
build_table_struct("n1", ("a", &vec![None, Some(1), Some(2), Some(3)]), None);
let right =
build_table_struct("n2", ("a", &vec![None, Some(1), Some(2), Some(4)]), None);
let on = vec![(
Arc::new(Column::new_with_schema("n1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("n2", &right.schema())?) as _,
)];

let (columns, batches) =
join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;

assert_eq!(columns, vec!["n1", "n2"]);

let expected = [
"+--------+--------+",
"| n1 | n2 |",
"+--------+--------+",
"| {a: } | {a: } |",
"| {a: 1} | {a: 1} |",
"| {a: 2} | {a: 2} |",
"+--------+--------+",
];
assert_batches_eq!(expected, &batches);

Ok(())
}

#[tokio::test]
async fn join_on_struct_with_nulls() -> Result<()> {
eejbyfeldt marked this conversation as resolved.
Show resolved Hide resolved
let task_ctx = Arc::new(TaskContext::default());
let left =
build_table_struct("n1", ("a", &vec![None]), Some(NullBuffer::new_null(1)));
let right =
build_table_struct("n2", ("a", &vec![None]), Some(NullBuffer::new_null(1)));
let on = vec![(
Arc::new(Column::new_with_schema("n1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("n2", &right.schema())?) as _,
)];

let (_, batches_null_eq) = join_collect(
left.clone(),
right.clone(),
on.clone(),
&JoinType::Inner,
true,
task_ctx.clone(),
)
.await?;

let expected_null_eq = [
"+----+----+",
"| n1 | n2 |",
"+----+----+",
"| | |",
"+----+----+",
];
assert_batches_eq!(expected_null_eq, &batches_null_eq);

let (_, batches_null_neq) =
join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;

let expected_null_neq =
["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"];
assert_batches_eq!(expected_null_neq, &batches_null_neq);

Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down
51 changes: 51 additions & 0 deletions datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ AS VALUES
(44, 'x', 3),
(55, 'w', 3);

statement ok
CREATE TABLE join_t3(s3 struct<id INT>)
AS VALUES
(NULL),
(struct(1)),
(struct(2));

statement ok
CREATE TABLE join_t4(s4 struct<id INT>)
AS VALUES
(NULL),
(struct(2)),
(struct(3));

# Left semi anti join

statement ok
Expand Down Expand Up @@ -1336,6 +1350,43 @@ physical_plan
10)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
11)------------MemoryExec: partitions=1, partition_sizes=[1]

# Join on struct
query TT
explain select join_t3.s3, join_t4.s4
from join_t3
inner join join_t4 on join_t3.s3 = join_t4.s4
----
logical_plan
01)Inner Join: join_t3.s3 = join_t4.s4
02)--TableScan: join_t3 projection=[s3]
03)--TableScan: join_t4 projection=[s4]
physical_plan
01)CoalesceBatchesExec: target_batch_size=2
02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s3@0, s4@0)]
03)----CoalesceBatchesExec: target_batch_size=2
04)------RepartitionExec: partitioning=Hash([s3@0], 2), input_partitions=2
05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
06)----------MemoryExec: partitions=1, partition_sizes=[1]
07)----CoalesceBatchesExec: target_batch_size=2
08)------RepartitionExec: partitioning=Hash([s4@0], 2), input_partitions=2
09)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
10)----------MemoryExec: partitions=1, partition_sizes=[1]

query ??
select join_t3.s3, join_t4.s4
from join_t3
inner join join_t4 on join_t3.s3 = join_t4.s4
----
{id: 2} {id: 2}

# join with struct key and nulls
eejbyfeldt marked this conversation as resolved.
Show resolved Hide resolved
query ?
SELECT * FROM join_t3
EXCEPT
SELECT * FROM join_t4
----
{id: 1}

query TT
EXPLAIN
select count(*)
Expand Down