Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Only allow incompatible cast expressions to run in comet if a config is enabled #362

Merged
merged 39 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d592417
Only allow supported cast operations to be converted to native
andygrove May 1, 2024
4641540
refactor for readability
andygrove May 1, 2024
39ac177
Remove unused import
andygrove May 1, 2024
c7775a6
formatting
andygrove May 1, 2024
4e803cc
avoid duplicate strings in withInfo
andygrove May 1, 2024
d255879
always cast between same type
andygrove May 1, 2024
b2d3d2d
save progress
andygrove May 2, 2024
889c754
skip try_cast testing prior to Spark 3.4
andygrove May 2, 2024
35bcbf9
add config to allow incompatible casts
andygrove May 2, 2024
5be20c6
remove unused imports
andygrove May 2, 2024
a27f846
improve fallback reason reporting
andygrove May 2, 2024
23847cd
revert some changes that are no longer needed
andygrove May 2, 2024
4bf6c16
save progress
andygrove May 2, 2024
ed07913
checkpoint
andygrove May 2, 2024
5255d6c
save
andygrove May 2, 2024
0c8da56
code cleanup
andygrove May 2, 2024
22c6394
fix bug
andygrove May 2, 2024
7c39b05
fix bug
andygrove May 2, 2024
787d17b
improve reporting
andygrove May 2, 2024
1b833e2
add documentation
andygrove May 2, 2024
7ad2a43
improve docs, fix code format
andygrove May 2, 2024
ee77b14
fix lint error
andygrove May 2, 2024
6d903c3
automate docs
andygrove May 2, 2024
4c24e41
automate docs
andygrove May 2, 2024
b8df40f
remove unused imports
andygrove May 2, 2024
4b98bc2
make format
andygrove May 2, 2024
6675c72
prettier
andygrove May 2, 2024
b67d571
change default to false
andygrove May 2, 2024
83a70b7
formatting
andygrove May 2, 2024
e5226f0
set COMET_CAST_ALLOW_INCOMPATIBLE=true in some tests
andygrove May 2, 2024
6738544
revert a change
andygrove May 2, 2024
a1bfdee
spark 3.2 does not support cast timestamp to string
andygrove May 3, 2024
15c4989
spotless
andygrove May 3, 2024
8c40e17
add links to issues
andygrove May 3, 2024
be94495
regenerate config docs
andygrove May 3, 2024
581d92e
prettier
andygrove May 3, 2024
4374fc0
revert prettier
andygrove May 3, 2024
2e5e33b
revert a change
andygrove May 3, 2024
128e763
revert some redundant test changes
andygrove May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1042,10 +1042,10 @@ object CometSparkSessionExtensions extends Logging {
* The node with information (if any) attached
*/
def withInfo[T <: TreeNode[_]](node: T, info: String, exprs: T*): T = {
val exprInfo = exprs
.flatMap { e => Seq(e.getTagValue(CometExplainInfo.EXTENSION_INFO)) }
.flatten
.mkString("\n")
val exprInfo =
(Seq(node.getTagValue(CometExplainInfo.EXTENSION_INFO)) ++ exprs
.flatMap(e => Seq(e.getTagValue(CometExplainInfo.EXTENSION_INFO)))).flatten
.mkString("\n")
andygrove marked this conversation as resolved.
Show resolved Hide resolved
if (info != null && info.nonEmpty && exprInfo.nonEmpty) {
node.setTagValue(CometExplainInfo.EXTENSION_INFO, Seq(exprInfo, info).mkString("\n"))
} else if (exprInfo.nonEmpty) {
Expand Down
136 changes: 136 additions & 0 deletions spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.expressions

import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType}

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo

object CometCast {

def isSupported(
cast: Cast,
fromType: DataType,
toType: DataType,
timeZoneId: Option[String],
evalMode: String): Boolean = {
(fromType, toType) match {
case (DataTypes.StringType, _) =>
castFromStringSupported(cast, toType, timeZoneId, evalMode)
case (_, DataTypes.StringType) =>
castToStringSupported(cast, fromType, timeZoneId, evalMode)
case (DataTypes.TimestampType, _) =>
castFromTimestampSupported(cast, toType, timeZoneId, evalMode)
case (_: DecimalType, DataTypes.FloatType | DataTypes.DoubleType) => true
case (
DataTypes.BooleanType,
DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
DataTypes.LongType | DataTypes.FloatType | DataTypes.DoubleType) =>
true
case (
DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType,
DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType |
DataTypes.DoubleType) =>
true
case (
DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType,
_: DecimalType) =>
true
case (DataTypes.FloatType, DataTypes.BooleanType | DataTypes.DoubleType) => true
case (DataTypes.DoubleType, DataTypes.BooleanType | DataTypes.FloatType) => true
case _ => false
}
}

private def castFromStringSupported(
cast: Cast,
toType: DataType,
timeZoneId: Option[String],
evalMode: String): Boolean = {
toType match {
case DataTypes.BooleanType =>
true
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
DataTypes.LongType =>
true
case DataTypes.FloatType | DataTypes.DoubleType =>
// https://github.com/apache/datafusion-comet/issues/326
false
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/325
false
case DataTypes.DateType =>
// https://github.com/apache/datafusion-comet/issues/327
false
case DataTypes.TimestampType =>
val enabled = CometConf.COMET_CAST_STRING_TO_TIMESTAMP.get()
if (!enabled) {
// https://github.com/apache/datafusion-comet/issues/328
withInfo(cast, s"${CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key} is disabled")
}
enabled
case _ =>
false
}
}

private def castToStringSupported(
cast: Cast,
fromType: DataType,
timeZoneId: Option[String],
evalMode: String): Boolean = {
fromType match {
case DataTypes.BooleanType => true
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
DataTypes.LongType =>
true
case DataTypes.DateType => true
case DataTypes.TimestampType => true
case DataTypes.FloatType | DataTypes.DoubleType =>
// https://github.com/apache/datafusion-comet/issues/326
false
case _ => false
}
}

private def castFromTimestampSupported(
cast: Cast,
toType: DataType,
timeZoneId: Option[String],
evalMode: String): Boolean = {
toType match {
case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType =>
// https://github.com/apache/datafusion-comet/issues/352
// this seems like an edge case that isn't important for us to support
false
case DataTypes.LongType =>
// https://github.com/apache/datafusion-comet/issues/352
false
case DataTypes.StringType => true
case DataTypes.DateType => true
case _ => false
}
}

}
18 changes: 6 additions & 12 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus, withInfo}
import org.apache.comet.expressions.CometCast
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator}
Expand Down Expand Up @@ -575,7 +576,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
val value = cast.eval()
exprToProtoInternal(Literal(value, dataType), inputs)

