Skip to content

Commit

Permalink
support decimal data type for the optimizer rule of PreCastLitInCompa…
Browse files Browse the repository at this point in the history
…risonExpressions (#3245)

* support decimal for the PreCastLitInComparisonExpressions rule

* address comments

* fix the lint

* add comments
  • Loading branch information
liukun4515 authored Aug 27, 2022
1 parent 90a0e7c commit b1db5ff
Showing 1 changed file with 158 additions and 79 deletions.
237 changes: 158 additions & 79 deletions datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
//! Pre-cast literal binary comparison rule can be only used to the binary comparison expr.
//! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr.
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use arrow::datatypes::{
DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::utils::from_plan;
Expand Down Expand Up @@ -99,7 +101,6 @@ impl ExprRewriter for PreCastLitExprRewriter {
}

fn mutate(&mut self, expr: Expr) -> Result<Expr> {
// traverse the expr by dfs
match &expr {
Expr::BinaryExpr { left, op, right } => {
let left = left.as_ref().clone();
Expand All @@ -121,32 +122,19 @@ impl ExprRewriter for PreCastLitExprRewriter {
(Expr::Literal(_), Expr::Literal(_)) => {
// do nothing
}
(Expr::Literal(left_lit_value), _)
if can_integer_literal_cast_to_type(
left_lit_value,
&right_type,
)? =>
{
// cast the left literal to the right type
return Ok(binary_expr(
cast_to_other_scalar_expr(left_lit_value, &right_type)?,
*op,
right,
));
(Expr::Literal(left_lit_value), _) => {
let casted_scalar_value =
try_cast_literal_to_type(left_lit_value, &right_type)?;
if let Some(value) = casted_scalar_value {
return Ok(binary_expr(lit(value), *op, right));
}
}
(_, Expr::Literal(right_lit_value))
if can_integer_literal_cast_to_type(
right_lit_value,
&left_type,
)
.unwrap() =>
{
// cast the right literal to the left type
return Ok(binary_expr(
left,
*op,
cast_to_other_scalar_expr(right_lit_value, &left_type)?,
));
(_, Expr::Literal(right_lit_value)) => {
let casted_scalar_value =
try_cast_literal_to_type(right_lit_value, &left_type)?;
if let Some(value) = casted_scalar_value {
return Ok(binary_expr(left, *op, lit(value)));
}
}
(_, _) => {
// do nothing
Expand All @@ -164,43 +152,6 @@ impl ExprRewriter for PreCastLitExprRewriter {
}
}

fn cast_to_other_scalar_expr(
origin_value: &ScalarValue,
target_type: &DataType,
) -> Result<Expr> {
// null case
if origin_value.is_null() {
// if the origin value is null, just convert to another type of null value
// The target type must be satisfied `is_support_data_type` method, we can unwrap safely
return Ok(lit(ScalarValue::try_from(target_type).unwrap()));
}
// no null case
let value: i64 = match origin_value {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
ScalarValue::Int32(Some(v)) => *v as i64,
ScalarValue::Int64(Some(v)) => *v as i64,
other_value => {
return Err(DataFusionError::Internal(format!(
"Invalid type and value {}",
other_value
)))
}
};
Ok(lit(match target_type {
DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
DataType::Int64 => ScalarValue::Int64(Some(value)),
other_type => {
return Err(DataFusionError::Internal(format!(
"Invalid target data type {:?}",
other_type
)))
}
}))
}

fn is_comparison_op(op: &Operator) -> bool {
matches!(
op,
Expand All @@ -214,47 +165,112 @@ fn is_comparison_op(op: &Operator) -> bool {
}

fn is_support_data_type(data_type: &DataType) -> bool {
// TODO support decimal with other data type
matches!(
data_type,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Decimal128(_, _)
)
}

fn can_integer_literal_cast_to_type(
integer_lit_value: &ScalarValue,
fn try_cast_literal_to_type(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Result<bool> {
if integer_lit_value.is_null() {
) -> Result<Option<ScalarValue>> {
if lit_value.is_null() {
// null value can be cast to any type of null value
return Ok(true);
return Ok(Some(ScalarValue::try_from(target_type)?));
}
let mul = match target_type {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => 1_i128,
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {:?}",
other_type
)));
}
};
let (target_min, target_max) = match target_type {
DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
DataType::Decimal128(precision, _) => (
// Different precision for decimal128 can store different range of value.
// For example, the precision is 3, the max of value is `999` and the min
// value is `-999`
MIN_DECIMAL_FOR_EACH_PRECISION[*precision - 1],
MAX_DECIMAL_FOR_EACH_PRECISION[*precision - 1],
),
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {:?}",
other_type
)))
)));
}
};
let lit_value = match integer_lit_value {
ScalarValue::Int8(Some(v)) => *v as i128,
ScalarValue::Int16(Some(v)) => *v as i128,
ScalarValue::Int32(Some(v)) => *v as i128,
ScalarValue::Int64(Some(v)) => *v as i128,
let lit_value_target_type = match lit_value {
ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Decimal128(Some(v), _, scale) => {
let lit_scale_mul = 10_i128.pow(*scale as u32);
if mul >= lit_scale_mul {
// Example:
// lit is decimal(123,3,2)
// target type is decimal(5,3)
// the lit can be converted to the decimal(1230,5,3)
(*v).checked_mul(mul / lit_scale_mul)
} else if (*v) % (lit_scale_mul / mul) == 0 {
// Example:
// lit is decimal(123000,10,3)
// target type is int32: the lit can be converted to INT32(123)
// target type is decimal(10,2): the lit can be converted to decimal(12300,10,2)
Some(*v / (lit_scale_mul / mul))
} else {
// can't convert the lit decimal to the target data type
None
}
}
other_value => {
return Err(DataFusionError::Internal(format!(
"Invalid literal value {:?}",
other_value
)))
)));
}
};

Ok(lit_value >= target_min && lit_value <= target_max)
match lit_value_target_type {
None => Ok(None),
Some(value) => {
if value >= target_min && value <= target_max {
// the value casted from lit to the target type is in the range of target type.
// return the target type of scalar value
let result_scalar = match target_type {
DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(value), *p, *s)
}
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {:?}",
other_type
)));
}
};
Ok(Some(result_scalar))
} else {
Ok(None)
}
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -307,6 +323,67 @@ mod tests {
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
}

