From 767fc6fcc694fb1df2537dda6847d27d742c21f6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 21 Feb 2024 13:40:38 -0800 Subject: [PATCH] feat: Add CometBroadcastExchangeExec to support broadcasting the result of Comet native operator --- .../apache/comet/serde/QueryPlanSerde.scala | 2 +- .../apache/spark/sql/comet/operators.scala | 35 ++++++++++--------- .../apache/comet/exec/CometExecSuite.scala | 27 +++++++++++++- .../org/apache/spark/sql/CometTestBase.scala | 5 ++- 4 files changed, 49 insertions(+), 20 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a60b7e5b78..0a251d4484 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkP import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 12f086c765..d1d6f8f205 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -71,7 +71,7 @@ abstract class CometExec extends CometPlan { /** * Executes this Comet operator and serialized output ColumnarBatch into bytes. */ - private def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = { + def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = { executeColumnar().mapPartitionsInternal { iter => val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate) @@ -85,28 +85,14 @@ abstract class CometExec extends CometPlan { } } - /** - * Decodes the byte arrays back to ColumnarBatches and put them into buffer. - */ - private def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = { - if (bytes.size == 0) { - return Iterator.empty - } - - val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val cbbis = bytes.toInputStream() - val ins = new DataInputStream(codec.compressedInputStream(cbbis)) - - new ArrowReaderIterator(Channels.newChannel(ins)) - } - /** * Executes the Comet operator and returns the result as an iterator of ColumnarBatch. */ def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = { val countsAndBytes = getByteArrayRdd().collect() val total = countsAndBytes.map(_._1).sum - val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeBatches(countAndBytes._2)) + val rows = countsAndBytes.iterator + .flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2)) (total, rows) } } @@ -133,6 +119,21 @@ object CometExec { val bytes = outputStream.toByteArray new CometExecIterator(newIterId, inputs, bytes, nativeMetrics) } + + /** + * Decodes the byte arrays back to ColumnarBatches and put them into buffer. + */ + def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = { + if (bytes.size == 0) { + return Iterator.empty + } + + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val cbbis = bytes.toInputStream() + val ins = new DataInputStream(codec.compressedInputStream(cbbis)) + + new ArrowReaderIterator(Channels.newChannel(ins)) + } } /** diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index e1f864249e..d3a1bd2c95 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex -import org.apache.spark.sql.comet.{CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec @@ -55,6 +55,31 @@ class CometExecSuite extends CometTestBase { } } + test("CometBroadcastExchangeExec") { + withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_b") { + val df = sql( + "SELECT tbl_a._1, tbl_b._2 FROM tbl_a JOIN tbl_b " + + "WHERE tbl_a._1 > tbl_a._2 LIMIT 2") + + val nativeBroadcast = find(df.queryExecution.executedPlan) { + case _: CometBroadcastExchangeExec => true + case _ => false + }.get.asInstanceOf[CometBroadcastExchangeExec] + + val numParts = nativeBroadcast.executeColumnar().getNumPartitions + + val rows = nativeBroadcast.executeCollect().toSeq.sortBy(row => row.getInt(0)) + val rowContents = rows.map(row => row.getInt(0)) + val expected = (0 until numParts).flatMap(_ => (0 until 5).map(i => i + 1)).sorted + + assert(rowContents === expected) + } + } + } + } + test("CometExec.executeColumnarCollectIterator can collect ColumnarBatch results") { withSQLConf( CometConf.COMET_EXEC_ENABLED.key -> "true", diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 4f2838cfba..2b37ce035f 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -35,7 +35,7 @@ import org.apache.parquet.hadoop.ParquetWriter import org.apache.parquet.hadoop.example.ExampleParquetWriter import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark._ -import org.apache.spark.sql.comet.{CometBatchScanExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -79,6 +79,8 @@ abstract class CometTestBase CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key -> "true", CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.key -> "2g", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1g", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "1g", SQLConf.ANSI_ENABLED.key -> "false") { testFun } @@ -157,6 +159,7 @@ abstract class CometTestBase case _: CometScanExec | _: CometBatchScanExec => true case _: CometSinkPlaceHolder | _: CometScanWrapper => false case _: CometExec | _: CometShuffleExchangeExec => true + case _: CometBroadcastExchangeExec => true case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter => true case op => if (excludedClasses.exists(c => c.isAssignableFrom(op.getClass))) {