Skip to content

Commit

Permalink
fix: coalesce function should return correct data type (#9459)
Browse files Browse the repository at this point in the history
* fix: Remove supported coalesce types

* Use comparison_coercion

* Fix test

* Fix

* Add comment

* More

* fix
  • Loading branch information
viirya authored Mar 7, 2024
1 parent 8d58b03 commit 37b7375
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 35 deletions.
20 changes: 13 additions & 7 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ use std::sync::{Arc, OnceLock};
use crate::signature::TIMEZONE_WILDCARD;
use crate::type_coercion::binary::get_wider_type;
use crate::type_coercion::functions::data_types;
use crate::{
conditional_expressions, FuncMonotonicity, Signature, TypeSignature, Volatility,
};
use crate::{FuncMonotonicity, Signature, TypeSignature, Volatility};

use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit};
use datafusion_common::{exec_err, plan_err, DataFusionError, Result};
Expand Down Expand Up @@ -899,10 +897,9 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::ConcatWithSeparator => {
Signature::variadic(vec![Utf8], self.volatility())
}
BuiltinScalarFunction::Coalesce => Signature::variadic(
conditional_expressions::SUPPORTED_COALESCE_TYPES.to_vec(),
self.volatility(),
),
BuiltinScalarFunction::Coalesce => {
Signature::variadic_equal(self.volatility())
}
BuiltinScalarFunction::SHA224
| BuiltinScalarFunction::SHA256
| BuiltinScalarFunction::SHA384
Expand Down Expand Up @@ -1575,4 +1572,13 @@ mod tests {
assert_eq!(func_from_str, *func_original);
}
}

#[test]
fn test_coalesce_return_types() {
let coalesce = BuiltinScalarFunction::Coalesce;
let return_type = coalesce
.return_type(&[DataType::Date32, DataType::Date32])
.unwrap();
assert_eq!(return_type, DataType::Date32);
}
}
19 changes: 0 additions & 19 deletions datafusion/expr/src/conditional_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,6 @@ use arrow::datatypes::DataType;
use datafusion_common::{plan_err, DFSchema, Result};
use std::collections::HashSet;

/// Currently supported types by the coalesce function.
/// The order of these types correspond to the order on which coercion applies
/// This should thus be from least informative to most informative
pub static SUPPORTED_COALESCE_TYPES: &[DataType] = &[
DataType::Boolean,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
DataType::Utf8,
DataType::LargeUtf8,
];

/// Helper struct for building [Expr::Case]
pub struct CaseBuilder {
expr: Option<Box<Expr>>,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ fn string_temporal_coercion(

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
/// where one both are numeric
fn comparison_binary_numeric_coercion(
pub(crate) fn comparison_binary_numeric_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
Expand Down
27 changes: 21 additions & 6 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use arrow::{
use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims};
use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result};

use super::binary::comparison_coercion;
use super::binary::{comparison_binary_numeric_coercion, comparison_coercion};

/// Performs type coercion for function arguments.
///
Expand Down Expand Up @@ -187,6 +187,10 @@ fn get_valid_types(
let new_type = current_types.iter().skip(1).try_fold(
current_types.first().unwrap().clone(),
|acc, x| {
// The coerced types found by `comparison_coercion` are not guaranteed to be
// coercible for the arguments. `comparison_coercion` returns more loose
// types that can be coerced to both `acc` and `x` for comparison purpose.
// See `maybe_data_types` for the actual coercion.
let coerced_type = comparison_coercion(&acc, x);
if let Some(coerced_type) = coerced_type {
Ok(coerced_type)
Expand Down Expand Up @@ -276,9 +280,9 @@ fn maybe_data_types(
if current_type == valid_type {
new_type.push(current_type.clone())
} else {
// attempt to coerce
if let Some(valid_type) = coerced_from(valid_type, current_type) {
new_type.push(valid_type)
// attempt to coerce.
if let Some(coerced_type) = coerced_from(valid_type, current_type) {
new_type.push(coerced_type)
} else {
// not possible
return None;
Expand Down Expand Up @@ -427,8 +431,19 @@ fn coerced_from<'a>(
Some(type_into.clone())
}

// cannot coerce
_ => None,
// More coerce rules.
// Note that not all rules in `comparison_coercion` can be reused here.
// For example, all numeric types can be coerced into Utf8 for comparison,
// but not for function arguments.
_ => comparison_binary_numeric_coercion(type_into, type_from).and_then(
|coerced_type| {
if *type_into == coerced_type {
Some(coerced_type)
} else {
None
}
},
),
}
}

Expand Down
4 changes: 3 additions & 1 deletion datafusion/physical-expr/src/math_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use std::sync::Arc;
use arrow::array::ArrayRef;
use arrow::array::{BooleanArray, Float32Array, Float64Array, Int64Array};
use arrow::datatypes::DataType;
use arrow_array::Array;
use rand::{thread_rng, Rng};

use datafusion_common::ScalarValue::{Float32, Int64};
Expand Down Expand Up @@ -92,8 +93,9 @@ macro_rules! downcast_arg {
($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
$ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
DataFusionError::Internal(format!(
"could not cast {} to {}",
"could not cast {} from {} to {}",
$NAME,
$ARG.data_type(),
type_name::<$ARRAY_TYPE>()
))
})?
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1782,7 +1782,7 @@ AS VALUES
('BB', 6, 1),
('BB', 6, 1);

query TIR
query TII
select col1, col2, coalesce(sum_col3, 0) as sum_col3
from (select distinct col2 from tbl) AS q1
cross join (select distinct col1 from tbl) AS q2
Expand Down
45 changes: 45 additions & 0 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1841,6 +1841,51 @@ SELECT COALESCE(c1 * c2, 0) FROM test
statement ok
drop table test

# coalesce date32

statement ok
CREATE TABLE test(
d1_date DATE,
d2_date DATE,
d3_date DATE
) as VALUES
('2022-12-12','2022-12-12','2022-12-12'),
(NULL,'2022-12-11','2022-12-12'),
('2022-12-12','2022-12-10','2022-12-12'),
('2022-12-12',NULL,'2022-12-12'),
('2022-12-12','2022-12-8','2022-12-12'),
('2022-12-12','2022-12-7',NULL),
('2022-12-12',NULL,'2022-12-12'),
(NULL,'2022-12-5','2022-12-12')
;

query D
SELECT COALESCE(d1_date, d2_date, d3_date) FROM test
----
2022-12-12
2022-12-11
2022-12-12
2022-12-12
2022-12-12
2022-12-12
2022-12-12
2022-12-05

query T
SELECT arrow_typeof(COALESCE(d1_date, d2_date, d3_date)) FROM test
----
Date32
Date32
Date32
Date32
Date32
Date32
Date32
Date32

statement ok
drop table test

statement ok
CREATE TABLE test(
i32 INT,
Expand Down

0 comments on commit 37b7375

Please sign in to comment.