Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 3, 2024
1 parent fd8e343 commit 6cb63ed
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 48 deletions.
40 changes: 1 addition & 39 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@

package org.apache.comet.vector

import java.io.OutputStream
import java.nio.channels.Channels

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data}
Expand All @@ -32,46 +28,12 @@ import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.spark.SparkException
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.CometArrowStreamWriter

class NativeUtil {
private val allocator = new RootAllocator(Long.MaxValue)
private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider
private val importer = new ArrowImporter(allocator)

/**
* Serializes a list of `ColumnarBatch` into an output stream.
*
* @param batches
* the output batches, each batch is a list of Arrow vectors wrapped in `CometVector`
* @param out
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = {
var writer: Option[CometArrowStreamWriter] = None
var rowCount = 0

batches.foreach { batch =>
val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
val root = new VectorSchemaRoot(fieldVectors.asJava)
val provider = batchProviderOpt.getOrElse(dictionaryProvider)

if (writer.isEmpty) {
writer = Some(new CometArrowStreamWriter(root, provider, Channels.newChannel(out)))
writer.get.start()
writer.get.writeBatch()
} else {
writer.get.writeMoreBatch(root)
}

root.clear()
rowCount += batch.numRows()
}

writer.map(_.end())

rowCount
}
def getDictionaryProvider: DictionaryProvider = dictionaryProvider

def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ case class StreamReader(channel: ReadableByteChannel) extends AutoCloseable {
// Native shuffle always uses decimal128.
CometVector.getVector(vector, true, arrowReader).asInstanceOf[ColumnVector]
}.toArray

val batch = new ColumnarBatch(columns)
batch.setNumRows(root.getRowCount)
batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ class CometSparkSessionExtensions
case other => other
}
if (!newChildren.exists(_.isInstanceOf[BroadcastExchangeExec])) {
val newPlan = transform(plan.withNewChildren(newChildren))
val newPlan = apply(plan.withNewChildren(newChildren))
if (isCometNative(newPlan) || isCometBroadCastForceEnabled(conf)) {
newPlan
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
@transient
private lazy val maxBroadcastRows = 512000000

private lazy val childRDD = child.asInstanceOf[CometExec].executeColumnar()

@transient
override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
Expand Down Expand Up @@ -191,7 +193,7 @@ case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]()

new CometBatchRDD(sparkContext, broadcasted.value.length, broadcasted)
new CometBatchRDD(sparkContext, childRDD.getNumPartitions, broadcasted)
}

override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
Expand Down
37 changes: 33 additions & 4 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

package org.apache.spark.sql.comet

import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream, OutputStream}
import java.nio.ByteBuffer
import java.nio.channels.Channels

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -77,18 +80,44 @@ abstract class CometExec extends CometPlan {
*/
def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = {
executeColumnar().mapPartitionsInternal { iter =>
serializeBatches(iter)
}
}

/**
* Serializes a list of `ColumnarBatch` into an output stream.
*
* @param batches
* the output batches, each batch is a list of Arrow vectors wrapped in `CometVector`
* @param out
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch]): Iterator[(Long, ChunkedByteBuffer)] = {
val nativeUtil = new NativeUtil()

batches.map { batch =>
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
val out = new DataOutputStream(codec.compressedOutputStream(cbbos))

val count = new NativeUtil().serializeBatches(iter, out)
val (fieldVectors, batchProviderOpt) = nativeUtil.getBatchFieldVectors(batch)
val root = new VectorSchemaRoot(fieldVectors.asJava)
val provider = batchProviderOpt.getOrElse(nativeUtil.getDictionaryProvider)

val writer = new ArrowStreamWriter(root, provider, Channels.newChannel(out))
writer.start()
writer.writeBatch()

root.clear()
writer.end()

out.flush()
out.close()

if (out.size() > 0) {
Iterator((count, cbbos.toChunkedByteBuffer))
(batch.numRows(), cbbos.toChunkedByteBuffer)
} else {
Iterator((count, new ChunkedByteBuffer(Array.empty[ByteBuffer])))
(batch.numRows(), new ChunkedByteBuffer(Array.empty[ByteBuffer]))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ class CometTPCDSQuerySuite
"q99")

// TODO: enable the 3 queries after fixing the issues #1358.
override val tpcdsQueries: Seq[String] =
tpcdsAllQueries.filterNot(excludedTpcdsQueries.contains)

override val tpcdsQueries: Seq[String] = Seq("q4")
// tpcdsAllQueries.filterNot(excludedTpcdsQueries.contains)
// Seq("q1", "q2", "q3", "q4")
}
with TPCDSQueryTestSuite {
override def sparkConf: SparkConf = {
Expand Down

0 comments on commit 6cb63ed

Please sign in to comment.