Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-183] Add Date/Timestamp type support #347

Merged
merged 11 commits into from
Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,21 @@

import java.lang.*;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.*;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.holders.NullableVarCharHolder;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.types.pojo.Field;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.util.ArrowUtils;
import org.apache.spark.sql.vectorized.*;
Expand Down Expand Up @@ -84,7 +80,7 @@ public static BufferAllocator getOffRecordAllocator() {
*/
public static ArrowWritableColumnVector[] allocateColumns(
int capacity, StructType schema) {
String timeZoneId = SQLConf.get().sessionLocalTimeZone();
String timeZoneId = SparkSchemaUtils.getGandivaCompatibleTimeZoneID();
Schema arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId);
VectorSchemaRoot new_root =
VectorSchemaRoot.create(arrowSchema, SparkMemoryUtils.contextAllocator());
Expand Down Expand Up @@ -169,7 +165,7 @@ public ArrowWritableColumnVector(int capacity, DataType dataType) {
super(capacity, dataType);
vectorCount.getAndIncrement();
refCnt.getAndIncrement();
String timeZoneId = SQLConf.get().sessionLocalTimeZone();
String timeZoneId = SparkSchemaUtils.getGandivaCompatibleTimeZoneID();
List<Field> fields =
Arrays.asList(ArrowUtils.toArrowField("col", dataType, true, timeZoneId));
Schema arrowSchema = new Schema(fields);
Expand Down Expand Up @@ -232,8 +228,8 @@ private void createVectorAccessor(ValueVector vector, ValueVector dictionary) {
accessor = new BinaryAccessor((VarBinaryVector) vector);
} else if (vector instanceof DateDayVector) {
accessor = new DateAccessor((DateDayVector) vector);
} else if (vector instanceof TimeStampMicroTZVector) {
accessor = new TimestampAccessor((TimeStampMicroTZVector) vector);
} else if (vector instanceof TimeStampMicroVector) {
accessor = new TimestampAccessor((TimeStampMicroVector) vector);
} else if (vector instanceof ListVector) {
ListVector listVector = (ListVector) vector;
accessor = new ArrayAccessor(listVector);
Expand Down Expand Up @@ -274,8 +270,8 @@ private ArrowVectorWriter createVectorWriter(ValueVector vector) {
return new BinaryWriter((VarBinaryVector) vector);
} else if (vector instanceof DateDayVector) {
return new DateWriter((DateDayVector) vector);
} else if (vector instanceof TimeStampMicroTZVector) {
return new TimestampWriter((TimeStampMicroTZVector) vector);
} else if (vector instanceof TimeStampVector) {
return new TimestampWriter((TimeStampVector) vector);
} else if (vector instanceof ListVector) {
ListVector listVector = (ListVector) vector;
ArrowVectorWriter elementVector = createVectorWriter(listVector.getDataVector());
Expand Down Expand Up @@ -1144,9 +1140,9 @@ final UTF8String getUTF8String(int rowId) {
}

private static class TimestampAccessor extends ArrowVectorAccessor {
private final TimeStampMicroTZVector accessor;
private final TimeStampMicroVector accessor;

TimestampAccessor(TimeStampMicroTZVector vector) {
TimestampAccessor(TimeStampMicroVector vector) {
super(vector);
this.accessor = vector;
}
Expand Down Expand Up @@ -1798,9 +1794,9 @@ final void setNulls(int rowId, int count) {
}

private static class TimestampWriter extends ArrowVectorWriter {
private final TimeStampMicroTZVector writer;
private final TimeStampVector writer;

TimestampWriter(TimeStampMicroTZVector vector) {
TimestampWriter(TimeStampVector vector) {
super(vector);
this.writer = vector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql.execution.datasources.v2.arrow

import java.util.TimeZone

import org.apache.arrow.vector.types.pojo.Schema

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

Expand All @@ -31,4 +34,18 @@ object SparkSchemaUtils {
def toArrowSchema(schema: StructType, timeZoneId: String): Schema = {
ArrowUtils.toArrowSchema(schema, timeZoneId)
}

def getGandivaCompatibleTimeZoneID(): String = {
val zone = SQLConf.get.sessionLocalTimeZone
if (TimeZone.getTimeZone(zone)
.toZoneId
.getRules
.getOffset(java.time.Instant.now())
.getTotalSeconds != 0) {
throw new RuntimeException("Running Spark with Native SQL engine in non-UTC timezone" +
" environment is forbidden. Consider setting session timezone within Spark config " +
"spark.sql.session.timeZone. E.g. spark.sql.session.timeZone = UTC")
}
null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.intel.oap.spark.sql.execution.datasources.v2.arrow

import java.net.URI
import java.time.ZoneId

import scala.collection.JavaConverters._

Expand All @@ -28,6 +29,7 @@ import org.apache.arrow.vector.types.pojo.Schema
import org.apache.hadoop.fs.FileStatus

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.v2.arrow.{SparkMemoryUtils, SparkSchemaUtils}
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -83,7 +85,7 @@ object ArrowUtils {

def toArrowSchema(t: StructType): Schema = {
// fixme this might be platform dependent
SparkSchemaUtils.toArrowSchema(t, SQLConf.get.sessionLocalTimeZone)
SparkSchemaUtils.toArrowSchema(t, SparkSchemaUtils.getGandivaCompatibleTimeZoneID())
}

def loadBatch(input: ArrowRecordBatch, partitionValues: InternalRow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.intel.oap.vectorized.{ArrowWritableColumnVector, CloseableColumnBatch
import org.apache.arrow.gandiva.expression.TreeBuilder
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -43,12 +44,13 @@ import org.apache.spark.sql.types.{DataType, DateType, DecimalType, DoubleType,
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ExecutorManager

import scala.collection.JavaConverters._
import scala.collection.immutable.Stream.Empty
import scala.collection.mutable.ListBuffer
import scala.util.Random

import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils

case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
Expand Down Expand Up @@ -221,7 +223,8 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],

val evaluator = new ExpressionEvaluator()
val resultSchema = new Schema(resultField.getChildren)
val arrowSchema = ArrowUtils.toArrowSchema(child.schema, SQLConf.get.sessionLocalTimeZone)
val arrowSchema = ArrowUtils.toArrowSchema(child.schema,
SparkSchemaUtils.getGandivaCompatibleTimeZoneID())
evaluator.build(arrowSchema,
List(TreeBuilder.makeExpression(window,
resultField)).asJava, resultSchema, true)
Expand Down Expand Up @@ -440,8 +443,8 @@ object ColumnarWindowExec extends Logging {
Cast(we.copy(
windowFunction =
ae.copy(aggregateFunction = Min(Cast(Cast(e, TimestampType,
Some(DateTimeUtils.TimeZoneUTC.getID)), LongType)))),
TimestampType), DateType, Some(DateTimeUtils.TimeZoneUTC.getID))
Some(SparkSchemaUtils.getGandivaCompatibleTimeZoneID())), LongType)))),
TimestampType), DateType, Some(SparkSchemaUtils.getGandivaCompatibleTimeZoneID()))
case _ => we
}
case Max(e) => e.dataType match {
Expand All @@ -454,8 +457,8 @@ object ColumnarWindowExec extends Logging {
Cast(we.copy(
windowFunction =
ae.copy(aggregateFunction = Max(Cast(Cast(e, TimestampType,
Some(DateTimeUtils.TimeZoneUTC.getID)), LongType)))),
TimestampType), DateType, Some(DateTimeUtils.TimeZoneUTC.getID))
Some(SparkSchemaUtils.getGandivaCompatibleTimeZoneID())), LongType)))),
TimestampType), DateType, Some(SparkSchemaUtils.getGandivaCompatibleTimeZoneID()))
case _ => we
}
case _ => we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import org.apache.arrow.vector.IntVector
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
import org.apache.arrow.vector.types.pojo.ArrowType

