Skip to content

Commit

Permalink
feat: Enable Comet broadcast by default (apache#213)
Browse files Browse the repository at this point in the history
* feat: Remove COMET_EXEC_BROADCAST_ENABLED

* Fix

* Fix

* Update plan stability

* Fix

* Remove unused import and class

* Fix

* Remove unused imports

* Fix

* Fix scala style

* fix

* Fix

* Update diff
  • Loading branch information
viirya authored and Steve Vaughan Jr committed Apr 9, 2024
1 parent cd091f8 commit 6b9c38c
Show file tree
Hide file tree
Showing 31 changed files with 2,467 additions and 2,192 deletions.
51 changes: 0 additions & 51 deletions common/src/main/java/org/apache/comet/CometArrowStreamWriter.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ protected CometDecodedVector(ValueVector vector, Field valueField, boolean useDe
}

@Override
ValueVector getValueVector() {
public ValueVector getValueVector() {
return valueVector;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public ColumnVector getChild(int i) {
}

@Override
ValueVector getValueVector() {
public ValueVector getValueVector() {
return delegate.getValueVector();
}

Expand All @@ -163,7 +163,7 @@ public CometVector slice(int offset, int length) {
}

@Override
DictionaryProvider getDictionaryProvider() {
public DictionaryProvider getDictionaryProvider() {
return delegate.getDictionaryProvider();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public CometDictionaryVector(
}

@Override
DictionaryProvider getDictionaryProvider() {
public DictionaryProvider getDictionaryProvider() {
return this.provider;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ public byte[] getBinary(int rowId) {
}

@Override
CDataDictionaryProvider getDictionaryProvider() {
public CDataDictionaryProvider getDictionaryProvider() {
return null;
}

Expand Down
4 changes: 2 additions & 2 deletions common/src/main/java/org/apache/comet/vector/CometVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ public void close() {
getValueVector().close();
}

DictionaryProvider getDictionaryProvider() {
public DictionaryProvider getDictionaryProvider() {
throw new UnsupportedOperationException("Not implemented");
}

abstract ValueVector getValueVector();
public abstract ValueVector getValueVector();

/**
* Returns a zero-copying new vector that contains the values from [offset, offset + length).
Expand Down
9 changes: 5 additions & 4 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,13 @@ object CometConf {
.booleanConf
.createWithDefault(false)

val COMET_EXEC_BROADCAST_ENABLED: ConfigEntry[Boolean] =
val COMET_EXEC_BROADCAST_FORCE_ENABLED: ConfigEntry[Boolean] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.broadcast.enabled")
.doc(
"Whether to enable broadcasting for Comet native operators. By default, " +
"this config is false. Note that this feature is not fully supported yet " +
"and only enabled for test purpose.")
"Whether to force enabling broadcasting for Comet native operators. By default, " +
"this config is false. Comet broadcast feature will be enabled automatically by " +
"Comet extension. But for unit tests, we need this feature to force enabling it " +
"for invalid cases. So this config is only used for unit test.")
.booleanConf
.createWithDefault(false)

Expand Down
85 changes: 3 additions & 82 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,89 +19,21 @@

package org.apache.comet.vector

import java.io.OutputStream
import java.nio.channels.Channels

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data}
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.spark.SparkException
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.CometArrowStreamWriter

class NativeUtil {
import Utils._

private val allocator = new RootAllocator(Long.MaxValue)
private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider
private val importer = new ArrowImporter(allocator)

/**
* Serializes a list of `ColumnarBatch` into an output stream.
*
* @param batches
* the output batches, each batch is a list of Arrow vectors wrapped in `CometVector`
* @param out
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = {
var writer: Option[CometArrowStreamWriter] = None
var rowCount = 0

batches.foreach { batch =>
val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
val root = new VectorSchemaRoot(fieldVectors.asJava)
val provider = batchProviderOpt.getOrElse(dictionaryProvider)

if (writer.isEmpty) {
writer = Some(new CometArrowStreamWriter(root, provider, Channels.newChannel(out)))
writer.get.start()
writer.get.writeBatch()
} else {
writer.get.writeMoreBatch(root)
}

root.clear()
rowCount += batch.numRows()
}

writer.map(_.end())

rowCount
}

def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
var provider: Option[DictionaryProvider] = None
val fieldVectors = (0 until batch.numCols()).map { index =>
batch.column(index) match {
case a: CometVector =>
val valueVector = a.getValueVector
if (valueVector.getField.getDictionary != null) {
if (provider.isEmpty) {
provider = Some(a.getDictionaryProvider)
} else {
if (provider.get != a.getDictionaryProvider) {
throw new SparkException(
"Comet execution only takes Arrow Arrays with the same dictionary provider")
}
}
}

getFieldVector(valueVector)

case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}
}
(fieldVectors, provider)
}

/**
* Exports a Comet `ColumnarBatch` into a list of memory addresses that can be consumed by the
* native execution.
Expand Down Expand Up @@ -199,15 +131,4 @@ class NativeUtil {

new ColumnarBatch(arrayVectors.toArray, maxNumRows)
}

private def getFieldVector(valueVector: ValueVector): FieldVector = {
valueVector match {
case v @ (_: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector |
_: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector |
_: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector |
_: FixedSizeBinaryVector | _: TimeStampMicroVector) =>
v.asInstanceOf[FieldVector]
case _ => throw new SparkException(s"Unsupported Arrow Vector: ${valueVector.getClass}")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ case class StreamReader(channel: ReadableByteChannel) extends AutoCloseable {
// Native shuffle always uses decimal128.
CometVector.getVector(vector, true, arrowReader).asInstanceOf[ColumnVector]
}.toArray

val batch = new ColumnarBatch(columns)
batch.setNumRows(root.getRowCount)
batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,26 @@

package org.apache.spark.sql.comet.util

import java.io.File
import java.io.{DataOutputStream, File}
import java.nio.ByteBuffer
import java.nio.channels.Channels

import scala.collection.JavaConverters._

import org.apache.arrow.c.CDataDictionaryProvider
import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot}
import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.arrow.vector.types._
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

import org.apache.comet.vector.CometVector

object Utils {
def getConfPath(confFileName: String): String = {
Expand Down Expand Up @@ -161,4 +173,79 @@ object Utils {
toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
}.asJava)
}

/**
* Serializes a list of `ColumnarBatch` into an output stream. This method must be in `spark`
* package because `ChunkedByteBufferOutputStream` is spark private class. As it uses Arrow
* classes, it must be in `common` module.
*
* @param batches
* the output batches, each batch is a list of Arrow vectors wrapped in `CometVector`
* @param out
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch]): Iterator[(Long, ChunkedByteBuffer)] = {
batches.map { batch =>
val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider

val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
val out = new DataOutputStream(codec.compressedOutputStream(cbbos))

val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
val root = new VectorSchemaRoot(fieldVectors.asJava)
val provider = batchProviderOpt.getOrElse(dictionaryProvider)

val writer = new ArrowStreamWriter(root, provider, Channels.newChannel(out))
writer.start()
writer.writeBatch()

root.clear()
writer.end()

out.flush()
out.close()

if (out.size() > 0) {
(batch.numRows(), cbbos.toChunkedByteBuffer)
} else {
(batch.numRows(), new ChunkedByteBuffer(Array.empty[ByteBuffer]))
}
}
}

def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
var provider: Option[DictionaryProvider] = None
val fieldVectors = (0 until batch.numCols()).map { index =>
batch.column(index) match {
case a: CometVector =>
val valueVector = a.getValueVector
if (valueVector.getField.getDictionary != null) {
if (provider.isEmpty) {
provider = Some(a.getDictionaryProvider)
}
}

getFieldVector(valueVector)

case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}
}
(fieldVectors, provider)
}

def getFieldVector(valueVector: ValueVector): FieldVector = {
valueVector match {
case v @ (_: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector |
_: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector |
_: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector |
_: FixedSizeBinaryVector | _: TimeStampMicroVector) =>
v.asInstanceOf[FieldVector]
case _ => throw new SparkException(s"Unsupported Arrow Vector: ${valueVector.getClass}")
}
}
}
23 changes: 21 additions & 2 deletions dev/diffs/3.4.2.diff
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ index 9ddb4abe98b..1bebe99f1cc 100644
sql(
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
index f33432ddb6f..fe9f74ff8f1 100644
index f33432ddb6f..6160c8d241a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen
Expand All @@ -270,7 +270,26 @@ index f33432ddb6f..fe9f74ff8f1 100644
case _ => Nil
}
}
@@ -1729,6 +1733,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat
@@ -1238,7 +1242,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}

- test("Plan broadcast pruning only when the broadcast can be reused") {
+ test("Plan broadcast pruning only when the broadcast can be reused",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
Given("dynamic pruning filter on the build side")
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
val df = sql(
@@ -1485,7 +1490,7 @@ abstract class DynamicPartitionPruningSuiteBase
}

test("SPARK-38148: Do not add dynamic partition pruning if there exists static partition " +
- "pruning") {
+ "pruning", IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
Seq(
"f.store_id = 1" -> false,
@@ -1729,6 +1734,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat
case s: BatchScanExec =>
// we use f1 col for v2 tables due to schema pruning
s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1")))
Expand Down
Loading

0 comments on commit 6b9c38c

Please sign in to comment.