Skip to content

Commit

Permalink
[KYUUBI #4662] [ARROW] Arrow serialization should not introduce extra…
Browse files Browse the repository at this point in the history
… shuffle for outermost limit

### _Why are the changes needed?_

The fundamental concept is to execute a job similar to the way in which `CollectLimitExec.executeCollect()` operates.

```sql
select * from parquet.`parquet/tpcds/sf1000/catalog_sales` limit 20;
```

Before this PR:
![截屏2023-04-04 下午3 20 34](https://user-images.githubusercontent.com/8537877/229717946-87c480c6-9550-4d00-bc96-14d59d7ce9f7.png)

![截屏2023-04-04 下午3 20 54](https://user-images.githubusercontent.com/8537877/229717973-bf6da5af-74e7-422a-b9fa-8b7bebd43320.png)

After this PR:

![截屏2023-04-04 下午3 17 05](https://user-images.githubusercontent.com/8537877/229718016-6218d019-b223-4deb-b596-6f0431d33d2a.png)

![截屏2023-04-04 下午3 17 16](https://user-images.githubusercontent.com/8537877/229718046-ea07cd1f-5ffc-42ba-87d5-08085feb4b16.png)

### _How was this patch tested?_
- [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #4662 from cfmcgrady/arrow-collect-limit-exec-2.

Closes #4662

82c912e [Fu Chen] close vector
130bcb1 [Fu Chen] finally close
facc13f [Fu Chen] exclude rule OptimizeLimitZero
3700839 [Fu Chen] SparkArrowbasedOperationSuite adapt Spark-3.1.x
6064ab9 [Fu Chen] limit = 0 test case
6d596fc [Fu Chen] address comment
8280783 [Fu Chen] add `isStaticConfigKey` to adapt Spark-3.1.x
22cc70f [Fu Chen] add ut
b72bc6f [Fu Chen] add offset support to adapt Spark-3.4.x
9ffb44f [Fu Chen] make toBatchIterator private
c83cf3f [Fu Chen] SparkArrowbasedOperationSuite adapt Spark-3.1.x
573a262 [Fu Chen] fix
4cef204 [Fu Chen] SparkArrowbasedOperationSuite adapt Spark-3.1.x
d70aee3 [Fu Chen] SparkPlan.session -> SparkSession.active to adapt Spark-3.1.x
e3bf84c [Fu Chen] refactor
81886f0 [Fu Chen] address comment
2286afc [Fu Chen] reflective calla AdaptiveSparkPlanExec.finalPhysicalPlan
03d0747 [Fu Chen] address comment
25e4f05 [Fu Chen] add docs
885cf2c [Fu Chen] infer row size by schema.defaultSize
4e7ca54 [Fu Chen] unnecessarily changes
ee5a756 [Fu Chen] revert unnecessarily changes
6c5b1eb [Fu Chen] add ut
4212a89 [Fu Chen] refactor and add ut
ed8c692 [Fu Chen] refactor
0088671 [Fu Chen] refine
8593d85 [Fu Chen] driver slice last batch
a584943 [Fu Chen] arrow take

Authored-by: Fu Chen <cfmcgrady@gmail.com>
Signed-off-by: ulyssesyou <ulyssesyou@apache.org>
  • Loading branch information
cfmcgrady authored and ulysses-you committed Apr 10, 2023
1 parent 5faebb1 commit 1a65125
Show file tree
Hide file tree
Showing 7 changed files with 753 additions and 33 deletions.
7 changes: 7 additions & 0 deletions externals/kyuubi-spark-sql-engine/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<type>test-jar</type>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-repl_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ import java.util.concurrent.RejectedExecutionException

import scala.collection.JavaConverters._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
import org.apache.spark.sql.kyuubi.SparkDatasetHelper._
import org.apache.spark.sql.types._

import org.apache.kyuubi.{KyuubiSQLException, Logging}
Expand Down Expand Up @@ -187,42 +185,22 @@ class ArrowBasedExecuteStatement(
handle) {

override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
collectAsArrow(convertComplexType(resultDF)) { rdd =>
rdd.toLocalIterator
}
toArrowBatchLocalIterator(convertComplexType(resultDF))
}

override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
collectAsArrow(convertComplexType(resultDF)) { rdd =>
rdd.collect()
}
executeCollect(convertComplexType(resultDF))
}

override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
// this will introduce shuffle and hurt performance
val limitedResult = resultDF.limit(maxRows)
collectAsArrow(convertComplexType(limitedResult)) { rdd =>
rdd.collect()
}
}

/**
* refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
* operation, so that we can track the arrow-based queries on the UI tab.
*/
private def collectAsArrow[T](df: DataFrame)(action: RDD[Array[Byte]] => T): T = {
SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) {
df.queryExecution.executedPlan.resetMetrics()
action(SparkDatasetHelper.toArrowBatchRdd(df))
}
executeCollect(convertComplexType(resultDF.limit(maxRows)))
}

override protected def isArrowBasedOperation: Boolean = true

override val resultFormat = "arrow"

private def convertComplexType(df: DataFrame): DataFrame = {
SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString)
convertTopLevelComplexTypeToHiveString(df, timestampAsString)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.arrow

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.channels.Channels

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.arrow.vector._
import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel}
import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer}
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.CollectLimitExec
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils

object KyuubiArrowConverters extends SQLConfHelper with Logging {

type Batch = (Array[Byte], Long)

/**
* this method is to slice the input Arrow record batch byte array `bytes`, starting from `start`
* and taking `length` number of elements.
*/
def slice(
schema: StructType,
timeZoneId: String,
bytes: Array[Byte],
start: Int,
length: Int): Array[Byte] = {
val in = new ByteArrayInputStream(bytes)
val out = new ByteArrayOutputStream(bytes.length)

var vectorSchemaRoot: VectorSchemaRoot = null
var slicedVectorSchemaRoot: VectorSchemaRoot = null

val sliceAllocator = ArrowUtils.rootAllocator.newChildAllocator(
"slice",
0,
Long.MaxValue)
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
vectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, sliceAllocator)
try {
val recordBatch = MessageSerializer.deserializeRecordBatch(
new ReadChannel(Channels.newChannel(in)),
sliceAllocator)
val vectorLoader = new VectorLoader(vectorSchemaRoot)
vectorLoader.load(recordBatch)
recordBatch.close()
slicedVectorSchemaRoot = vectorSchemaRoot.slice(start, length)

val unloader = new VectorUnloader(slicedVectorSchemaRoot)
val writeChannel = new WriteChannel(Channels.newChannel(out))
val batch = unloader.getRecordBatch()
MessageSerializer.serialize(writeChannel, batch)
batch.close()
out.toByteArray()
} finally {
in.close()
out.close()
if (vectorSchemaRoot != null) {
vectorSchemaRoot.getFieldVectors.asScala.foreach(_.close())
vectorSchemaRoot.close()
}
if (slicedVectorSchemaRoot != null) {
slicedVectorSchemaRoot.getFieldVectors.asScala.foreach(_.close())
slicedVectorSchemaRoot.close()
}
sliceAllocator.close()
}
}

/**
* Forked from `org.apache.spark.sql.execution.SparkPlan#executeTake()`, the algorithm can be
* summarized in the following steps:
* 1. If the limit specified in the CollectLimitExec object is 0, the function returns an empty
* array of batches.
* 2. Otherwise, execute the child query plan of the CollectLimitExec object to obtain an RDD of
* data to collect.
* 3. Use an iterative approach to collect data in batches until the specified limit is reached.
* In each iteration, it selects a subset of the partitions of the RDD to scan and tries to
* collect data from them.
* 4. For each partition subset, we use the runJob method of the Spark context to execute a
* closure that scans the partition data and converts it to Arrow batches.
* 5. Check if the collected data reaches the specified limit. If not, it selects another subset
* of partitions to scan and repeats the process until the limit is reached or all partitions
* have been scanned.
* 6. Return an array of all the collected Arrow batches.
*
* Note that:
* 1. The returned Arrow batches row count >= limit, if the input df has more than the `limit`
* row count
* 2. We don't implement the `takeFromEnd` logical
*
* @return
*/
def takeAsArrowBatches(
collectLimitExec: CollectLimitExec,
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
timeZoneId: String): Array[Batch] = {
val n = collectLimitExec.limit
val schema = collectLimitExec.schema
if (n == 0) {
return new Array[Batch](0)
} else {
val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
// TODO: refactor and reuse the code from RDD's take()
val childRDD = collectLimitExec.child.execute()
val buf = new ArrayBuffer[Batch]
var bufferedRowSize = 0L
val totalParts = childRDD.partitions.length
var partsScanned = 0
while (bufferedRowSize < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = limitInitialNumPartitions
if (partsScanned > 0) {
// If we didn't find any rows after the previous iteration, multiply by
// limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
// to try, but overestimate it by 50%. We also cap the estimation in the end.
if (buf.isEmpty) {
numPartsToTry = partsScanned * limitScaleUpFactor
} else {
val left = n - bufferedRowSize
// As left > 0, numPartsToTry is always >= 1
numPartsToTry = Math.ceil(1.5 * left * partsScanned / bufferedRowSize).toInt
numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
}
}

val partsToScan =
partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)

// TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we
// drop Spark-3.1.x support.
val sc = SparkSession.active.sparkContext
val res = sc.runJob(
childRDD,
(it: Iterator[InternalRow]) => {
val batches = toBatchIterator(
it,
schema,
maxRecordsPerBatch,
maxEstimatedBatchSize,
n,
timeZoneId)
batches.map(b => b -> batches.rowCountInLastBatch).toArray
},
partsToScan)

var i = 0
while (bufferedRowSize < n && i < res.length) {
var j = 0
val batches = res(i)
while (j < batches.length && n > bufferedRowSize) {
val batch = batches(j)
val (_, batchSize) = batch
buf += batch
bufferedRowSize += batchSize
j += 1
}
i += 1
}
partsScanned += partsToScan.size
}

buf.toArray
}
}

