Skip to content

Commit

Permalink
Support dictionary / dictionary comparisons by unpacking them
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jan 18, 2022
1 parent c549d51 commit c4db8f3
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 13 deletions.
8 changes: 4 additions & 4 deletions datafusion/src/physical_plan/coercion_rule/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use crate::arrow::datatypes::DataType;
use crate::error::{DataFusionError, Result};
use crate::logical_plan::Operator;
use crate::physical_plan::expressions::coercion::{
dictionary_coercion, eq_coercion, is_numeric, like_coercion, string_coercion,
temporal_coercion,
dictionary_coercion, eq_coercion, is_dictionary, is_numeric, like_coercion,
string_coercion, temporal_coercion,
};
use crate::scalar::{MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128};

Expand Down Expand Up @@ -77,7 +77,7 @@ pub(crate) fn coerce_types(
}

fn comparison_eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
if lhs_type == rhs_type {
if lhs_type == rhs_type && !is_dictionary(lhs_type) {
// same type => equality is possible
return Some(lhs_type.clone());
}
Expand All @@ -90,7 +90,7 @@ fn comparison_order_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
if lhs_type == rhs_type {
if lhs_type == rhs_type && !is_dictionary(lhs_type) {
// same type => all good
return Some(lhs_type.clone());
}
Expand Down
8 changes: 6 additions & 2 deletions datafusion/src/physical_plan/expressions/coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ pub fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Dat
}
}

pub(crate) fn is_dictionary(t: &DataType) -> bool {
matches!(t, DataType::Dictionary(_, _))
}

/// Coercion rule for numerical types: The type that both lhs and rhs
/// can be casted to for numerical calculation, while maintaining
/// maximum precision
Expand All @@ -159,7 +163,7 @@ pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
};

// same type => all good
if lhs_type == rhs_type {
if lhs_type == rhs_type && !is_dictionary(lhs_type) {
return Some(lhs_type.clone());
}

Expand All @@ -182,7 +186,7 @@ pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da

// coercion rules for equality operations. This is a superset of all numerical coercion rules.
pub fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
if lhs_type == rhs_type {
if lhs_type == rhs_type && !is_dictionary(lhs_type) {
// same type => equality is possible
return Some(lhs_type.clone());
}
Expand Down
85 changes: 78 additions & 7 deletions datafusion/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -655,19 +655,28 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
async fn query_on_string_dictionary() -> Result<()> {
// Test to ensure DataFusion can operate on dictionary types
// Use StringDictionary (32 bit indexes = keys)
let array = vec![Some("one"), None, Some("three")]
let d1: DictionaryArray<Int32Type> =
vec![Some("one"), None, Some("three")].into_iter().collect();

let d2: DictionaryArray<Int32Type> = vec![Some("blarg"), None, Some("three")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>();
.collect();

let d3: StringArray = vec![Some("XYZ"), None, Some("three")].into_iter().collect();

let batch =
RecordBatch::try_from_iter(vec![("d1", Arc::new(array) as ArrayRef)]).unwrap();
let batch = RecordBatch::try_from_iter(vec![
("d1", Arc::new(d1) as ArrayRef),
("d2", Arc::new(d2) as ArrayRef),
("d3", Arc::new(d3) as ArrayRef),
])
.unwrap();

let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
let mut ctx = ExecutionContext::new();
ctx.register_table("test", Arc::new(table))?;

// Basic SELECT
let sql = "SELECT * FROM test";
let sql = "SELECT d1 FROM test";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+-------+",
Expand All @@ -681,7 +690,7 @@ async fn query_on_string_dictionary() -> Result<()> {
assert_batches_eq!(expected, &actual);

// basic filtering
let sql = "SELECT * FROM test WHERE d1 IS NOT NULL";
let sql = "SELECT d1 FROM test WHERE d1 IS NOT NULL";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+-------+",
Expand All @@ -693,8 +702,56 @@ async fn query_on_string_dictionary() -> Result<()> {
];
assert_batches_eq!(expected, &actual);

// comparison with constant
let sql = "SELECT d1 FROM test WHERE d1 = 'three'";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+-------+",
"| d1 |",
"+-------+",
"| three |",
"+-------+",
];
assert_batches_eq!(expected, &actual);

// comparison with another dictionary column
let sql = "SELECT d1 FROM test WHERE d1 = d2";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+-------+",
"| d1 |",
"+-------+",
"| three |",
"+-------+",
];
assert_batches_eq!(expected, &actual);

// order comparison with another dictionary column
let sql = "SELECT d1 FROM test WHERE d1 <= d2";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+-------+",
"| d1 |",
"+-------+",
"| three |",
"+-------+",
];
assert_batches_eq!(expected, &actual);

// comparison with a non dictionary column
let sql = "SELECT d1 FROM test WHERE d1 = d3";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+-------+",
"| d1 |",
"+-------+",
"| three |",
"+-------+",
];
assert_batches_eq!(expected, &actual);

// filtering with constant
let sql = "SELECT * FROM test WHERE d1 = 'three'";
let sql = "SELECT d1 FROM test WHERE d1 = 'three'";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+-------+",
Expand All @@ -719,6 +776,20 @@ async fn query_on_string_dictionary() -> Result<()> {
];
assert_batches_eq!(expected, &actual);

// Expression evaluation with two dictionaries
let sql = "SELECT concat(d1, d2) FROM test";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+-------------------------+",
"| concat(test.d1,test.d2) |",
"+-------------------------+",
"| oneblarg |",
"| |",
"| threethree |",
"+-------------------------+",
];
assert_batches_eq!(expected, &actual);

// aggregation
let sql = "SELECT COUNT(d1) FROM test";
let actual = execute_to_batches(&mut ctx, sql).await;
Expand Down

0 comments on commit c4db8f3

Please sign in to comment.