diff --git a/fuzz-testing/README.md b/fuzz-testing/README.md index 56af359f2..076ff6aea 100644 --- a/fuzz-testing/README.md +++ b/fuzz-testing/README.md @@ -30,8 +30,8 @@ Comet Fuzz is inspired by the [SparkFuzz](https://ir.cwi.nl/pub/30222) paper fro Planned areas of improvement: +- ANSI mode - Support for all data types, expressions, and operators supported by Comet -- Explicit casts - Unary and binary arithmetic expressions - IF and CASE WHEN expressions - Complex (nested) expressions @@ -91,7 +91,8 @@ $SPARK_HOME/bin/spark-submit \ --conf spark.comet.exec.shuffle.enabled=true \ --conf spark.comet.exec.shuffle.mode=auto \ --jars $COMET_JAR \ - --driver-class-path $COMET_JAR \ + --conf spark.driver.extraClassPath=$COMET_JAR \ + --conf spark.executor.extraClassPath=$COMET_JAR \ --class org.apache.comet.fuzz.Main \ target/comet-fuzz-spark3.4_2.12-0.1.0-SNAPSHOT-jar-with-dependencies.jar \ run --num-files=2 --filename=queries.sql diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/DataGen.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/DataGen.scala index 47a6bd879..9f9f772b7 100644 --- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/DataGen.scala +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/DataGen.scala @@ -50,7 +50,8 @@ object DataGen { // generate schema using random data types val fields = Range(0, numColumns) - .map(i => StructField(s"c$i", Utils.randomWeightedChoice(Meta.dataTypes), nullable = true)) + .map(i => + StructField(s"c$i", Utils.randomWeightedChoice(Meta.dataTypes, r), nullable = true)) val schema = StructType(fields) // generate columnar data diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala index 1daa26200..7584e76ce 100644 --- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala @@ -42,10 +42,11 @@ object QueryGen { val uniqueQueries = mutable.HashSet[String]() for (_ <- 0 until numQueries) { - val sql = r.nextInt().abs % 3 match { + val sql = r.nextInt().abs % 4 match { case 0 => generateJoin(r, spark, numFiles) case 1 => generateAggregate(r, spark, numFiles) case 2 => generateScalar(r, spark, numFiles) + case 3 => generateCast(r, spark, numFiles) } if (!uniqueQueries.contains(sql)) { uniqueQueries += sql @@ -91,6 +92,21 @@ object QueryGen { s"ORDER BY ${args.mkString(", ")};" } + private def generateCast(r: Random, spark: SparkSession, numFiles: Int): String = { + val tableName = s"test${r.nextInt(numFiles)}" + val table = spark.table(tableName) + + val toType = Utils.randomWeightedChoice(Meta.dataTypes, r).sql + val arg = Utils.randomChoice(table.columns, r) + + // We test both `cast` and `try_cast` to cover LEGACY and TRY eval modes. It is not + // recommended to run Comet Fuzz with ANSI enabled currently. + // Example SELECT c0, cast(c0 as float), try_cast(c0 as float) FROM test0 + s"SELECT $arg, cast($arg as $toType), try_cast($arg as $toType) " + + s"FROM $tableName " + + s"ORDER BY $arg;" + } + private def generateJoin(r: Random, spark: SparkSession, numFiles: Int): String = { val leftTableName = s"test${r.nextInt(numFiles)}" val rightTableName = s"test${r.nextInt(numFiles)}" @@ -101,7 +117,7 @@ object QueryGen { val rightCol = Utils.randomChoice(rightTable.columns, r) val joinTypes = Seq(("INNER", 0.4), ("LEFT", 0.3), ("RIGHT", 0.3)) - val joinType = Utils.randomWeightedChoice(joinTypes) + val joinType = Utils.randomWeightedChoice(joinTypes, r) val leftColProjection = leftTable.columns.map(c => s"l.$c").mkString(", ") val rightColProjection = rightTable.columns.map(c => s"r.$c").mkString(", ") diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala index 49f9fc3bd..b2ceae9d0 100644 --- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala @@ -19,7 +19,7 @@ package org.apache.comet.fuzz -import java.io.{BufferedWriter, FileWriter} +import java.io.{BufferedWriter, FileWriter, PrintWriter} import scala.io.Source @@ -109,7 +109,12 @@ object QueryRunner { case e: Exception => // the query worked in Spark but failed in Comet, so this is likely a bug in Comet showSQL(w, sql) - w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}\n") + w.write(s"[ERROR] Query failed in Comet: ${e.getMessage}:\n") + w.write("```\n") + val p = new PrintWriter(w) + e.printStackTrace(p) + p.close() + w.write("```\n") } // flush after every query so that results are saved in the event of the driver crashing @@ -134,6 +139,7 @@ object QueryRunner { private def formatRow(row: Row): String = { row.toSeq .map { + case null => "NULL" case v: Array[Byte] => v.mkString case other => other.toString } diff --git a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Utils.scala b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Utils.scala index 19f9695a9..4d51c60e5 100644 --- a/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Utils.scala +++ b/fuzz-testing/src/main/scala/org/apache/comet/fuzz/Utils.scala @@ -27,9 +27,9 @@ object Utils { list(r.nextInt(list.length)) } - def randomWeightedChoice[T](valuesWithWeights: Seq[(T, Double)]): T = { + def randomWeightedChoice[T](valuesWithWeights: Seq[(T, Double)], r: Random): T = { val totalWeight = valuesWithWeights.map(_._2).sum - val randomValue = Random.nextDouble() * totalWeight + val randomValue = r.nextDouble() * totalWeight var cumulativeWeight = 0.0 for ((value, weight) <- valuesWithWeights) { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 439ec4ebb..8d81b57c4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2169,10 +2169,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val trimCast = Cast(trimStr.get, StringType) val trimExpr = exprToProtoInternal(trimCast, inputs) val optExpr = scalarExprToProto(trimType, srcExpr, trimExpr) - optExprWithInfo(optExpr, expr, null, srcCast, trimCast) + optExprWithInfo(optExpr, expr, srcCast, trimCast) } else { val optExpr = scalarExprToProto(trimType, srcExpr) - optExprWithInfo(optExpr, expr, null, srcCast) + optExprWithInfo(optExpr, expr, srcCast) } }