#[test]
fn test_not_cast_with_decimal_lit_comparison() {
let schema = expr_test_schema();
// integer to decimal: value is out of the bounds of the decimal
// c3 = INT64(100000000000000000)
let expr_eq = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000))));
let expected = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000))));
assert_eq!(optimize_test(expr_eq, &schema), expected);
// c4 = INT64(1000) will overflow the i128
let expr_eq = col("c4").eq(lit(ScalarValue::Int64(Some(1000))));
let expected = col("c4").eq(lit(ScalarValue::Int64(Some(1000))));
assert_eq!(optimize_test(expr_eq, &schema), expected);

// decimal to decimal: value will lose the scale when convert to the target data type
// c3 = DECIMAL(12340,20,4)
let expr_eq = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4)));
let expected = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4)));
assert_eq!(optimize_test(expr_eq, &schema), expected);

// decimal to integer
// c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type
let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1)));
let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1)));
assert_eq!(optimize_test(expr_eq, &schema), expected);
// c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type
let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2)));
let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2)));
assert_eq!(optimize_test(expr_eq, &schema), expected);
}

#[test]
fn test_pre_cast_with_decimal_lit_comparison() {
let schema = expr_test_schema();
// integer to decimal
// c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2));
let expr_lt = col("c3").lt(lit(ScalarValue::Int64(Some(16))));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(1600), 18, 2)));
assert_eq!(optimize_test(expr_lt, &schema), expected);

// c3 < INT64(NULL)
let c1_lt_lit_null = col("c3").lt(lit(ScalarValue::Int64(None)));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(None, 18, 2)));
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);

// decimal to decimal
// c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2)
let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 10, 0)));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(12300), 18, 2)));
assert_eq!(optimize_test(expr_lt, &schema), expected);
// c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2)
let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3)));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 18, 2)));
assert_eq!(optimize_test(expr_lt, &schema), expected);

// decimal to integer
// c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123)
let expr_lt = col("c1").lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2)));
let expected = col("c1").lt(lit(ScalarValue::Int32(Some(123))));
assert_eq!(optimize_test(expr_lt, &schema), expected);
}

#[test]
fn aliased() {
let schema = expr_test_schema();
Expand Down Expand Up @@ -344,6 +421,8 @@ mod tests {
vec![
DFField::new(None, "c1", DataType::Int32, false),
DFField::new(None, "c2", DataType::Int64, false),
DFField::new(None, "c3", DataType::Decimal128(18, 2), false),
DFField::new(None, "c4", DataType::Decimal128(38, 37), false),
],
HashMap::new(),
)
Expand Down

0 comments on commit b1db5ff

Please sign in to comment.