Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 3, 2024
1 parent 1b6d9c5 commit 9e16b5a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 87 deletions.
81 changes: 1 addition & 80 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@

package org.apache.comet.vector

import java.io.OutputStream
import java.nio.channels.Channels

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data}
Expand All @@ -32,84 +28,12 @@ import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.spark.SparkException
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.CometArrowStreamWriter

class NativeUtil {
private val allocator = new RootAllocator(Long.MaxValue)
private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider
private val importer = new ArrowImporter(allocator)

/**
* Serializes a list of `ColumnarBatch` into an output stream.
*
* @param batches
* the output batches, each batch is a list of Arrow vectors wrapped in `CometVector`
* @param out
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = {
var writer: Option[CometArrowStreamWriter] = None
var rowCount = 0
var prevProvider: Option[DictionaryProvider] = None

batches.zipWithIndex.foreach { case (batch, idx) =>
// scalastyle:off println
println(s"serializeBatches (idx: $idx): batch.numCols: ${batch.numCols()}")
for (i <- 0 until batch.numCols()) {
batch.column(i) match {
case a: CometPlainVector =>
val valueVector = a.getValueVector
println(s"serializeBatches: valueVector: $valueVector")

case a: CometDictionaryVector =>
val indices = a.indices
val dictionary = a.values
println(s"serializeBatches: indices: ${indices.getValueVector}")
println(s"serializeBatches: dictionary: ${dictionary.getValueVector}")

val dictId = indices.getValueVector.getField.getDictionary.getId
println(s"serializeBatches: dictionary dictId: $dictId")
case _ =>

}
}

val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
val root = new VectorSchemaRoot(fieldVectors.asJava)
if (prevProvider.isDefined && prevProvider.get !=
batchProviderOpt.getOrElse(dictionaryProvider)) {
throw new SparkException(
"Comet execution only takes Arrow Arrays with the same dictionary provider")
} else {
prevProvider = batchProviderOpt
}
val provider = batchProviderOpt.getOrElse(dictionaryProvider)

// scalastyle:off println
println(s"serializeBatches (idx: $idx): provider: ${provider.getDictionaryIds}")

for (id <- provider.getDictionaryIds.asScala) {
val dictionary = provider.lookup(id)
val vector = dictionary.getVector()
println(s"serializeBatches (idx: $idx): dictionary id: $id, value: $vector")
}

if (writer.isEmpty) {
writer = Some(new CometArrowStreamWriter(root, provider, Channels.newChannel(out)))
writer.get.start()
writer.get.writeBatch()
} else {
writer.get.writeMoreBatch(root)
}

root.clear()
rowCount += batch.numRows()
}

writer.map(_.end())

rowCount
}
def getDictionaryProvider: DictionaryProvider = dictionaryProvider

def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
Expand All @@ -119,9 +43,6 @@ class NativeUtil {
case a: CometVector =>
val valueVector = a.getValueVector
if (valueVector.getField.getDictionary != null) {
// scalastyle:off println
// println(s"Dictionary is not null: $valueVector")

if (provider.isEmpty) {
provider = Some(a.getDictionaryProvider)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
@transient
private lazy val maxBroadcastRows = 512000000

private lazy val childRDD = child.asInstanceOf[CometExec].executeColumnar()

@transient
override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
Expand Down Expand Up @@ -191,7 +193,7 @@ case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]()

new CometBatchRDD(sparkContext, broadcasted.value.length, broadcasted)
new CometBatchRDD(sparkContext, childRDD.getNumPartitions, broadcasted)
}

override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
Expand Down
39 changes: 33 additions & 6 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

package org.apache.spark.sql.comet

import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream, OutputStream}
import java.nio.ByteBuffer
import java.nio.channels.Channels

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -77,20 +80,44 @@ abstract class CometExec extends CometPlan {
*/
def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = {
executeColumnar().mapPartitionsInternal { iter =>
serializeBatches(iter)
}
}

/**
* Serializes a list of `ColumnarBatch` into an output stream.
*
* @param batches
* the output batches, each batch is a list of Arrow vectors wrapped in `CometVector`
* @param out
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch]): Iterator[(Long, ChunkedByteBuffer)] = {
val nativeUtil = new NativeUtil()

batches.map { batch =>
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
val out = new DataOutputStream(codec.compressedOutputStream(cbbos))

// scalastyle:off println
println(s"getByteArrayRdd: $this")
val count = new NativeUtil().serializeBatches(iter, out)
val (fieldVectors, batchProviderOpt) = nativeUtil.getBatchFieldVectors(batch)
val root = new VectorSchemaRoot(fieldVectors.asJava)
val provider = batchProviderOpt.getOrElse(nativeUtil.getDictionaryProvider)

val writer = new ArrowStreamWriter(root, provider, Channels.newChannel(out))
writer.start()
writer.writeBatch()

root.clear()
writer.end()

out.flush()
out.close()

if (out.size() > 0) {
Iterator((count, cbbos.toChunkedByteBuffer))
(batch.numRows(), cbbos.toChunkedByteBuffer)
} else {
Iterator((count, new ChunkedByteBuffer(Array.empty[ByteBuffer])))
(batch.numRows(), new ChunkedByteBuffer(Array.empty[ByteBuffer]))
}
}
}
Expand Down

0 comments on commit 9e16b5a

Please sign in to comment.