diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index c1d63299e..c57c00279 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -74,10 +74,38 @@ class CometSparkSessionExtensions override def apply(extensions: SparkSessionExtensions): Unit = { extensions.injectColumnar { session => CometScanColumnar(session) } extensions.injectColumnar { session => CometExecColumnar(session) } + extensions.injectColumnar { session => CometColumnarOverrideRules(session) } extensions.injectQueryStagePrepRule { session => CometScanRule(session) } extensions.injectQueryStagePrepRule { session => CometExecRule(session) } } + case class CometColumnarOverrideRules(session: SparkSession) extends ColumnarRule { + override def preColumnarTransitions: Rule[SparkPlan] = CometPreColumnarTransitions() + override def postColumnarTransitions: Rule[SparkPlan] = CometPostColumnarTransitions() + } + + case class CometPreColumnarTransitions() extends Rule[SparkPlan] { + override def apply(sparkPlan: SparkPlan): SparkPlan = { + sparkPlan + } + } + + /** Replace ColumnarToRowExec with CometColumnarToRowExec for CometExec inputs */ + case class CometPostColumnarTransitions() extends Rule[SparkPlan] { + override def apply(sparkPlan: SparkPlan): SparkPlan = { + sparkPlan.transformUp { + case ColumnarToRowExec(child: CometScanExec) => + CometColumnarToRowExec(child) + case ColumnarToRowExec(InputAdapter(child: CometScanExec)) => + CometColumnarToRowExec(InputAdapter(child)) + case ColumnarToRowExec(child: CometExec) => + CometColumnarToRowExec(child) + case ColumnarToRowExec(InputAdapter(child: CometExec)) => + CometColumnarToRowExec(InputAdapter(child)) + } + } + } + case class CometScanColumnar(session: SparkSession) extends ColumnarRule { override def preColumnarTransitions: Rule[SparkPlan] = CometScanRule(session) } @@ -1072,7 +1100,12 @@ class CometSparkSessionExtensions case class EliminateRedundantTransitions(session: SparkSession) extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { val eliminatedPlan = plan transformUp { - case ColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) => sparkToColumnar.child + case CometColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) + if !sparkToColumnar.child.supportsColumnar => + sparkToColumnar.child + case ColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) + if !sparkToColumnar.child.supportsColumnar => + sparkToColumnar.child case CometSparkToColumnarExec(child: CometSparkToColumnarExec) => child // Spark adds `RowToColumnar` under Comet columnar shuffle. But it's redundant as the // shuffle takes row-based input. @@ -1087,6 +1120,8 @@ class CometSparkSessionExtensions } eliminatedPlan match { + case CometColumnarToRowExec(child: CometCollectLimitExec) => + child case ColumnarToRowExec(child: CometCollectLimitExec) => child case other => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBatchScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBatchScanExec.scala index dc1f2db8b..8fae122b8 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBatchScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBatchScanExec.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, DynamicPruningExpre import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.read._ -import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.execution.metric._ import org.apache.spark.sql.vectorized._ @@ -73,11 +72,11 @@ case class CometBatchScanExec(wrapped: BatchScanExec, runtimeFilters: Seq[Expres // `ReusedSubqueryExec` in Spark only call non-columnar execute. override def doExecute(): RDD[InternalRow] = { - ColumnarToRowExec(this).doExecute() + CometColumnarToRowExec(this).doExecute() } override def executeCollect(): Array[InternalRow] = { - ColumnarToRowExec(this).executeCollect() + CometColumnarToRowExec(this).executeCollect() } override def readerFactory: PartitionReaderFactory = wrappedScan.readerFactory diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala new file mode 100644 index 000000000..31f7f6ba7 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.JavaConverters._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.{CodegenSupport, ColumnarToRowTransition, SparkPlan} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.util.Utils + +/** + * This is currently an identical copy of Spark's ColumnarToRowExec except for removing the + * code-gen features. + * + * This is moved into the Comet repo as the first step towards refactoring this to make the + * interactions with CometVector more efficient to avoid some JNI overhead. + */ +case class CometColumnarToRowExec(child: SparkPlan) + extends CometExec + with ColumnarToRowTransition + with CodegenSupport { + // supportsColumnar requires to be only called on driver side, see also SPARK-37779. + assert(Utils.isInRunningSparkTask || child.supportsColumnar) + + override def supportsColumnar: Boolean = false + + override def originalPlan: SparkPlan = child + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + // `ColumnarToRowExec` processes the input RDD directly, which is kind of a leaf node in the + // codegen stage and needs to do the limit check. + protected override def canCheckLimitNotReached: Boolean = true + + private val prefetchTime: SQLMetric = + SQLMetrics.createNanoTimingMetric(sparkContext, "time to prefetch vectors") + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"), + "prefetchTime" -> prefetchTime) + + override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + val numInputBatches = longMetric("numInputBatches") + // This avoids calling `output` in the RDD closure, so that we don't need to include the entire + // plan (this) in the closure. + val localOutput = this.output + child.executeColumnar().mapPartitionsInternal { batches => + val toUnsafe = UnsafeProjection.create(localOutput, localOutput) + batches.flatMap { batch => + numInputBatches += 1 + numOutputRows += batch.numRows() + + // This is the original Spark code that creates an iterator over `ColumnarBatch` + // to provide `Iterator[InternalRow]`. The implementation uses a `ColumnarBatchRow` + // instance that contains an array of `ColumnVector` which will be instances of + // `CometVector`, which in turn is a wrapper around Arrow's `ValueVector`. + batch.rowIterator().asScala.map(toUnsafe) + } + } + } + + /** + * Generate [[ColumnVector]] expressions for our parent to consume as rows. This is called once + * per [[ColumnVector]] in the batch. + */ + private def genCodeColumnVector( + ctx: CodegenContext, + columnVar: String, + ordinal: String, + dataType: DataType, + nullable: Boolean): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) + val isNullVar = if (nullable) { + JavaCode.isNullVariable(ctx.freshName("isNull")) + } else { + FalseLiteral + } + val valueVar = ctx.freshName("value") + val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" + val code = code"${ctx.registerComment(str)}" + (if (nullable) { + code""" + boolean $isNullVar = $columnVar.isNullAt($ordinal); + $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); + """ + } else { + code"$javaType $valueVar = $value;" + }) + ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) + } + + /** + * Produce code to process the input iterator as [[ColumnarBatch]]es. This produces an + * [[org.apache.spark.sql.catalyst.expressions.UnsafeRow]] for each row in each batch. + */ + override protected def doProduce(ctx: CodegenContext): String = { + // PhysicalRDD always just has one input + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") + + // metrics + val numOutputRows = metricTerm(ctx, "numOutputRows") + val numInputBatches = metricTerm(ctx, "numInputBatches") + + val columnarBatchClz = classOf[ColumnarBatch].getName + val batch = ctx.addMutableState(columnarBatchClz, "batch") + + val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0 + val columnVectorClzs = + child.vectorTypes.getOrElse(Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) + val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { + case (columnVectorClz, i) => + val colVarName = s"colInstance$i" + val name = ctx.addMutableState(columnVectorClz, colVarName) + (name, s"$name = ($columnVectorClz) $batch.column($i);") + }.unzip + + val nextBatch = ctx.freshName("nextBatch") + val nextBatchFuncName = ctx.addNewFunction( + nextBatch, + s""" + |private void $nextBatch() throws java.io.IOException { + | if ($input.hasNext()) { + | $batch = ($columnarBatchClz)$input.next(); + | $numInputBatches.add(1); + | $numOutputRows.add($batch.numRows()); + | $idx = 0; + | ${columnAssigns.mkString("", "\n", "\n")} + | } + |}""".stripMargin) + + ctx.currentVars = null + val rowidx = ctx.freshName("rowIdx") + val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => + genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) + } + val localIdx = ctx.freshName("localIdx") + val localEnd = ctx.freshName("localEnd") + val numRows = ctx.freshName("numRows") + val shouldStop = if (parent.needStopCheck) { + s"if (shouldStop()) { $idx = $rowidx + 1; return; }" + } else { + "// shouldStop check is eliminated" + } + s""" + |if ($batch == null) { + | $nextBatchFuncName(); + |} + |while ($limitNotReachedCond $batch != null) { + | int $numRows = $batch.numRows(); + | int $localEnd = $numRows - $idx; + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | int $rowidx = $idx + $localIdx; + | ${consume(ctx, columnsBatchInput).trim} + | $shouldStop + | } + | $idx = $numRows; + | $batch = null; + | $nextBatchFuncName(); + |} + """.stripMargin + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + Seq(child.executeColumnar().asInstanceOf[RDD[InternalRow]]) // Hack because of type erasure + } + + override protected def withNewChildInternal(newChild: SparkPlan): CometColumnarToRowExec = + copy(child = newChild) +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala index 49f7694bc..79c611027 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala @@ -235,7 +235,7 @@ case class CometScanExec( } protected override def doExecute(): RDD[InternalRow] = { - ColumnarToRowExec(this).doExecute() + CometColumnarToRowExec(this).doExecute() } protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { @@ -262,7 +262,7 @@ case class CometScanExec( } override def executeCollect(): Array[InternalRow] = { - ColumnarToRowExec(this).executeCollect() + CometColumnarToRowExec(this).executeCollect() } override val nodeName: String = diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index dd1526d82..76464710c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec} import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode import org.apache.spark.sql.comet.util.Utils -import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{BinaryExecNode, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -65,10 +65,10 @@ abstract class CometExec extends CometPlan { override def output: Seq[Attribute] = originalPlan.output override def doExecute(): RDD[InternalRow] = - ColumnarToRowExec(this).doExecute() + CometColumnarToRowExec(this).doExecute() override def executeCollect(): Array[InternalRow] = - ColumnarToRowExec(this).executeCollect() + CometColumnarToRowExec(this).executeCollect() override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 0d00867d1..7b74efb53 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -28,8 +28,8 @@ import scala.util.Random import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps -import org.apache.spark.sql.comet.CometProjectExec -import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, ProjectExec, WholeStageCodegenExec} +import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec} +import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -752,7 +752,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val project = cometPlan .asInstanceOf[WholeStageCodegenExec] .child - .asInstanceOf[ColumnarToRowExec] + .asInstanceOf[CometColumnarToRowExec] .child .asInstanceOf[InputAdapter] .child diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 1709cce61..4ff5acfb6 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -36,7 +36,7 @@ import org.apache.parquet.hadoop.example.ExampleParquetWriter import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark._ import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER} -import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder, CometSparkToColumnarExec} +import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometColumnarToRowExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder, CometSparkToColumnarExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, ExtendedMode, InputAdapter, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -174,6 +174,7 @@ abstract class CometTestBase wrapped.foreach { case _: CometScanExec | _: CometBatchScanExec => case _: CometSinkPlaceHolder | _: CometScanWrapper => + case _: CometColumnarToRowExec => case _: CometSparkToColumnarExec => case _: CometExec | _: CometShuffleExchangeExec => case _: CometBroadcastExchangeExec =>