Skip to content

Commit

Permalink
Implement comparisons on nested data types such that distinct/except …
Browse files Browse the repository at this point in the history
…would work

This relies on newer functionality in arrow 52 and allows
DataFrame.except() to properly work on schemas with structs and lists

Closes apache#10749
  • Loading branch information
rtyler committed Jun 25, 2024
1 parent 8b244ee commit b028bbe
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
76 changes: 76 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3445,6 +3445,82 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_except_nested_struct() -> Result<()> {
use arrow::array::StructArray;

let nested_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, true),
Field::new("lat", DataType::Int32, true),
Field::new("long", DataType::Int32, true),
]));
let schema = Arc::new(Schema::new(vec![
Field::new("value", DataType::Int32, true),
Field::new(
"nested",
DataType::Struct(nested_schema.fields.clone()),
true,
),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])),
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("id", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("lat", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("long", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
])),
],
)
.unwrap();

let updated_batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(12), Some(3)])),
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("id", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("lat", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("long", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
])),
],
)
.unwrap();

let ctx = SessionContext::new();
let before = ctx.read_batch(batch).expect("Failed to make DataFrame");
let after = ctx
.read_batch(updated_batch)
.expect("Failed to make DataFrame");

let diff = before
.except(after)
.expect("Failed to except")
.collect()
.await?;
assert_eq!(diff.len(), 1);
Ok(())
}

#[tokio::test]
async fn nested_explain_should_fail() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
14 changes: 13 additions & 1 deletion datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ use arrow::array::{
Array, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, UInt32Array,
UInt64Array,
};
use arrow::buffer::NullBuffer;
use arrow::compute::kernels::cmp::{eq, not_distinct};
use arrow::compute::{and, concat_batches, take, FilterBuilder};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow::util::bit_util;
use arrow_array::cast::downcast_array;
use arrow_schema::ArrowError;
use arrow_ord::ord::make_comparator;
use arrow_schema::{ArrowError, SortOptions};
use datafusion_common::utils::memory::estimate_memory_size;
use datafusion_common::{
internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError,
Expand Down Expand Up @@ -1210,6 +1212,16 @@ fn eq_dyn_null(
right: &dyn Array,
null_equals_null: bool,
) -> Result<BooleanArray, ArrowError> {
// Nested datatypes cannot use the underlying not_distinct function and must use a special
// implementation
// <https://github.com/apache/datafusion/issues/10749>
if left.data_type().is_nested() && null_equals_null {
let cmp = make_comparator(left, right, SortOptions::default())?;
let len = left.len().min(right.len());
let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
let nulls = NullBuffer::union(left.nulls(), right.nulls());
return Ok(BooleanArray::new(values, nulls));
}
match (left.data_type(), right.data_type()) {
_ if null_equals_null => not_distinct(&left, &right),
_ => eq(&left, &right),
Expand Down

0 comments on commit b028bbe

Please sign in to comment.