import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils

object CodeGeneration {
val timeZoneId = SQLConf.get.sessionLocalTimeZone
val timeZoneId = SparkSchemaUtils.getGandivaCompatibleTimeZoneID()

def getResultType(left: ArrowType, right: ArrowType): ArrowType = {
//TODO(): remove this API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package com.intel.oap.expression

import com.google.common.collect.Lists

import org.apache.arrow.gandiva.evaluator._
import org.apache.arrow.gandiva.exceptions.GandivaException
import org.apache.arrow.gandiva.expression._
Expand All @@ -31,9 +30,10 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.types._

import scala.collection.mutable.ListBuffer

import org.apache.arrow.vector.types.TimeUnit

/**
* A version of add that supports columnar processing for longs.
*/
Expand Down Expand Up @@ -450,6 +450,13 @@ class ColumnarCast(
throw new UnsupportedOperationException(
s"${child.dataType} is not supported in castDECIMAL")
}
} else if (dataType.isInstanceOf[TimestampType]) {
val supported = List(LongType, DateType, StringType)
if (supported.indexOf(child.dataType) == -1 &&
!child.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${child.dataType} is not supported in castDECIMAL")
}
} else {
throw new UnsupportedOperationException(s"not currently supported: ${dataType}.")
}
Expand All @@ -459,7 +466,12 @@ class ColumnarCast(
val (child_node, childType): (TreeNode, ArrowType) =
child.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = CodeGeneration.getResultType(dataType)
val toType = CodeGeneration.getResultType(dataType)
val child_node0 = childType match {
case _: ArrowType.Timestamp =>
ConverterUtils.convertTimestampToMilli(child_node, childType)._1
case _ => child_node
}
if (dataType == StringType) {
val limitLen: java.lang.Long = childType match {
case int: ArrowType.Int if int.getBitWidth == 8 => 4
Expand All @@ -476,6 +488,7 @@ class ColumnarCast(
case decimal: ArrowType.Decimal =>
// Add two to precision for decimal point and negative sign
(decimal.getPrecision() + 2)
case _: ArrowType.Timestamp => 24
case _ =>
throw new UnsupportedOperationException(
s"ColumnarCast to String doesn't support ${childType}")
Expand All @@ -484,63 +497,74 @@ class ColumnarCast(
val funcNode =
TreeBuilder.makeFunction(
"castVARCHAR",
Lists.newArrayList(child_node, limitLenNode),
resultType)
(funcNode, resultType)
Lists.newArrayList(child_node0, limitLenNode),
toType)
(funcNode, toType)
} else if (dataType == ByteType) {
val funcNode =
TreeBuilder.makeFunction("castBYTE", Lists.newArrayList(child_node), resultType)
(funcNode, resultType)
TreeBuilder.makeFunction("castBYTE", Lists.newArrayList(child_node0), toType)
(funcNode, toType)
} else if (dataType == IntegerType) {
val funcNode = child.dataType match {
case d: DecimalType =>
val half_node = TreeBuilder.makeDecimalLiteral("0.5", 2, 1)
val round_down_node = TreeBuilder.makeFunction(
"subtract",
Lists.newArrayList(child_node, half_node),
Lists.newArrayList(child_node0, half_node),
childType)
val long_node = TreeBuilder.makeFunction(
"castBIGINT",
Lists.newArrayList(round_down_node),
new ArrowType.Int(64, true))
TreeBuilder.makeFunction("castINT", Lists.newArrayList(long_node), resultType)
TreeBuilder.makeFunction("castINT", Lists.newArrayList(long_node), toType)
case other =>
TreeBuilder.makeFunction("castINT", Lists.newArrayList(child_node), resultType)
TreeBuilder.makeFunction("castINT", Lists.newArrayList(child_node0), toType)
}
(funcNode, resultType)
(funcNode, toType)
} else if (dataType == LongType) {
val funcNode =
TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(child_node), resultType)
(funcNode, resultType)
TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(child_node0), toType)
(funcNode, toType)
//(child_node, childType)
} else if (dataType == FloatType) {
val funcNode = child.dataType match {
case d: DecimalType =>
val double_node = TreeBuilder.makeFunction(
"castFLOAT8",
Lists.newArrayList(child_node),
Lists.newArrayList(child_node0),
new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))
TreeBuilder.makeFunction("castFLOAT4", Lists.newArrayList(double_node), resultType)
TreeBuilder.makeFunction("castFLOAT4", Lists.newArrayList(double_node), toType)
case other =>
TreeBuilder.makeFunction("castFLOAT4", Lists.newArrayList(child_node), resultType)
TreeBuilder.makeFunction("castFLOAT4", Lists.newArrayList(child_node0), toType)
}
(funcNode, resultType)
(funcNode, toType)
} else if (dataType == DoubleType) {
val funcNode =
TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(child_node), resultType)
(funcNode, resultType)
TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(child_node0), toType)
(funcNode, toType)
} else if (dataType == DateType) {
val funcNode =
TreeBuilder.makeFunction("castDATE", Lists.newArrayList(child_node), resultType)
(funcNode, resultType)
TreeBuilder.makeFunction("castDATE", Lists.newArrayList(child_node0), toType)
(funcNode, toType)
} else if (dataType.isInstanceOf[DecimalType]) {
dataType match {
case d: DecimalType =>
val dType = CodeGeneration.getResultType(d)
val funcNode =
TreeBuilder.makeFunction("castDECIMAL", Lists.newArrayList(child_node), dType)
TreeBuilder.makeFunction("castDECIMAL", Lists.newArrayList(child_node0), dType)
(funcNode, dType)
}
} else if (dataType.isInstanceOf[TimestampType]) {
val arrowTsType = toType match {
case ts: ArrowType.Timestamp => ts
case _ => throw new IllegalArgumentException("Not an Arrow timestamp type: " + toType)
}
// convert to milli, then convert to micro
val intermediateType = new ArrowType.Timestamp(TimeUnit.MILLISECOND, arrowTsType.getTimezone)
val funcNode =
TreeBuilder.makeFunction("castTIMESTAMP", Lists.newArrayList(child_node0),
intermediateType)
ConverterUtils.convertTimestampToMicro(funcNode, intermediateType)
} else {
throw new UnsupportedOperationException(s"not currently supported: ${dataType}.")
}
Expand Down
Loading