diff --git a/datafusion/src/physical_plan/coercion_rule/binary_rule.rs b/datafusion/src/physical_plan/coercion_rule/binary_rule.rs index 205ee2e23ca9..e616afb11ab1 100644 --- a/datafusion/src/physical_plan/coercion_rule/binary_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/binary_rule.rs @@ -21,8 +21,8 @@ use crate::arrow::datatypes::DataType; use crate::error::{DataFusionError, Result}; use crate::logical_plan::Operator; use crate::physical_plan::expressions::coercion::{ - dictionary_coercion, eq_coercion, is_numeric, like_coercion, string_coercion, - temporal_coercion, + dictionary_coercion, eq_coercion, is_dictionary, is_numeric, like_coercion, + string_coercion, temporal_coercion, }; use crate::scalar::{MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128}; @@ -77,7 +77,9 @@ pub(crate) fn coerce_types( } fn comparison_eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { - if lhs_type == rhs_type { + // can't compare dictionaries directly due to + // https://github.com/apache/arrow-rs/issues/1201 + if lhs_type == rhs_type && !is_dictionary(lhs_type) { // same type => equality is possible return Some(lhs_type.clone()); } @@ -90,7 +92,9 @@ fn comparison_order_coercion( lhs_type: &DataType, rhs_type: &DataType, ) -> Option { - if lhs_type == rhs_type { + // can't compare dictionaries directly due to + // https://github.com/apache/arrow-rs/issues/1201 + if lhs_type == rhs_type && !is_dictionary(lhs_type) { // same type => all good return Some(lhs_type.clone()); } diff --git a/datafusion/src/physical_plan/expressions/coercion.rs b/datafusion/src/physical_plan/expressions/coercion.rs index a303e8b3ef3f..7bfb567cbc7b 100644 --- a/datafusion/src/physical_plan/expressions/coercion.rs +++ b/datafusion/src/physical_plan/expressions/coercion.rs @@ -147,6 +147,10 @@ pub fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option bool { + matches!(t, DataType::Dictionary(_, _)) +} + /// Coercion rule for numerical types: The type that both lhs and rhs /// can be casted to for numerical calculation, while maintaining /// maximum precision @@ -158,8 +162,10 @@ pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option all good - if lhs_type == rhs_type { + // can't compare dictionaries directly due to + // https://github.com/apache/arrow-rs/issues/1201 + if lhs_type == rhs_type && !is_dictionary(lhs_type) { + // same type => all good return Some(lhs_type.clone()); } @@ -182,7 +188,9 @@ pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { - if lhs_type == rhs_type { + // can't compare dictionaries directly due to + // https://github.com/apache/arrow-rs/issues/1201 + if lhs_type == rhs_type && !is_dictionary(lhs_type) { // same type => equality is possible return Some(lhs_type.clone()); } diff --git a/datafusion/tests/sql/select.rs b/datafusion/tests/sql/select.rs index 6264c18d3ecd..759a45c9fca9 100644 --- a/datafusion/tests/sql/select.rs +++ b/datafusion/tests/sql/select.rs @@ -655,19 +655,28 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> { async fn query_on_string_dictionary() -> Result<()> { // Test to ensure DataFusion can operate on dictionary types // Use StringDictionary (32 bit indexes = keys) - let array = vec![Some("one"), None, Some("three")] + let d1: DictionaryArray = + vec![Some("one"), None, Some("three")].into_iter().collect(); + + let d2: DictionaryArray = vec![Some("blarg"), None, Some("three")] .into_iter() - .collect::>(); + .collect(); + + let d3: StringArray = vec![Some("XYZ"), None, Some("three")].into_iter().collect(); - let batch = - RecordBatch::try_from_iter(vec![("d1", Arc::new(array) as ArrayRef)]).unwrap(); + let batch = RecordBatch::try_from_iter(vec![ + ("d1", Arc::new(d1) as ArrayRef), + ("d2", Arc::new(d2) as ArrayRef), + ("d3", Arc::new(d3) as ArrayRef), + ]) + .unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; let mut ctx = ExecutionContext::new(); ctx.register_table("test", Arc::new(table))?; // Basic SELECT - let sql = "SELECT * FROM test"; + let sql = "SELECT d1 FROM test"; let actual = execute_to_batches(&mut ctx, sql).await; let expected = vec![ "+-------+", @@ -681,7 +690,7 @@ async fn query_on_string_dictionary() -> Result<()> { assert_batches_eq!(expected, &actual); // basic filtering - let sql = "SELECT * FROM test WHERE d1 IS NOT NULL"; + let sql = "SELECT d1 FROM test WHERE d1 IS NOT NULL"; let actual = execute_to_batches(&mut ctx, sql).await; let expected = vec![ "+-------+", @@ -693,8 +702,56 @@ async fn query_on_string_dictionary() -> Result<()> { ]; assert_batches_eq!(expected, &actual); + // comparison with constant + let sql = "SELECT d1 FROM test WHERE d1 = 'three'"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| d1 |", + "+-------+", + "| three |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + + // comparison with another dictionary column + let sql = "SELECT d1 FROM test WHERE d1 = d2"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| d1 |", + "+-------+", + "| three |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + + // order comparison with another dictionary column + let sql = "SELECT d1 FROM test WHERE d1 <= d2"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| d1 |", + "+-------+", + "| three |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + + // comparison with a non dictionary column + let sql = "SELECT d1 FROM test WHERE d1 = d3"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------+", + "| d1 |", + "+-------+", + "| three |", + "+-------+", + ]; + assert_batches_eq!(expected, &actual); + // filtering with constant - let sql = "SELECT * FROM test WHERE d1 = 'three'"; + let sql = "SELECT d1 FROM test WHERE d1 = 'three'"; let actual = execute_to_batches(&mut ctx, sql).await; let expected = vec![ "+-------+", @@ -719,6 +776,20 @@ async fn query_on_string_dictionary() -> Result<()> { ]; assert_batches_eq!(expected, &actual); + // Expression evaluation with two dictionaries + let sql = "SELECT concat(d1, d2) FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+-------------------------+", + "| concat(test.d1,test.d2) |", + "+-------------------------+", + "| oneblarg |", + "| |", + "| threethree |", + "+-------------------------+", + ]; + assert_batches_eq!(expected, &actual); + // aggregation let sql = "SELECT COUNT(d1) FROM test"; let actual = execute_to_batches(&mut ctx, sql).await;