diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala index 75e73f64e540..1f9419976f29 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.columnarbatch.ColumnarBatches +import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxColumnarBatches} import org.apache.gluten.config.GlutenConfig import org.apache.gluten.execution.{RowToVeloxColumnarExec, VeloxColumnarToRowExec} import org.apache.gluten.iterator.Iterators @@ -171,11 +171,24 @@ class ColumnarCachedBatchSerializer extends CachedBatchSerializer with Logging { conf: SQLConf): RDD[CachedBatch] = { input.mapPartitions { it => + val lightBatches = it.map { + /* Native code needs a Velox offloaded batch, making sure to offload + if heavy batch is encountered */ + batch => + val heavy = ColumnarBatches.isHeavyBatch(batch) + if (heavy) { + val offloaded = VeloxColumnarBatches.toVeloxBatch( + ColumnarBatches.offload(ArrowBufferAllocators.contextInstance(), batch)) + offloaded + } else { + batch + } + } new Iterator[CachedBatch] { - override def hasNext: Boolean = it.hasNext + override def hasNext: Boolean = lightBatches.hasNext override def next(): CachedBatch = { - val batch = it.next() + val batch = lightBatches.next() val results = ColumnarBatchSerializerJniWrapper .create( diff --git a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java index 3914fb155ec4..5114853363bd 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java @@ -85,8 +85,7 @@ private static BatchType identifyBatchType(ColumnarBatch batch) { } /** Heavy batch: Data is readable from JVM and formatted as Arrow data. */ - @VisibleForTesting - static boolean isHeavyBatch(ColumnarBatch batch) { + public static boolean isHeavyBatch(ColumnarBatch batch) { return identifyBatchType(batch) == BatchType.HEAVY; } @@ -94,8 +93,7 @@ static boolean isHeavyBatch(ColumnarBatch batch) { * Light batch: Data is not readable from JVM, a long int handle (which is a pointer usually) is * used to bind the batch to a native side implementation. */ - @VisibleForTesting - static boolean isLightBatch(ColumnarBatch batch) { + public static boolean isLightBatch(ColumnarBatch batch) { return identifyBatchType(batch) == BatchType.LIGHT; } @@ -230,7 +228,8 @@ public static ColumnarBatch offload(BufferAllocator allocator, ColumnarBatch inp if (input.numCols() == 0) { throw new IllegalArgumentException("batch with zero columns cannot be offloaded"); } - // Batch-offloading doesn't involve any backend-specific native code. Use the internal + // Batch-offloading doesn't involve any backend-specific native code. Use the + // internal // backend to store native batch references only. final Runtime runtime = Runtimes.contextInstance(INTERNAL_BACKEND_KIND, "ColumnarBatches#offload");