Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13255][SQL] Update vectorized reader to directly return ColumnarBatch instead of InternalRows. #11435

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.apache.parquet.schema.Type;

import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
Expand All @@ -57,10 +56,14 @@
*
* TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch.
* All of these can be handled efficiently and easily with codegen.
*
* This class can either return InternalRows or ColumnarBatches. With whole stage codegen
* enabled, this class returns ColumnarBatches which offers significant performance gains.
* TODO: make this always return ColumnarBatches.
*/
public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase<InternalRow> {
public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase<Object> {
/**
* Batch of unsafe rows that we assemble and the current index we've returned. Everytime this
* Batch of unsafe rows that we assemble and the current index we've returned. Every time this
* batch is used up (batchIdx == numBatched), we populated the batch.
*/
private UnsafeRow[] rows = new UnsafeRow[64];
Expand Down Expand Up @@ -115,11 +118,15 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
* code between the path that uses the MR decoders and the vectorized ones.
*
* TODOs:
* - Implement all the encodings to support vectorized.
* - Implement v2 page formats (just make sure we create the correct decoders).
*/
private ColumnarBatch columnarBatch;

/**
* If true, this class returns batches instead of rows.
*/
private boolean returnColumnarBatch;

/**
* The default config on whether columnarBatch should be offheap.
*/
Expand Down Expand Up @@ -169,6 +176,8 @@ public void close() throws IOException {

@Override
public boolean nextKeyValue() throws IOException, InterruptedException {
if (returnColumnarBatch) return nextBatch();

if (batchIdx >= numBatched) {
if (vectorizedDecode()) {
if (!nextBatch()) return false;
Expand All @@ -181,7 +190,9 @@ public boolean nextKeyValue() throws IOException, InterruptedException {
}

@Override
public InternalRow getCurrentValue() throws IOException, InterruptedException {
public Object getCurrentValue() throws IOException, InterruptedException {
if (returnColumnarBatch) return columnarBatch;

if (vectorizedDecode()) {
return columnarBatch.getRow(batchIdx - 1);
} else {
Expand Down Expand Up @@ -210,6 +221,14 @@ public ColumnarBatch resultBatch(MemoryMode memMode) {
return columnarBatch;
}

/**
* Can be called before any rows are returned to enable returning columnar batches directly.
*/
public void enableReturningBatches() {
assert(vectorizedDecode());
returnColumnarBatch = true;
}

/**
* Advances to the next batch of rows. Returns false if there are no more.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,73 @@

import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;

/**
* Utilities to help manipulate data associate with ColumnVectors. These should be used mostly
* for debugging or other non-performance critical paths.
* These utilities are mostly used to convert ColumnVectors into other formats.
*/
public class ColumnVectorUtils {
/**
* Populates the entire `col` with `row[fieldIdx]`
*/
public static void populate(ColumnVector col, InternalRow row, int fieldIdx) {
int capacity = col.capacity;
DataType t = col.dataType();

if (row.isNullAt(fieldIdx)) {
col.putNulls(0, capacity);
} else {
if (t == DataTypes.BooleanType) {
col.putBooleans(0, capacity, row.getBoolean(fieldIdx));
} else if (t == DataTypes.ByteType) {
col.putBytes(0, capacity, row.getByte(fieldIdx));
} else if (t == DataTypes.ShortType) {
col.putShorts(0, capacity, row.getShort(fieldIdx));
} else if (t == DataTypes.IntegerType) {
col.putInts(0, capacity, row.getInt(fieldIdx));
} else if (t == DataTypes.LongType) {
col.putLongs(0, capacity, row.getLong(fieldIdx));
} else if (t == DataTypes.FloatType) {
col.putFloats(0, capacity, row.getFloat(fieldIdx));
} else if (t == DataTypes.DoubleType) {
col.putDoubles(0, capacity, row.getDouble(fieldIdx));
} else if (t == DataTypes.StringType) {
UTF8String v = row.getUTF8String(fieldIdx);
byte[] bytes = v.getBytes();
for (int i = 0; i < capacity; i++) {
col.putByteArray(i, bytes);
}
} else if (t instanceof DecimalType) {
DecimalType dt = (DecimalType)t;
Decimal d = row.getDecimal(fieldIdx, dt.precision(), dt.scale());
if (dt.precision() <= Decimal.MAX_INT_DIGITS()) {
col.putInts(0, capacity, (int)d.toUnscaledLong());
} else if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
col.putLongs(0, capacity, d.toUnscaledLong());
} else {
final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
byte[] bytes = integer.toByteArray();
for (int i = 0; i < capacity; i++) {
col.putByteArray(i, bytes, 0, bytes.length);
}
}
} else if (t instanceof CalendarIntervalType) {
CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t);
col.getChildColumn(0).putInts(0, capacity, c.months);
col.getChildColumn(1).putLongs(0, capacity, c.microseconds);
} else if (t instanceof DateType) {
Date date = (Date)row.get(fieldIdx, t);
col.putInts(0, capacity, DateTimeUtils.fromJavaDate(date));
}
}
}

/**
* Returns the array data as the java primitive array.
* For example, an array of IntegerType will return an int[].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.commons.lang.NotImplementedException;

import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
Expand Down Expand Up @@ -315,6 +316,17 @@ public int numValidRows() {
*/
public ColumnVector column(int ordinal) { return columns[ordinal]; }

/**
* Sets (replaces) the column at `ordinal` with column. This can be used to do very efficient
* projections.
*/
public void setColumn(int ordinal, ColumnVector column) {
if (column instanceof OffHeapColumnVector) {
throw new NotImplementedException("Need to ref count columns.");
}
columns[ordinal] = column;
}

/**
* Returns the row in this batch at `rowId`. Returned row is reused across calls.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ public final long nullsNativeAddress() {

@Override
public final void close() {
nulls = null;
intData = null;
doubleData = null;
}

//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,24 +139,77 @@ private[sql] case class PhysicalRDD(
// Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen
// never requires UnsafeRow as input.
override protected def doProduce(ctx: CodegenContext): String = {
val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
val input = ctx.freshName("input")
val idx = ctx.freshName("batchIdx")
val batch = ctx.freshName("batch")
// PhysicalRDD always just has one input
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;")
ctx.addMutableState("int", idx, s"$idx = 0;")

val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
val numOutputRows = metricTerm(ctx, "numOutputRows")
ctx.INPUT_ROW = row
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))

// The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this
// by looking at the first value of the RDD and then calling the function which will process
// the remaining. It is faster to return batches.
// TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we return ColumnarBatch in SqlNewHadoopRDD, the counter of number of records in SqlNewHadoopRDD will be wrong.

Should we have a BatchedSqlNewHadoopRDD for this purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to clean this up but let's do this in a follow up. That counter is too expensive to maintain right now and it's not clear why we would if we maintain sql metrics.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are used in different places, we can clean this later.

// here which path to use. Fix this.


val scanBatches = ctx.freshName("processBatches")
ctx.addNewFunction(scanBatches,
s"""
| private void $scanBatches() throws java.io.IOException {
| while (true) {
| int numRows = $batch.numRows();
| if ($idx == 0) $numOutputRows.add(numRows);
|
| while ($idx < numRows) {
| InternalRow $row = $batch.getRow($idx++);
| ${columns.map(_.code).mkString("\n").trim}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could generate code to use the ColumnBatch object, not InternalRow, not sure how the difference will be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something to try. Would be easy enough in a follow up.

| ${consume(ctx, columns).trim}
| if (shouldStop()) return;
| }
|
| if (!$input.hasNext()) {
| $batch = null;
| break;
| }
| $batch = ($columnarBatchClz)$input.next();
| $idx = 0;
| }
| }""".stripMargin)

val scanRows = ctx.freshName("processRows")
ctx.addNewFunction(scanRows,
s"""
| private void $scanRows(InternalRow $row) throws java.io.IOException {
| while (true) {
| $numOutputRows.add(1);
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| if (shouldStop()) return;
| if (!$input.hasNext()) break;
| $row = (InternalRow)$input.next();
| }
| }""".stripMargin)

s"""
| while ($input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
| $numOutputRows.add(1);
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| if (shouldStop()) {
| return;
| if ($batch != null) {
| $scanBatches();
| } else if ($input.hasNext()) {
| Object value = $input.next();
| if (value instanceof $columnarBatchClz) {
| $batch = ($columnarBatchClz)value;
| $scanBatches();
| } else {
| $scanRows((InternalRow)value);
| }
| }
""".stripMargin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
Expand All @@ -33,8 +34,9 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.ExecutedCommand
import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVectorUtils}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.spark.util.collection.BitSet
Expand Down Expand Up @@ -220,6 +222,44 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
sparkPlan
}

/**
* Creates a ColumnarBatch that contains the values for `requiredColumns`. These columns can
* either come from `input` (columns scanned from the data source) or from the partitioning
* values (data from `partitionValues`). This is done *once* per physical partition. When
* the column is from `input`, it just references the same underlying column. When using
* partition columns, the column is populated once.
* TODO: there's probably a cleaner way to do this.
*/
private def projectedColumnBatch(
input: ColumnarBatch,
requiredColumns: Seq[Attribute],
dataColumns: Seq[Attribute],
partitionColumnSchema: StructType,
partitionValues: InternalRow) : ColumnarBatch = {
val result = ColumnarBatch.allocate(StructType.fromAttributes(requiredColumns))
var resultIdx = 0
var inputIdx = 0

while (resultIdx < requiredColumns.length) {
val attr = requiredColumns(resultIdx)
if (inputIdx < dataColumns.length && requiredColumns(resultIdx) == dataColumns(inputIdx)) {
result.setColumn(resultIdx, input.column(inputIdx))
inputIdx += 1
} else {
require(partitionColumnSchema.fields.filter(_.name.equals(attr.name)).length == 1)
var partitionIdx = 0
partitionColumnSchema.fields.foreach { f => {
if (f.name.equals(attr.name)) {
ColumnVectorUtils.populate(result.column(resultIdx), partitionValues, partitionIdx)
}
partitionIdx += 1
}}
}
resultIdx += 1
}
result
}

private def mergeWithPartitionValues(
requiredColumns: Seq[Attribute],
dataColumns: Seq[Attribute],
Expand All @@ -239,25 +279,43 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
}
}

val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[InternalRow]) => {
val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[Object]) => {
// Note that we can't use an `UnsafeRowJoiner` to replace the following `JoinedRow` and
// `UnsafeProjection`. Because the projection may also adjust column order.
val mutableJoinedRow = new JoinedRow()
val unsafePartitionValues = UnsafeProjection.create(partitionColumnSchema)(partitionValues)
val unsafeProjection =
UnsafeProjection.create(requiredColumns, dataColumns ++ partitionColumns)

iterator.map { unsafeDataRow =>
unsafeProjection(mutableJoinedRow(unsafeDataRow, unsafePartitionValues))
}
// If we are returning batches directly, we need to augment them with the partitioning
// columns. We want to do this without a row by row operation.
var columnBatch: ColumnarBatch = null
var mergedBatch: ColumnarBatch = null

iterator.map { input => {
if (input.isInstanceOf[InternalRow]) {
unsafeProjection(mutableJoinedRow(
input.asInstanceOf[InternalRow], unsafePartitionValues))
} else {
require(input.isInstanceOf[ColumnarBatch])
val inputBatch = input.asInstanceOf[ColumnarBatch]
if (inputBatch != mergedBatch) {
mergedBatch = inputBatch
columnBatch = projectedColumnBatch(inputBatch, requiredColumns,
dataColumns, partitionColumnSchema, partitionValues)
}
columnBatch.setNumRows(inputBatch.numRows())
columnBatch
}
}}
}

// This is an internal RDD whose call site the user should not be concerned with
// Since we create many of these (one per partition), the time spent on computing
// the call site may add up.
Utils.withDummyCallSite(dataRows.sparkContext) {
new MapPartitionsRDD(dataRows, mapPartitionsFunc, preservesPartitioning = false)
}
}.asInstanceOf[RDD[InternalRow]]
} else {
dataRows
}
Expand Down
Loading