Skip to content

Commit

Permalink
feat: Add specific fuzz tests for cast and try_cast and fix NPE found…
Browse files Browse the repository at this point in the history
… during fuzz testing (#514)

* Varius improvements to fuzz testing tool

* Fix NPE in QueryPlanSerde handling of trim expression

* format
  • Loading branch information
andygrove authored Jun 4, 2024
1 parent 6565229 commit a668a86
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 11 deletions.
5 changes: 3 additions & 2 deletions fuzz-testing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ Comet Fuzz is inspired by the [SparkFuzz](https://ir.cwi.nl/pub/30222) paper fro

Planned areas of improvement:

- ANSI mode
- Support for all data types, expressions, and operators supported by Comet
- Explicit casts
- Unary and binary arithmetic expressions
- IF and CASE WHEN expressions
- Complex (nested) expressions
Expand Down Expand Up @@ -91,7 +91,8 @@ $SPARK_HOME/bin/spark-submit \
--conf spark.comet.exec.shuffle.enabled=true \
--conf spark.comet.exec.shuffle.mode=auto \
--jars $COMET_JAR \
--driver-class-path $COMET_JAR \
--conf spark.driver.extraClassPath=$COMET_JAR \
--conf spark.executor.extraClassPath=$COMET_JAR \
--class org.apache.comet.fuzz.Main \
target/comet-fuzz-spark3.4_2.12-0.1.0-SNAPSHOT-jar-with-dependencies.jar \
run --num-files=2 --filename=queries.sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ object DataGen {

// generate schema using random data types
val fields = Range(0, numColumns)
.map(i => StructField(s"c$i", Utils.randomWeightedChoice(Meta.dataTypes), nullable = true))
.map(i =>
StructField(s"c$i", Utils.randomWeightedChoice(Meta.dataTypes, r), nullable = true))
val schema = StructType(fields)

// generate columnar data
Expand Down
20 changes: 18 additions & 2 deletions fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ object QueryGen {
val uniqueQueries = mutable.HashSet[String]()

for (_ <- 0 until numQueries) {
val sql = r.nextInt().abs % 3 match {
val sql = r.nextInt().abs % 4 match {
case 0 => generateJoin(r, spark, numFiles)
case 1 => generateAggregate(r, spark, numFiles)
case 2 => generateScalar(r, spark, numFiles)
case 3 => generateCast(r, spark, numFiles)
}
if (!uniqueQueries.contains(sql)) {
uniqueQueries += sql
Expand Down Expand Up @@ -91,6 +92,21 @@ object QueryGen {
s"ORDER BY ${args.mkString(", ")};"
}

private def generateCast(r: Random, spark: SparkSession, numFiles: Int): String = {
val tableName = s"test${r.nextInt(numFiles)}"
val table = spark.table(tableName)

val toType = Utils.randomWeightedChoice(Meta.dataTypes, r).sql
val arg = Utils.randomChoice(table.columns, r)

// We test both `cast` and `try_cast` to cover LEGACY and TRY eval modes. It is not
// recommended to run Comet Fuzz with ANSI enabled currently.
// Example SELECT c0, cast(c0 as float), try_cast(c0 as float) FROM test0
s"SELECT $arg, cast($arg as $toType), try_cast($arg as $toType) " +
s"FROM $tableName " +
s"ORDER BY $arg;"
}

private def generateJoin(r: Random, spark: SparkSession, numFiles: Int): String = {
val leftTableName = s"test${r.nextInt(numFiles)}"
val rightTableName = s"test${r.nextInt(numFiles)}"
Expand All @@ -101,7 +117,7 @@ object QueryGen {
val rightCol = Utils.randomChoice(rightTable.columns, r)

val joinTypes = Seq(("INNER", 0.4), ("LEFT", 0.3), ("RIGHT", 0.3))
val joinType = Utils.randomWeightedChoice(joinTypes)
val joinType = Utils.randomWeightedChoice(joinTypes, r)

val leftColProjection = leftTable.columns.map(c => s"l.$c").mkString(", ")
val rightColProjection = rightTable.columns.map(c => s"r.$c").mkString(", ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

package org.apache.comet.fuzz

import java.io.{BufferedWriter, FileWriter}
import java.io.{BufferedWriter, FileWriter, PrintWriter}

import scala.io.Source

Expand Down Expand Up @@ -109,7 +109,12 @@ object QueryRunner {
case e: Exception =>
// the query worked in Spark but failed in Comet, so this is likely a bug in Comet
showSQL(w, sql)
w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}\n")
w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}:\n")
w.write("```\n")
val p = new PrintWriter(w)
e.printStackTrace(p)
p.close()
w.write("```\n")
}

// flush after every query so that results are saved in the event of the driver crashing
Expand All @@ -134,6 +139,7 @@ object QueryRunner {
private def formatRow(row: Row): String = {
row.toSeq
.map {
case null => "NULL"
case v: Array[Byte] => v.mkString
case other => other.toString
}
Expand Down
4 changes: 2 additions & 2 deletions fuzz-testing/src/main/scala/org/apache/comet/fuzz/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ object Utils {
list(r.nextInt(list.length))
}

def randomWeightedChoice[T](valuesWithWeights: Seq[(T, Double)]): T = {
def randomWeightedChoice[T](valuesWithWeights: Seq[(T, Double)], r: Random): T = {
val totalWeight = valuesWithWeights.map(_._2).sum
val randomValue = Random.nextDouble() * totalWeight
val randomValue = r.nextDouble() * totalWeight
var cumulativeWeight = 0.0

for ((value, weight) <- valuesWithWeights) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2169,10 +2169,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
val trimCast = Cast(trimStr.get, StringType)
val trimExpr = exprToProtoInternal(trimCast, inputs)
val optExpr = scalarExprToProto(trimType, srcExpr, trimExpr)
optExprWithInfo(optExpr, expr, null, srcCast, trimCast)
optExprWithInfo(optExpr, expr, srcCast, trimCast)
} else {
val optExpr = scalarExprToProto(trimType, srcExpr)
optExprWithInfo(optExpr, expr, null, srcCast)
optExprWithInfo(optExpr, expr, srcCast)
}
}

Expand Down

0 comments on commit a668a86

Please sign in to comment.