case Cast(child, dt, timeZoneId, evalMode) =>
case cast @ Cast(child, dt, timeZoneId, evalMode) =>
val childExpr = exprToProtoInternal(child, inputs)
if (childExpr.isDefined) {
val evalModeStr = if (evalMode.isInstanceOf[Boolean]) {
Expand All @@ -585,19 +586,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
// Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
evalMode.toString
}
val supportedCast = (child.dataType, dt) match {
case (DataTypes.StringType, DataTypes.TimestampType)
if !CometConf.COMET_CAST_STRING_TO_TIMESTAMP.get() =>
// https://github.com/apache/datafusion-comet/issues/328
withInfo(expr, s"${CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key} is disabled")
false
case _ => true
}
if (supportedCast) {
if (CometCast.isSupported(cast, child.dataType, dt, timeZoneId, evalModeStr)) {
castToProto(timeZoneId, dt, childExpr, evalModeStr)
} else {
// no need to call withInfo here since it was called when determining
// the value for `supportedCast`
withInfo(
expr,
s"Unsupported cast from ${child.dataType} to $dt with timezone $timeZoneId and evalMode $evalModeStr")
None
}
} else {
Expand Down
6 changes: 3 additions & 3 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -771,12 +771,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
// cast() should return null for invalid inputs when ansi mode is disabled
val df = spark.sql(s"select a, cast(a as ${toType.sql}) from t order by a")
checkSparkAnswer(df)
checkSparkAnswerAndOperator(df)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this so that we make sure that the comet plan really ran with comet


// try_cast() should always return null for invalid inputs
val df2 =
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
checkSparkAnswer(df2)
checkSparkAnswerAndOperator(df2)
}

// with ANSI enabled, we should produce the same exception as Spark
Expand Down Expand Up @@ -818,7 +818,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
// try_cast() should always return null for invalid inputs
val df2 =
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
checkSparkAnswer(df2)
checkSparkAnswerAndOperator(df2)
}
}
}
Expand Down
Loading