diff --git a/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java b/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java new file mode 100644 index 0000000000..a492ce887d --- /dev/null +++ b/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; + +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.compression.NoCompressionCodec; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; + +/** + * A custom `ArrowStreamWriter` that allows writing batches from different root to the same stream. + * Arrow `ArrowStreamWriter` cannot change the root after initialization. + */ +public class CometArrowStreamWriter extends ArrowStreamWriter { + public CometArrowStreamWriter( + VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { + super(root, provider, out); + } + + public void writeMoreBatch(VectorSchemaRoot root) throws IOException { + VectorUnloader unloader = + new VectorUnloader( + root, /*includeNullCount*/ true, NoCompressionCodec.INSTANCE, /*alignBuffers*/ true); + + try (ArrowRecordBatch batch = unloader.getRecordBatch()) { + writeRecordBatch(batch); + } + } +} diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 1682295f72..cc726e3e8b 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -20,6 +20,7 @@ package org.apache.comet.vector import java.io.OutputStream +import java.nio.channels.Channels import scala.collection.JavaConverters._ import scala.collection.mutable @@ -28,10 +29,11 @@ import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictiona import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector._ import org.apache.arrow.vector.dictionary.DictionaryProvider -import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.SparkException import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.comet.CometArrowStreamWriter + class NativeUtil { private val allocator = new RootAllocator(Long.MaxValue) private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider @@ -46,29 +48,27 @@ class NativeUtil { * the output stream */ def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = { - var schemaRoot: Option[VectorSchemaRoot] = None - var writer: Option[ArrowStreamWriter] = None + var writer: Option[CometArrowStreamWriter] = None var rowCount = 0 batches.foreach { batch => val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch) - val root = schemaRoot.getOrElse(new VectorSchemaRoot(fieldVectors.asJava)) + val root = new VectorSchemaRoot(fieldVectors.asJava) val provider = batchProviderOpt.getOrElse(dictionaryProvider) if (writer.isEmpty) { - writer = Some(new ArrowStreamWriter(root, provider, out)) + writer = Some(new CometArrowStreamWriter(root, provider, Channels.newChannel(out))) writer.get.start() + writer.get.writeBatch() + } else { + writer.get.writeMoreBatch(root) } - writer.get.writeBatch() root.clear() - schemaRoot = Some(root) - rowCount += batch.numRows() } writer.map(_.end()) - schemaRoot.map(_.close()) rowCount } diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala index e8dba93e70..3922912443 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala @@ -23,7 +23,7 @@ import java.nio.channels.ReadableByteChannel import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.vector.StreamReader +import org.apache.comet.vector._ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] { @@ -50,13 +50,6 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna val nextBatch = batch.get - // Release the previous batch. - // If it is not released, when closing the reader, arrow library will complain about - // memory leak. - if (currentBatch != null) { - currentBatch.close() - } - currentBatch = nextBatch batch = None currentBatch diff --git a/core/src/execution/operators/scan.rs b/core/src/execution/operators/scan.rs index e31230c583..9581ac4054 100644 --- a/core/src/execution/operators/scan.rs +++ b/core/src/execution/operators/scan.rs @@ -26,7 +26,7 @@ use futures::Stream; use itertools::Itertools; use arrow::compute::{cast_with_options, CastOptions}; -use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions}; +use arrow_array::{make_array, Array, ArrayRef, RecordBatch, RecordBatchOptions}; use arrow_data::ArrayData; use arrow_schema::{DataType, Field, Schema, SchemaRef}; diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 7eca995579..09d92edd5c 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -31,14 +31,14 @@ import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -359,6 +359,27 @@ class CometSparkSessionExtensions op } + case op: BroadcastHashJoinExec + if isCometOperatorEnabled(conf, "broadcast_hash_join") && + op.children.forall(isCometNative(_)) => + val newOp = transform1(op) + newOp match { + case Some(nativeOp) => + CometBroadcastHashJoinExec( + nativeOp, + op, + op.leftKeys, + op.rightKeys, + op.joinType, + op.condition, + op.buildSide, + op.left, + op.right, + SerializedPlan(None)) + case None => + op + } + case c @ CoalesceExec(numPartitions, child) if isCometOperatorEnabled(conf, "coalesce") && isCometNative(child) => @@ -394,6 +415,16 @@ class CometSparkSessionExtensions u } + // For AQE broadcast stage on a Comet broadcast exchange + case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => + val newOp = transform1(s) + newOp match { + case Some(nativeOp) => + CometSinkPlaceHolder(nativeOp, s, s) + case None => + s + } + case b: BroadcastExchangeExec if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") && isCometBroadCastEnabled(conf) => diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b0692ad8e2..e1a348d83e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -29,14 +29,14 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils -import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometSinkPlaceHolder, DecimalPrecision} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1840,7 +1840,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } - case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf, "hash_join") => + case join: HashJoin => + // `HashJoin` has only two implementations in Spark, but we check the type of the join to + // make sure we are handling the correct join type. + if (!(isCometOperatorEnabled(op.conf, "hash_join") && + join.isInstanceOf[ShuffledHashJoinExec]) && + !(isCometOperatorEnabled(op.conf, "broadcast_hash_join") && + join.isInstanceOf[BroadcastHashJoinExec])) { + return None + } + if (join.buildSide == BuildRight) { // DataFusion HashJoin assumes build side is always left. // TODO: support BuildRight @@ -1932,6 +1941,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true case _: TakeOrderedAndProjectExec => true + case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true case _: BroadcastExchangeExec => true case _ => false } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index f115b2ad92..e4e3278867 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -31,6 +31,7 @@ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike @@ -43,6 +44,7 @@ import org.apache.spark.util.io.ChunkedByteBuffer import com.google.common.base.Objects import org.apache.comet.shims.ShimCometBroadcastExchangeExec +import org.apache.comet.vector._ /** * A [[CometBroadcastExchangeExec]] collects, transforms and finally broadcasts the result of a @@ -257,9 +259,28 @@ class CometBatchRDD( } override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - val partition = split.asInstanceOf[CometBatchPartition] + new Iterator[ColumnarBatch] { + val partition = split.asInstanceOf[CometBatchPartition] + val batchesIter = partition.value.value.map(CometExec.decodeBatches(_)).toIterator + var iter: Iterator[ColumnarBatch] = null + + override def hasNext: Boolean = { + if (iter != null) { + if (iter.hasNext) { + return true + } + } + if (batchesIter.hasNext) { + iter = batchesIter.next() + return iter.hasNext + } + false + } - partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator + override def next(): ColumnarBatch = { + iter.next() + } + } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 2d03b2e65d..4cf55d4acb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf @@ -269,7 +269,8 @@ abstract class CometNativeExec extends CometExec { plan match { case _: CometScanExec | _: CometBatchScanExec | _: ShuffleQueryStageExec | _: AQEShuffleReadExec | _: CometShuffleExchangeExec | _: CometUnionExec | - _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec => + _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec | + _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec => func(plan) case _: CometPlan => // Other Comet operators, continue to traverse the tree. @@ -625,6 +626,43 @@ case class CometHashJoinExec( Objects.hashCode(leftKeys, rightKeys, condition, left, right) } +case class CometBroadcastHashJoinExec( + override val nativeOp: Operator, + override val originalPlan: SparkPlan, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + buildSide: BuildSide, + override val left: SparkPlan, + override val right: SparkPlan, + override val serializedPlanOpt: SerializedPlan) + extends CometBinaryExec { + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = + this.copy(left = newLeft, right = newRight) + + override def stringArgs: Iterator[Any] = + Iterator(leftKeys, rightKeys, joinType, condition, left, right) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometBroadcastHashJoinExec => + this.leftKeys == other.leftKeys && + this.rightKeys == other.rightKeys && + this.condition == other.condition && + this.buildSide == other.buildSide && + this.left == other.left && + this.right == other.right && + this.serializedPlanOpt == other.serializedPlanOpt + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(leftKeys, rightKeys, condition, left, right) +} + case class CometScanWrapper(override val nativeOp: Operator, override val originalPlan: SparkPlan) extends CometNativeExec with LeafExecNode { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 72e67bf5ce..6660a8e150 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec @@ -58,6 +58,66 @@ class CometExecSuite extends CometTestBase { } } + test("Broadcast HashJoin without join filter") { + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> "100", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + "spark.comet.exec.broadcast.enabled" -> "true", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator( + df1, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // Right join: build left + val df2 = + sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator( + df2, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + } + } + } + } + + test("Broadcast HashJoin with join filter") { + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> "100", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + "spark.comet.exec.broadcast.enabled" -> "true", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql( + "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator( + df1, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // Right join: build left + val df2 = + sql( + "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator( + df2, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + } + } + } + } + test("HashJoin without join filter") { withSQLConf( SQLConf.PREFER_SORTMERGEJOIN.key -> "false",