From f5f1e2be578ad40daafe25c6cc1b09bb4f8bb71a Mon Sep 17 00:00:00 2001 From: Nong Li Date: Thu, 3 Mar 2016 20:40:14 -0800 Subject: [PATCH] Fix batching. --- .../spark/sql/execution/ExistingRDD.scala | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index c9ab466dda796..50672beb0b90e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -139,9 +139,14 @@ private[sql] case class PhysicalRDD( // Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen // never requires UnsafeRow as input. override protected def doProduce(ctx: CodegenContext): String = { + val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" val input = ctx.freshName("input") + val idx = ctx.freshName("batchIdx") + val batch = ctx.freshName("batch") // PhysicalRDD always just has one input ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") + ctx.addMutableState("int", idx, s"$idx = 0;") val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") @@ -156,27 +161,28 @@ private[sql] case class PhysicalRDD( // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know // here which path to use. Fix this. - val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" val scanBatches = ctx.freshName("processBatches") ctx.addNewFunction(scanBatches, s""" - | private void $scanBatches($columnarBatchClz batch) throws java.io.IOException { - | int batchIdx = 0; + | private void $scanBatches() throws java.io.IOException { | while (true) { - | int numRows = batch.numRows(); - | if (batchIdx == 0) $numOutputRows.add(numRows); + | int numRows = $batch.numRows(); + | if ($idx == 0) $numOutputRows.add(numRows); | - | while (batchIdx < numRows) { - | InternalRow $row = batch.getRow(batchIdx++); + | while ($idx < numRows) { + | InternalRow $row = $batch.getRow($idx++); | ${columns.map(_.code).mkString("\n").trim} | ${consume(ctx, columns).trim} | if (shouldStop()) return; | } | - | if (!$input.hasNext()) break; - | batch = ($columnarBatchClz)$input.next(); - | batchIdx = 0; + | if (!$input.hasNext()) { + | $batch = null; + | break; + | } + | $batch = ($columnarBatchClz)$input.next(); + | $idx = 0; | } | }""".stripMargin) @@ -195,12 +201,17 @@ private[sql] case class PhysicalRDD( | }""".stripMargin) s""" - | if ($input.hasNext()) { - | Object firstValue = $input.next(); - | if (firstValue instanceof $columnarBatchClz) { - | $scanBatches(($columnarBatchClz)firstValue); + | if ($batch != null || $input.hasNext()) { + | if ($batch == null) { + | Object value = $input.next(); + | if (value instanceof $columnarBatchClz) { + | $batch = ($columnarBatchClz)value; + | $scanBatches(); + | } else { + | $scanRows((InternalRow)value); + | } | } else { - | $scanRows((InternalRow)firstValue); + | $scanBatches(); | } | } """.stripMargin