Skip to content

Commit

Permalink
feat: Add CometBroadcastExchangeExec to support broadcasting the resu…
Browse files Browse the repository at this point in the history
…lt of Comet native operator
  • Loading branch information
viirya committed Feb 21, 2024
1 parent 937b42a commit 767fc6f
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkP
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down
35 changes: 18 additions & 17 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ abstract class CometExec extends CometPlan {
/**
* Executes this Comet operator and serialized output ColumnarBatch into bytes.
*/
private def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = {
def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = {
executeColumnar().mapPartitionsInternal { iter =>
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
Expand All @@ -85,28 +85,14 @@ abstract class CometExec extends CometPlan {
}
}

/**
* Decodes the byte arrays back to ColumnarBatches and put them into buffer.
*/
private def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = {
if (bytes.size == 0) {
return Iterator.empty
}

val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))

new ArrowReaderIterator(Channels.newChannel(ins))
}

/**
* Executes the Comet operator and returns the result as an iterator of ColumnarBatch.
*/
def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = {
val countsAndBytes = getByteArrayRdd().collect()
val total = countsAndBytes.map(_._1).sum
val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeBatches(countAndBytes._2))
val rows = countsAndBytes.iterator
.flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2))
(total, rows)
}
}
Expand All @@ -133,6 +119,21 @@ object CometExec {
val bytes = outputStream.toByteArray
new CometExecIterator(newIterId, inputs, bytes, nativeMetrics)
}

/**
* Decodes the byte arrays back to ColumnarBatches and put them into buffer.
*/
def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = {
if (bytes.size == 0) {
return Iterator.empty
}

val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))

new ArrowReaderIterator(Channels.newChannel(ins))
}
}

/**
Expand Down
27 changes: 26 additions & 1 deletion spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Hex
import org.apache.spark.sql.comet.{CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
Expand All @@ -55,6 +55,31 @@ class CometExecSuite extends CometTestBase {
}
}

test("CometBroadcastExchangeExec") {
withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_b") {
val df = sql(
"SELECT tbl_a._1, tbl_b._2 FROM tbl_a JOIN tbl_b " +
"WHERE tbl_a._1 > tbl_a._2 LIMIT 2")

val nativeBroadcast = find(df.queryExecution.executedPlan) {
case _: CometBroadcastExchangeExec => true
case _ => false
}.get.asInstanceOf[CometBroadcastExchangeExec]

val numParts = nativeBroadcast.executeColumnar().getNumPartitions

val rows = nativeBroadcast.executeCollect().toSeq.sortBy(row => row.getInt(0))
val rowContents = rows.map(row => row.getInt(0))
val expected = (0 until numParts).flatMap(_ => (0 until 5).map(i => i + 1)).sorted

assert(rowContents === expected)
}
}
}
}

test("CometExec.executeColumnarCollectIterator can collect ColumnarBatch results") {
withSQLConf(
CometConf.COMET_EXEC_ENABLED.key -> "true",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.parquet.hadoop.ParquetWriter
import org.apache.parquet.hadoop.example.ExampleParquetWriter
import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.apache.spark._
import org.apache.spark.sql.comet.{CometBatchScanExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder}
import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
Expand Down Expand Up @@ -79,6 +79,8 @@ abstract class CometTestBase
CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true",
CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.key -> "2g",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1g",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "1g",
SQLConf.ANSI_ENABLED.key -> "false") {
testFun
}
Expand Down Expand Up @@ -157,6 +159,7 @@ abstract class CometTestBase
case _: CometScanExec | _: CometBatchScanExec => true
case _: CometSinkPlaceHolder | _: CometScanWrapper => false
case _: CometExec | _: CometShuffleExchangeExec => true
case _: CometBroadcastExchangeExec => true
case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter => true
case op =>
if (excludedClasses.exists(c => c.isAssignableFrom(op.getClass))) {
Expand Down

0 comments on commit 767fc6f

Please sign in to comment.