Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support ANSI mode in CAST from String to Bool #290

Merged
merged 19 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,14 @@ object CometConf {
.toSequence
.createWithDefault(Seq("Range,InMemoryTableScan"))

val COMET_ANSI_MODE_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.ansi.enabled")
.doc(
"Comet does not respect ANSI mode in most cases and by default will not accelerate " +
"queries when ansi mode is enabled. Enable this setting to test Comet's experimental " +
"support for ANSI mode. This should not be used in production.")
.booleanConf
.createWithDefault(false)

}

object ConfigHelpers {
Expand Down
16 changes: 16 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ pub enum CometError {
#[error("Comet Internal Error: {0}")]
Internal(String),

// Note that this message format is based on Spark 3.4 and is more detailed than the message
// returned by Spark 3.2 or 3.3
#[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
because it is malformed. Correct the value as per the syntax, or change its target type. \
Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
CastInvalidValue {
value: String,
from_type: String,
to_type: String,
},

#[error(transparent)]
Arrow {
#[from]
Expand Down Expand Up @@ -183,6 +195,10 @@ impl jni::errors::ToException for CometError {
class: "java/lang/NullPointerException".to_string(),
msg: self.to_string(),
},
CometError::CastInvalidValue { .. } => Exception {
class: "org/apache/spark/SparkException".to_string(),
msg: self.to_string(),
},
CometError::NumberIntFormat { source: s } => Exception {
class: "java/lang/NumberFormatException".to_string(),
msg: s.to_string(),
Expand Down
74 changes: 55 additions & 19 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::{
sync::Arc,
};

use crate::errors::{CometError, CometResult};
use arrow::{
compute::{cast_with_options, CastOptions},
record_batch::RecordBatch,
Expand All @@ -30,7 +31,7 @@ use arrow::{
use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait};
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{Result as DataFusionResult, ScalarValue};
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;

use crate::execution::datafusion::expressions::utils::{
Expand All @@ -45,30 +46,49 @@ static CAST_OPTIONS: CastOptions = CastOptions {
.with_timestamp_format(TIMESTAMP_FORMAT),
};

#[derive(Debug, Hash, PartialEq, Clone, Copy)]
pub enum EvalMode {
Legacy,
Ansi,
Try,
}

#[derive(Debug, Hash)]
pub struct Cast {
pub child: Arc<dyn PhysicalExpr>,
pub data_type: DataType,
pub eval_mode: EvalMode,

/// When cast from/to timezone related types, we need timezone, which will be resolved with
/// session local timezone by an analyzer in Spark.
pub timezone: String,
}

impl Cast {
pub fn new(child: Arc<dyn PhysicalExpr>, data_type: DataType, timezone: String) -> Self {
pub fn new(
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
eval_mode: EvalMode,
timezone: String,
) -> Self {
Self {
child,
data_type,
timezone,
eval_mode,
}
}

pub fn new_without_timezone(child: Arc<dyn PhysicalExpr>, data_type: DataType) -> Self {
pub fn new_without_timezone(
child: Arc<dyn PhysicalExpr>,
data_type: DataType,
eval_mode: EvalMode,
) -> Self {
Self {
child,
data_type,
timezone: "".to_string(),
eval_mode,
}
}

Expand All @@ -77,17 +97,22 @@ impl Cast {
let array = array_with_timezone(array, self.timezone.clone(), Some(to_type));
let from_type = array.data_type();
let cast_result = match (from_type, to_type) {
(DataType::Utf8, DataType::Boolean) => Self::spark_cast_utf8_to_boolean::<i32>(&array),
(DataType::Utf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i32>(&array, self.eval_mode)?
}
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array)
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)?
}
_ => cast_with_options(&array, to_type, &CAST_OPTIONS)?,
};
let result = spark_cast(cast_result, from_type, to_type);
Ok(result)
}

fn spark_cast_utf8_to_boolean<OffsetSize>(from: &dyn Array) -> ArrayRef
fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
Expand All @@ -100,24 +125,29 @@ impl Cast {
.iter()
.map(|value| match value {
Some(value) => match value.to_ascii_lowercase().trim() {
"t" | "true" | "y" | "yes" | "1" => Some(true),
"f" | "false" | "n" | "no" | "0" => Some(false),
_ => None,
"t" | "true" | "y" | "yes" | "1" => Ok(Some(true)),
"f" | "false" | "n" | "no" | "0" => Ok(Some(false)),
_ if eval_mode == EvalMode::Ansi => Err(CometError::CastInvalidValue {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "BOOLEAN".to_string(),
}),
_ => Ok(None),
},
_ => None,
_ => Ok(None),
})
.collect::<BooleanArray>();
.collect::<Result<BooleanArray, _>>()?;

Arc::new(output_array)
Ok(Arc::new(output_array))
}
}

impl Display for Cast {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Cast [data_type: {}, timezone: {}, child: {}]",
self.data_type, self.timezone, self.child
"Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]",
self.data_type, self.timezone, self.child, &self.eval_mode
)
}
}
Expand All @@ -130,6 +160,7 @@ impl PartialEq<dyn Any> for Cast {
self.child.eq(&x.child)
&& self.timezone.eq(&x.timezone)
&& self.data_type.eq(&x.data_type)
&& self.eval_mode.eq(&x.eval_mode)
})
.unwrap_or(false)
}
Expand Down Expand Up @@ -171,18 +202,23 @@ impl PhysicalExpr for Cast {
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(Cast::new(
children[0].clone(),
self.data_type.clone(),
self.timezone.clone(),
)))
match children.len() {
1 => Ok(Arc::new(Cast::new(
children[0].clone(),
self.data_type.clone(),
self.eval_mode,
self.timezone.clone(),
))),
_ => internal_err!("Cast should have exactly one child"),
}
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.child.hash(&mut s);
self.data_type.hash(&mut s);
self.timezone.hash(&mut s);
self.eval_mode.hash(&mut s);
self.hash(&mut s);
}
}
22 changes: 19 additions & 3 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ use crate::{
avg_decimal::AvgDecimal,
bitwise_not::BitwiseNotExpr,
bloom_filter_might_contain::BloomFilterMightContain,
cast::Cast,
cast::{Cast, EvalMode},
checkoverflow::CheckOverflow,
if_expr::IfExpr,
scalar_funcs::create_comet_physical_fun,
Expand Down Expand Up @@ -343,7 +343,17 @@ impl PhysicalPlanner {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let timezone = expr.timezone.clone();
Ok(Arc::new(Cast::new(child, datatype, timezone)))
let eval_mode = match expr.eval_mode.as_str() {
"ANSI" => EvalMode::Ansi,
"TRY" => EvalMode::Try,
"LEGACY" => EvalMode::Legacy,
other => {
return Err(ExecutionError::GeneralError(format!(
"Invalid Cast EvalMode: \"{other}\""
)))
}
};
Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone)))
}
ExprStruct::Hour(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
Expand Down Expand Up @@ -638,13 +648,19 @@ impl PhysicalPlanner {
let left = Arc::new(Cast::new_without_timezone(
left,
DataType::Decimal256(p1, s1),
EvalMode::Legacy,
));
let right = Arc::new(Cast::new_without_timezone(
right,
DataType::Decimal256(p2, s2),
EvalMode::Legacy,
));
let child = Arc::new(BinaryExpr::new(left, op, right));
Ok(Arc::new(Cast::new_without_timezone(child, data_type)))
Ok(Arc::new(Cast::new_without_timezone(
child,
data_type,
EvalMode::Legacy,
)))
}
(
DataFusionOperator::Divide,
Expand Down
2 changes: 2 additions & 0 deletions core/src/execution/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ message Cast {
Expr child = 1;
DataType datatype = 2;
string timezone = 3;
// LEGACY, ANSI, or TRY
string eval_mode = 4;
}

message Equal {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,12 @@ class CometSparkSessionExtensions
// DataFusion doesn't have ANSI mode. For now we just disable CometExec if ANSI mode is
// enabled.
if (isANSIEnabled(conf)) {
logInfo("Comet extension disabled for ANSI mode")
return plan
if (COMET_ANSI_MODE_ENABLED.get()) {
logWarning("Using Comet's experimental support for ANSI mode.")
} else {
logInfo("Comet extension disabled for ANSI mode")
return plan
}
}

// We shouldn't transform Spark query plan if Comet is disabled.
Expand Down
18 changes: 14 additions & 4 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
def castToProto(
timeZoneId: Option[String],
dt: DataType,
childExpr: Option[Expr]): Option[Expr] = {
childExpr: Option[Expr],
evalMode: String): Option[Expr] = {
val dataType = serializeDataType(dt)

if (childExpr.isDefined && dataType.isDefined) {
val castBuilder = ExprOuterClass.Cast.newBuilder()
castBuilder.setChild(childExpr.get)
castBuilder.setDatatype(dataType.get)
castBuilder.setEvalMode(evalMode)

val timeZone = timeZoneId.getOrElse("UTC")
castBuilder.setTimezone(timeZone)
Expand All @@ -446,9 +448,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
val value = cast.eval()
exprToProtoInternal(Literal(value, dataType), inputs)

case Cast(child, dt, timeZoneId, _) =>
case Cast(child, dt, timeZoneId, evalMode) =>
val childExpr = exprToProtoInternal(child, inputs)
castToProto(timeZoneId, dt, childExpr)
val evalModeStr = if (evalMode.isInstanceOf[Boolean]) {
// Spark 3.2 & 3.3 has ansiEnabled boolean
if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY"
} else {
// Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
evalMode.toString
}
castToProto(timeZoneId, dt, childExpr, evalModeStr)

case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
val leftExpr = exprToProtoInternal(left, inputs)
Expand Down Expand Up @@ -991,6 +1000,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
.newBuilder()
.setChild(e)
.setDatatype(serializeDataType(IntegerType).get)
.setEvalMode("LEGACY") // year is not affected by ANSI mode
.build())
.build()
})
Expand Down Expand Up @@ -1565,7 +1575,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
val childExpr = scalarExprToProto("coalesce", exprChildren: _*)
// TODO: Remove this once we have new DataFusion release which includes
// the fix: https://github.com/apache/arrow-datafusion/pull/9459
castToProto(None, a.dataType, childExpr)
castToProto(None, a.dataType, childExpr, "LEGACY")

// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for
// char types. Use rpad to achieve the behavior.
Expand Down
Loading
Loading