Skip to content

Commit

Permalink
fix: bitwise shift with different left/right types
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Feb 29, 2024
1 parent e2a6aca commit a8fa0d4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
16 changes: 14 additions & 2 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1376,7 +1376,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {

case ShiftRight(left, right) =>
val leftExpr = exprToProtoInternal(left, inputs)
val rightExpr = exprToProtoInternal(right, inputs)
val rightExpr = if (left.dataType == LongType) {
// DataFusion bitwise shift right expression requires
// same data type between left and right side
exprToProtoInternal(Cast(right, LongType), inputs)
} else {
exprToProtoInternal(right, inputs)
}

if (leftExpr.isDefined && rightExpr.isDefined) {
val builder = ExprOuterClass.BitwiseShiftRight.newBuilder()
Expand All @@ -1394,7 +1400,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {

case ShiftLeft(left, right) =>
val leftExpr = exprToProtoInternal(left, inputs)
val rightExpr = exprToProtoInternal(right, inputs)
val rightExpr = if (left.dataType == LongType) {
// DataFusion bitwise shift left expression requires
// same data type between left and right side
exprToProtoInternal(Cast(right, LongType), inputs)
} else {
exprToProtoInternal(right, inputs)
}

if (leftExpr.isDefined && rightExpr.isDefined) {
val builder = ExprOuterClass.BitwiseShiftLeft.newBuilder()
Expand Down
20 changes: 20 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@ import org.apache.comet.CometSparkSessionExtensions.{isSpark32, isSpark34Plus}
class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._

test("bitwise shift with different left/right types") {
Seq(false, true).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
val table = "test"
withTable(table) {
sql(s"create table $table(col1 long, col2 int) using parquet")
sql(s"insert into $table values(1111, 2)")
sql(s"insert into $table values(1111, 2)")
sql(s"insert into $table values(3333, 4)")
sql(s"insert into $table values(5555, 6)")

checkSparkAnswerAndOperator(
s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table")
checkSparkAnswerAndOperator(
s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table")
}
}
}
}

test("basic data type support") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
Expand Down

0 comments on commit a8fa0d4

Please sign in to comment.