Skip to content

Commit

Permalink
ARROW-7364: [Rust][DataFusion] Add cast options to cast kernel and TR…
Browse files Browse the repository at this point in the history
…Y_CAST to DataFusion

@andygrove @alamb @nevi-me

This is my WIP implementation of adding `CastOptions` to the Rust Kernel and changing the default `CastOptions` for DataFusion to be `safe = false`.

From here we have two options:
1. use some sort of feature flag (like Spark) to set the default (see `spark.sql.ansi.enabled` [here](https://spark.apache.org/docs/latest/configuration.html#runtime-sql-configuration)).
2. add a `safe_cast` `expr` to do the same operation but override the default.

Thoughts?

Closes #9682 from seddonm1/safe-cast

Lead-authored-by: Mike Seddon <seddonm1@gmail.com>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
Signed-off-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
seddonm1 and alamb committed Apr 6, 2021
1 parent e2c22a1 commit 81f6521
Show file tree
Hide file tree
Showing 14 changed files with 910 additions and 210 deletions.
635 changes: 464 additions & 171 deletions rust/arrow/src/compute/kernels/cast.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion rust/datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ahash = "0.7"
hashbrown = "0.11"
arrow = { path = "../arrow", version = "4.0.0-SNAPSHOT", features = ["prettyprint"] }
parquet = { path = "../parquet", version = "4.0.0-SNAPSHOT", features = ["arrow"] }
sqlparser = "0.8.0"
sqlparser = "0.9.0"
clap = "2.33"
rustyline = {version = "7.0", optional = true}
paste = "^1.0"
Expand Down
2 changes: 2 additions & 0 deletions rust/datafusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ DataFusion also includes a simple command-line interactive SQL utility. See the
- [x] Limit
- [x] Aggregate
- [x] Common math functions
- [x] cast
- [x] try_cast
- Postgres compatible String functions
- [x] ascii
- [x] bit_length
Expand Down
25 changes: 24 additions & 1 deletion rust/datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,22 @@ pub enum Expr {
/// Optional "else" expression
else_expr: Option<Box<Expr>>,
},
/// Casts the expression to a given type. This expression is guaranteed to have a fixed type.
/// Casts the expression to a given type and will return a runtime error if the expression cannot be cast.
/// This expression is guaranteed to have a fixed type.
Cast {
/// The expression being cast
expr: Box<Expr>,
/// The `DataType` the expression will yield
data_type: DataType,
},
/// Casts the expression to a given type and will return a null value if the expression cannot be cast.
/// This expression is guaranteed to have a fixed type.
TryCast {
/// The expression being cast
expr: Box<Expr>,
/// The `DataType` the expression will yield
data_type: DataType,
},
/// A sort expression, that can be used to sort values.
Sort {
/// The expression to sort on
Expand Down Expand Up @@ -220,6 +229,7 @@ impl Expr {
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
Expr::Cast { data_type, .. } => Ok(data_type.clone()),
Expr::TryCast { data_type, .. } => Ok(data_type.clone()),
Expr::ScalarUDF { fun, args } => {
let data_types = args
.iter()
Expand Down Expand Up @@ -303,6 +313,7 @@ impl Expr {
}
}
Expr::Cast { expr, .. } => expr.nullable(input_schema),
Expr::TryCast { .. } => Ok(true),
Expr::ScalarFunction { .. } => Ok(true),
Expr::ScalarUDF { .. } => Ok(true),
Expr::AggregateFunction { .. } => Ok(true),
Expand Down Expand Up @@ -552,6 +563,7 @@ impl Expr {
}
}
Expr::Cast { expr, .. } => expr.accept(visitor),
Expr::TryCast { expr, .. } => expr.accept(visitor),
Expr::Sort { expr, .. } => expr.accept(visitor),
Expr::ScalarFunction { args, .. } => args
.iter()
Expand Down Expand Up @@ -671,6 +683,10 @@ impl Expr {
expr: rewrite_boxed(expr, rewriter)?,
data_type,
},
Expr::TryCast { expr, data_type } => Expr::TryCast {
expr: rewrite_boxed(expr, rewriter)?,
data_type,
},
Expr::Sort {
expr,
asc,
Expand Down Expand Up @@ -1197,6 +1213,9 @@ impl fmt::Debug for Expr {
Expr::Cast { expr, data_type } => {
write!(f, "CAST({:?} AS {:?})", expr, data_type)
}
Expr::TryCast { expr, data_type } => {
write!(f, "TRY_CAST({:?} AS {:?})", expr, data_type)
}
Expr::Not(expr) => write!(f, "NOT {:?}", expr),
Expr::Negative(expr) => write!(f, "(- {:?})", expr),
Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr),
Expand Down Expand Up @@ -1315,6 +1334,10 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
let expr = create_name(expr, input_schema)?;
Ok(format!("CAST({} AS {:?})", expr, data_type))
}
Expr::TryCast { expr, data_type } => {
let expr = create_name(expr, input_schema)?;
Ok(format!("TRY_CAST({} AS {:?})", expr, data_type))
}
Expr::Not(expr) => {
let expr = create_name(expr, input_schema)?;
Ok(format!("NOT {}", expr))
Expand Down
6 changes: 6 additions & 0 deletions rust/datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> {
Expr::Between { .. } => {}
Expr::Case { .. } => {}
Expr::Cast { .. } => {}
Expr::TryCast { .. } => {}
Expr::Sort { .. } => {}
Expr::ScalarFunction { .. } => {}
Expr::ScalarUDF { .. } => {}
Expand Down Expand Up @@ -261,6 +262,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
Ok(expr_list)
}
Expr::Cast { expr, .. } => Ok(vec![expr.as_ref().to_owned()]),
Expr::TryCast { expr, .. } => Ok(vec![expr.as_ref().to_owned()]),
Expr::Column(_) => Ok(vec![]),
Expr::Alias(expr, ..) => Ok(vec![expr.as_ref().to_owned()]),
Expr::Literal(_) => Ok(vec![]),
Expand Down Expand Up @@ -357,6 +359,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
expr: Box::new(expressions[0].clone()),
data_type: data_type.clone(),
}),
Expr::TryCast { data_type, .. } => Ok(Expr::TryCast {
expr: Box::new(expressions[0].clone()),
data_type: data_type.clone(),
}),
Expr::Alias(_, alias) => {
Ok(Expr::Alias(Box::new(expressions[0].clone()), alias.clone()))
}
Expand Down
6 changes: 3 additions & 3 deletions rust/datafusion/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use arrow::record_batch::RecordBatch;

use crate::error::{DataFusionError, Result};
use crate::logical_plan::Operator;
use crate::physical_plan::expressions::cast;
use crate::physical_plan::expressions::try_cast;
use crate::physical_plan::{ColumnarValue, PhysicalExpr};
use crate::scalar::ScalarValue;

Expand Down Expand Up @@ -547,8 +547,8 @@ fn binary_cast(
let cast_type = common_binary_type(lhs_type, op, rhs_type)?;

Ok((
cast(lhs, input_schema, cast_type.clone())?,
cast(rhs, input_schema, cast_type)?,
try_cast(lhs, input_schema, cast_type.clone())?,
try_cast(rhs, input_schema, cast_type)?,
))
}

Expand Down
116 changes: 97 additions & 19 deletions rust/datafusion/src/physical_plan/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,37 @@ use crate::physical_plan::PhysicalExpr;
use crate::scalar::ScalarValue;
use arrow::compute;
use arrow::compute::kernels;
use arrow::compute::CastOptions;
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use compute::can_cast_types;

/// CAST expression casts an expression to a specific data type
/// provide Datafusion default cast options
pub const DEFAULT_DATAFUSION_CAST_OPTIONS: CastOptions = CastOptions { safe: false };

/// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast
#[derive(Debug)]
pub struct CastExpr {
/// The expression to cast
expr: Arc<dyn PhysicalExpr>,
/// The data type to cast to
cast_type: DataType,
/// Cast options
cast_options: CastOptions,
}

impl CastExpr {
/// Create a new CastExpr
pub fn new(expr: Arc<dyn PhysicalExpr>, cast_type: DataType) -> Self {
Self { expr, cast_type }
pub fn new(
expr: Arc<dyn PhysicalExpr>,
cast_type: DataType,
cast_options: CastOptions,
) -> Self {
Self {
expr,
cast_type,
cast_options,
}
}

/// The expression to cast
Expand Down Expand Up @@ -78,13 +92,20 @@ impl PhysicalExpr for CastExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let value = self.expr.evaluate(batch)?;
match value {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(kernels::cast::cast(
&array,
&self.cast_type,
)?)),
ColumnarValue::Array(array) => {
Ok(ColumnarValue::Array(kernels::cast::cast_with_options(
&array,
&self.cast_type,
&self.cast_options,
)?))
}
ColumnarValue::Scalar(scalar) => {
let scalar_array = scalar.to_array();
let cast_array = kernels::cast::cast(&scalar_array, &self.cast_type)?;
let cast_array = kernels::cast::cast_with_options(
&scalar_array,
&self.cast_type,
&self.cast_options,
)?;
let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?;
Ok(ColumnarValue::Scalar(cast_scalar))
}
Expand All @@ -96,16 +117,17 @@ impl PhysicalExpr for CastExpr {
/// `cast_type`, if any casting is needed.
///
/// Note that such casts may lose type information
pub fn cast(
pub fn cast_with_options(
expr: Arc<dyn PhysicalExpr>,
input_schema: &Schema,
cast_type: DataType,
cast_options: CastOptions,
) -> Result<Arc<dyn PhysicalExpr>> {
let expr_type = expr.data_type(input_schema)?;
if expr_type == cast_type {
Ok(expr.clone())
} else if can_cast_types(&expr_type, &cast_type) {
Ok(Arc::new(CastExpr::new(expr, cast_type)))
Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
} else {
Err(DataFusionError::Internal(format!(
"Unsupported CAST from {:?} to {:?}",
Expand All @@ -114,14 +136,31 @@ pub fn cast(
}
}

/// Return a PhysicalExpression representing `expr` casted to
/// `cast_type`, if any casting is needed.
///
/// Note that such casts may lose type information
pub fn cast(
expr: Arc<dyn PhysicalExpr>,
input_schema: &Schema,
cast_type: DataType,
) -> Result<Arc<dyn PhysicalExpr>> {
cast_with_options(
expr,
input_schema,
cast_type,
DEFAULT_DATAFUSION_CAST_OPTIONS,
)
}

#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
use crate::physical_plan::expressions::col;
use arrow::array::{StringArray, Time64NanosecondArray};
use arrow::{
array::{Int32Array, Int64Array, TimestampNanosecondArray, UInt32Array},
array::{Array, Int32Array, Int64Array, TimestampNanosecondArray, UInt32Array},
datatypes::*,
};

Expand All @@ -132,14 +171,14 @@ mod tests {
// 4. verify that the resulting expression is of type B
// 5. verify that the resulting values are downcastable and correct
macro_rules! generic_test_cast {
($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{
($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $CAST_OPTIONS:expr) => {{
let schema = Schema::new(vec![Field::new("a", $A_TYPE, false)]);
let a = $A_ARRAY::from($A_VEC);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

// verify that we can construct the expression
let expression = cast(col("a"), &schema, $TYPE)?;
let expression = cast_with_options(col("a"), &schema, $TYPE, $CAST_OPTIONS)?;

// verify that its display is correct
assert_eq!(format!("CAST(a AS {:?})", $TYPE), format!("{}", expression));
Expand All @@ -164,7 +203,10 @@ mod tests {

// verify that the result itself is correct
for (i, x) in $VEC.iter().enumerate() {
assert_eq!(result.value(i), *x);
match x {
Some(x) => assert_eq!(result.value(i), *x),
None => assert!(!result.is_valid(i)),
}
}
}};
}
Expand All @@ -177,7 +219,14 @@ mod tests {
vec![1, 2, 3, 4, 5],
UInt32Array,
DataType::UInt32,
vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]
vec![
Some(1_u32),
Some(2_u32),
Some(3_u32),
Some(4_u32),
Some(5_u32)
],
DEFAULT_DATAFUSION_CAST_OPTIONS
);
Ok(())
}
Expand All @@ -190,25 +239,28 @@ mod tests {
vec![1, 2, 3, 4, 5],
StringArray,
DataType::Utf8,
vec!["1", "2", "3", "4", "5"]
vec![Some("1"), Some("2"), Some("3"), Some("4"), Some("5")],
DEFAULT_DATAFUSION_CAST_OPTIONS
);
Ok(())
}

#[allow(clippy::redundant_clone)]
#[test]
fn test_cast_i64_t64() -> Result<()> {
let original = vec![1, 2, 3, 4, 5];
let expected: Vec<i64> = original
let expected: Vec<Option<i64>> = original
.iter()
.map(|i| Time64NanosecondArray::from(vec![*i]).value(0))
.map(|i| Some(Time64NanosecondArray::from(vec![*i]).value(0)))
.collect();
generic_test_cast!(
Int64Array,
DataType::Int64,
original.clone(),
TimestampNanosecondArray,
DataType::Timestamp(TimeUnit::Nanosecond, None),
expected
expected,
DEFAULT_DATAFUSION_CAST_OPTIONS
);
Ok(())
}
Expand All @@ -217,7 +269,33 @@ mod tests {
fn invalid_cast() {
// Ensure a useful error happens at plan time if invalid casts are used
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let result = cast(col("a"), &schema, DataType::LargeBinary);
result.expect_err("expected Invalid CAST");
}

#[test]
fn invalid_cast_with_options_error() -> Result<()> {
// Ensure a useful error happens at plan time if invalid casts are used
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
let a = StringArray::from(vec!["9.1"]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
let expression = cast_with_options(
col("a"),
&schema,
DataType::Int32,
DEFAULT_DATAFUSION_CAST_OPTIONS,
)?;
let result = expression.evaluate(&batch);

match result {
Ok(_) => panic!("expected error"),
Err(e) => {
assert!(e.to_string().contains(
"Cast error: Cannot cast string '9.1' to value of arrow::datatypes::types::Int32Type type"
))
}
}
Ok(())
}
}
4 changes: 3 additions & 1 deletion rust/datafusion/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ mod negative;
mod not;
mod nullif;
mod sum;
mod try_cast;

pub use average::{avg_return_type, Avg, AvgAccumulator};
pub use binary::{binary, binary_operator_data_type, BinaryExpr};
pub use case::{case, CaseExpr};
pub use cast::{cast, CastExpr};
pub use cast::{cast, cast_with_options, CastExpr};
pub use column::{col, Column};
pub use count::Count;
pub use in_list::{in_list, InListExpr};
Expand All @@ -58,6 +59,7 @@ pub use negative::{negative, NegativeExpr};
pub use not::{not, NotExpr};
pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES};
pub use sum::{sum_return_type, Sum};
pub use try_cast::{try_cast, TryCastExpr};
/// returns the name of the state
pub fn format_state_name(name: &str, state_name: &str) -> String {
format!("{}[{}]", name, state_name)
Expand Down
Loading

0 comments on commit 81f6521

Please sign in to comment.