/**
* Spark introduced the config `spark.sql.limit.initialNumPartitions` since 3.4.0. see SPARK-40211
*/
private def limitInitialNumPartitions: Int = {
conf.getConfString("spark.sql.limit.initialNumPartitions", "1")
.toInt
}

/**
* Different from [[org.apache.spark.sql.execution.arrow.ArrowConverters.toBatchIterator]],
* each output arrow batch contains this batch row count.
*/
private def toBatchIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
limit: Long,
timeZoneId: String): ArrowBatchIterator = {
new ArrowBatchIterator(
rowIter,
schema,
maxRecordsPerBatch,
maxEstimatedBatchSize,
limit,
timeZoneId,
TaskContext.get)
}

/**
* This class ArrowBatchIterator is derived from
* [[org.apache.spark.sql.execution.arrow.ArrowConverters.ArrowBatchWithSchemaIterator]],
* with two key differences:
* 1. there is no requirement to write the schema at the batch header
* 2. iteration halts when `rowCount` equals `limit`
*/
private[sql] class ArrowBatchIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
limit: Long,
timeZoneId: String,
context: TaskContext)
extends Iterator[Array[Byte]] {

protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
private val allocator =
ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}",
0,
Long.MaxValue)

private val root = VectorSchemaRoot.create(arrowSchema, allocator)
protected val unloader = new VectorUnloader(root)
protected val arrowWriter = ArrowWriter.create(root)

