Skip to content

Commit

Permalink
[GLUTEN-8453][VL] Follow-up to #8454 to add a ensureVeloxBatch API …
Browse files Browse the repository at this point in the history
…for limited use cases (#8463)
  • Loading branch information
zhztheplayer authored Jan 9, 2025
1 parent e0194dc commit 73ee147
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.columnarbatch;

import org.apache.gluten.backendsapi.BackendsApiManager;
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators;
import org.apache.gluten.runtime.Runtime;
import org.apache.gluten.runtime.Runtimes;

Expand Down Expand Up @@ -56,6 +57,7 @@ public static void checkNonVeloxBatch(ColumnarBatch batch) {
}

public static ColumnarBatch toVeloxBatch(ColumnarBatch input) {
ColumnarBatches.checkOffloaded(input);
if (ColumnarBatches.isZeroColumnBatch(input)) {
return input;
}
Expand Down Expand Up @@ -86,6 +88,26 @@ public static ColumnarBatch toVeloxBatch(ColumnarBatch input) {
return input;
}

/**
* Check if a columnar batch is in Velox format. If not, convert it to Velox format then return.
* If already in Velox format, return the batch directly.
*
* <p>Should only be used for certain conditions when unable to insert explicit to-Velox
* transitions through query planner.
*
* <p>For example, used by {@link org.apache.spark.sql.execution.ColumnarCachedBatchSerializer} as
* Spark directly calls API ColumnarCachedBatchSerializer#convertColumnarBatchToCachedBatch for
* query plan that returns supportsColumnar=true without generating a cache-write query plan node.
*/
public static ColumnarBatch ensureVeloxBatch(ColumnarBatch input) {
final ColumnarBatch light =
ColumnarBatches.ensureOffloaded(ArrowBufferAllocators.contextInstance(), input);
if (isVeloxBatch(light)) {
return light;
}
return toVeloxBatch(light);
}

/**
* Combine multiple columnar batches horizontally, assuming each of them is already offloaded.
* Otherwise {@link UnsupportedOperationException} will be thrown.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,24 +171,16 @@ class ColumnarCachedBatchSerializer extends CachedBatchSerializer with Logging {
conf: SQLConf): RDD[CachedBatch] = {
input.mapPartitions {
it =>
val lightBatches = it.map {
val veloxBatches = it.map {
/* Native code needs a Velox offloaded batch, making sure to offload
if heavy batch is encountered */
batch =>
val heavy = ColumnarBatches.isHeavyBatch(batch)
if (heavy) {
val offloaded = VeloxColumnarBatches.toVeloxBatch(
ColumnarBatches.offload(ArrowBufferAllocators.contextInstance(), batch))
offloaded
} else {
batch
}
batch => VeloxColumnarBatches.ensureVeloxBatch(batch)
}
new Iterator[CachedBatch] {
override def hasNext: Boolean = lightBatches.hasNext
override def hasNext: Boolean = veloxBatches.hasNext

override def next(): CachedBatch = {
val batch = lightBatches.next()
val batch = veloxBatches.next()
val results =
ColumnarBatchSerializerJniWrapper
.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,33 +67,26 @@ class ArrowCsvScanSuiteV2 extends ArrowCsvScanSuite {
}
}

/** Since https://github.com/apache/incubator-gluten/pull/5850. */
abstract class ArrowCsvScanSuite extends VeloxWholeStageTransformerSuite {
override protected val resourcePath: String = "N/A"
override protected val fileFormat: String = "N/A"

protected val rootPath: String = getClass.getResource("/").getPath

override def beforeAll(): Unit = {
super.beforeAll()
createCsvTables()
}

override def afterAll(): Unit = {
super.afterAll()
}

class ArrowCsvScanWithTableCacheSuite extends ArrowCsvScanSuiteBase {
override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.sql.files.maxPartitionBytes", "1g")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.memory.offHeap.size", "2g")
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set(GlutenConfig.NATIVE_ARROW_READER_ENABLED.key, "true")
.set("spark.sql.sources.useV1SourceList", "csv")
.set(GlutenConfig.COLUMNAR_TABLE_CACHE_ENABLED.key, "true")
}

/**
* Test for GLUTEN-8453: https://github.com/apache/incubator-gluten/issues/8453. To make sure no
* error is thrown when caching an Arrow Java query plan.
*/
test("csv scan v1 with table cache") {
val df = spark.sql("select * from student")
df.cache()
assert(df.collect().length == 3)
}
}

/** Since https://github.com/apache/incubator-gluten/pull/5850. */
abstract class ArrowCsvScanSuite extends ArrowCsvScanSuiteBase {
test("csv scan with option string as null") {
val df = runAndCompare("select * from student_option_str")()
val plan = df.queryExecution.executedPlan
Expand Down Expand Up @@ -152,6 +145,33 @@ abstract class ArrowCsvScanSuite extends VeloxWholeStageTransformerSuite {
val df = runAndCompare("select count(1) from student")()
checkLengthAndPlan(df, 1)
}
}

abstract class ArrowCsvScanSuiteBase extends VeloxWholeStageTransformerSuite {
override protected val resourcePath: String = "N/A"
override protected val fileFormat: String = "N/A"

protected val rootPath: String = getClass.getResource("/").getPath

override def beforeAll(): Unit = {
super.beforeAll()
createCsvTables()
}

override def afterAll(): Unit = {
super.afterAll()
}

override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.sql.files.maxPartitionBytes", "1g")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.memory.offHeap.size", "2g")
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set(GlutenConfig.NATIVE_ARROW_READER_ENABLED.key, "true")
}

private def createCsvTables(): Unit = {
spark.read
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,17 @@ private static BatchType identifyBatchType(ColumnarBatch batch) {
}

/** Heavy batch: Data is readable from JVM and formatted as Arrow data. */
public static boolean isHeavyBatch(ColumnarBatch batch) {
@VisibleForTesting
static boolean isHeavyBatch(ColumnarBatch batch) {
return identifyBatchType(batch) == BatchType.HEAVY;
}

/**
* Light batch: Data is not readable from JVM, a long int handle (which is a pointer usually) is
* used to bind the batch to a native side implementation.
*/
public static boolean isLightBatch(ColumnarBatch batch) {
@VisibleForTesting
static boolean isLightBatch(ColumnarBatch batch) {
return identifyBatchType(batch) == BatchType.LIGHT;
}

Expand Down Expand Up @@ -128,7 +130,7 @@ public static ColumnarBatch select(String backendName, ColumnarBatch batch, int[
* Ensure the input batch is offloaded as native-based columnar batch (See {@link IndicatorVector}
* and {@link PlaceholderVector}). This method will close the input column batch after offloaded.
*/
private static ColumnarBatch ensureOffloaded(BufferAllocator allocator, ColumnarBatch batch) {
static ColumnarBatch ensureOffloaded(BufferAllocator allocator, ColumnarBatch batch) {
if (ColumnarBatches.isLightBatch(batch)) {
return batch;
}
Expand All @@ -140,7 +142,7 @@ private static ColumnarBatch ensureOffloaded(BufferAllocator allocator, Columnar
* take place if loading is required, which means when the input batch is not loaded yet. This
* method will close the input column batch after loaded.
*/
private static ColumnarBatch ensureLoaded(BufferAllocator allocator, ColumnarBatch batch) {
static ColumnarBatch ensureLoaded(BufferAllocator allocator, ColumnarBatch batch) {
if (isHeavyBatch(batch)) {
return batch;
}
Expand Down

0 comments on commit 73ee147

Please sign in to comment.