Option(context).foreach {
_.addTaskCompletionListener[Unit] { _ =>
root.close()
allocator.close()
}
}

override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || {
root.close()
allocator.close()
false
}

var rowCountInLastBatch: Long = 0
var rowCount: Long = 0

override def next(): Array[Byte] = {
val out = new ByteArrayOutputStream()
val writeChannel = new WriteChannel(Channels.newChannel(out))

rowCountInLastBatch = 0
var estimatedBatchSize = 0L
Utils.tryWithSafeFinally {

// Always write the first row.
while (rowIter.hasNext && (
// For maxBatchSize and maxRecordsPerBatch, respect whatever smaller.
// If the size in bytes is positive (set properly), always write the first row.
rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 ||
// If the size in bytes of rows are 0 or negative, unlimit it.
estimatedBatchSize <= 0 ||
estimatedBatchSize < maxEstimatedBatchSize ||
// If the size of rows are 0 or negative, unlimit it.
maxRecordsPerBatch <= 0 ||
rowCountInLastBatch < maxRecordsPerBatch ||
rowCount < limit)) {
val row = rowIter.next()
arrowWriter.write(row)
estimatedBatchSize += (row match {
case ur: UnsafeRow => ur.getSizeInBytes
// Trying to estimate the size of the current row
case _: InternalRow => schema.defaultSize
})
rowCountInLastBatch += 1
rowCount += 1
}
arrowWriter.finish()
val batch = unloader.getRecordBatch()
MessageSerializer.serialize(writeChannel, batch)

// Always write the Ipc options at the end.
ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)

batch.close()
} {
arrowWriter.reset()
}

out.toByteArray
}
}

// for testing
def fromBatchIterator(
arrowBatchIter: Iterator[Array[Byte]],
schema: StructType,
timeZoneId: String,
context: TaskContext): Iterator[InternalRow] = {
ArrowConverters.fromBatchIterator(arrowBatchIter, schema, timeZoneId, context)
}
}
Loading

0 comments on commit 1a65125

Please sign in to comment.