Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-183] Add Date/Timestamp type support (#347)
Browse files Browse the repository at this point in the history
Closes #183
  • Loading branch information
zhztheplayer authored Jun 17, 2021
1 parent 6d7511d commit 19bfe50
Show file tree
Hide file tree
Showing 29 changed files with 969 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,21 @@

import java.lang.*;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.*;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.holders.NullableVarCharHolder;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.types.pojo.Field;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.util.ArrowUtils;
import org.apache.spark.sql.vectorized.*;
Expand Down Expand Up @@ -84,7 +80,7 @@ public static BufferAllocator getOffRecordAllocator() {
*/
public static ArrowWritableColumnVector[] allocateColumns(
int capacity, StructType schema) {
String timeZoneId = SQLConf.get().sessionLocalTimeZone();
String timeZoneId = SparkSchemaUtils.getLocalTimezoneID();
Schema arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId);
VectorSchemaRoot new_root =
VectorSchemaRoot.create(arrowSchema, SparkMemoryUtils.contextAllocator());
Expand Down Expand Up @@ -169,7 +165,7 @@ public ArrowWritableColumnVector(int capacity, DataType dataType) {
super(capacity, dataType);
vectorCount.getAndIncrement();
refCnt.getAndIncrement();
String timeZoneId = SQLConf.get().sessionLocalTimeZone();
String timeZoneId = SparkSchemaUtils.getLocalTimezoneID();
List<Field> fields =
Arrays.asList(ArrowUtils.toArrowField("col", dataType, true, timeZoneId));
Schema arrowSchema = new Schema(fields);
Expand Down Expand Up @@ -232,8 +228,8 @@ private void createVectorAccessor(ValueVector vector, ValueVector dictionary) {
accessor = new BinaryAccessor((VarBinaryVector) vector);
} else if (vector instanceof DateDayVector) {
accessor = new DateAccessor((DateDayVector) vector);
} else if (vector instanceof TimeStampMicroTZVector) {
accessor = new TimestampAccessor((TimeStampMicroTZVector) vector);
} else if (vector instanceof TimeStampMicroVector || vector instanceof TimeStampMicroTZVector) {
accessor = new TimestampMicroAccessor((TimeStampVector) vector);
} else if (vector instanceof ListVector) {
ListVector listVector = (ListVector) vector;
accessor = new ArrayAccessor(listVector);
Expand Down Expand Up @@ -274,8 +270,8 @@ private ArrowVectorWriter createVectorWriter(ValueVector vector) {
return new BinaryWriter((VarBinaryVector) vector);
} else if (vector instanceof DateDayVector) {
return new DateWriter((DateDayVector) vector);
} else if (vector instanceof TimeStampMicroTZVector) {
return new TimestampWriter((TimeStampMicroTZVector) vector);
} else if (vector instanceof TimeStampMicroVector || vector instanceof TimeStampMicroTZVector) {
return new TimestampMicroWriter((TimeStampVector) vector);
} else if (vector instanceof ListVector) {
ListVector listVector = (ListVector) vector;
ArrowVectorWriter elementVector = createVectorWriter(listVector.getDataVector());
Expand All @@ -288,7 +284,7 @@ private ArrowVectorWriter createVectorWriter(ValueVector vector) {
}
return new StructWriter(structVector, children);
} else {
throw new UnsupportedOperationException("Unsupported data type: ");
throw new UnsupportedOperationException("Unsupported data type: " + vector.getMinorType());
}
}

Expand Down Expand Up @@ -1143,10 +1139,10 @@ final UTF8String getUTF8String(int rowId) {
}
}

private static class TimestampAccessor extends ArrowVectorAccessor {
private final TimeStampMicroTZVector accessor;
private static class TimestampMicroAccessor extends ArrowVectorAccessor {
private final TimeStampVector accessor;

TimestampAccessor(TimeStampMicroTZVector vector) {
TimestampMicroAccessor(TimeStampVector vector) {
super(vector);
this.accessor = vector;
}
Expand Down Expand Up @@ -1797,10 +1793,10 @@ final void setNulls(int rowId, int count) {
}
}

private static class TimestampWriter extends ArrowVectorWriter {
private final TimeStampMicroTZVector writer;
private static class TimestampMicroWriter extends ArrowVectorWriter {
private final TimeStampVector writer;

TimestampWriter(TimeStampMicroTZVector vector) {
TimestampMicroWriter(TimeStampVector vector) {
super(vector);
this.writer = vector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@

package org.apache.spark.sql.execution.datasources.v2.arrow

import java.util.Objects
import java.util.TimeZone

import org.apache.arrow.vector.types.pojo.Schema

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

Expand All @@ -31,4 +35,46 @@ object SparkSchemaUtils {
def toArrowSchema(schema: StructType, timeZoneId: String): Schema = {
ArrowUtils.toArrowSchema(schema, timeZoneId)
}

@deprecated // experimental
def getGandivaCompatibleTimeZoneID(): String = {
val zone = SQLConf.get.sessionLocalTimeZone
validateGandivaCompatibleTimezoneID(zone)
zone
}

def getLocalTimezoneID(): String = {
SQLConf.get.sessionLocalTimeZone
}

def validateGandivaCompatibleTimezoneID(zoneId: String): Unit = {
throw new UnsupportedOperationException("not implemented") // fixme 20210602 hongze
if (!isTimeZoneIDGandivaCompatible(zoneId)) {
throw new RuntimeException("Running Spark with Native SQL engine in non-UTC timezone" +
" environment is forbidden. Consider setting session timezone within Spark config " +
"spark.sql.session.timeZone. E.g. spark.sql.session.timeZone = UTC")
}
}

def isTimeZoneIDGandivaCompatible(zoneId: String): Boolean = {
// Only UTC supported by Gandiva kernels so far
isTimeZoneIDEquivalentToUTC(zoneId)
}

def timeZoneIDEquals(one: String, other: String): Boolean = {
getTimeZoneIDOffset(one) == getTimeZoneIDOffset(other)
}

def isTimeZoneIDEquivalentToUTC(zoneId: String): Boolean = {
getTimeZoneIDOffset(zoneId) == 0
}

def getTimeZoneIDOffset(zoneId: String): Int = {
Objects.requireNonNull(zoneId)
TimeZone.getTimeZone(zoneId)
.toZoneId
.getRules
.getOffset(java.time.Instant.now())
.getTotalSeconds
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.intel.oap.spark.sql.execution.datasources.v2.arrow

import java.net.URI
import java.time.ZoneId

import scala.collection.JavaConverters._

Expand All @@ -28,6 +29,7 @@ import org.apache.arrow.vector.types.pojo.Schema
import org.apache.hadoop.fs.FileStatus

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.v2.arrow.{SparkMemoryUtils, SparkSchemaUtils}
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -83,7 +85,7 @@ object ArrowUtils {

def toArrowSchema(t: StructType): Schema = {
// fixme this might be platform dependent
SparkSchemaUtils.toArrowSchema(t, SQLConf.get.sessionLocalTimeZone)
SparkSchemaUtils.toArrowSchema(t, SparkSchemaUtils.getLocalTimezoneID())
}

def loadBatch(input: ArrowRecordBatch, partitionValues: InternalRow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.intel.oap.vectorized.{ArrowWritableColumnVector, CloseableColumnBatch
import org.apache.arrow.gandiva.expression.TreeBuilder
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -43,12 +44,13 @@ import org.apache.spark.sql.types.{DataType, DateType, DecimalType, DoubleType,
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ExecutorManager

import scala.collection.JavaConverters._
import scala.collection.immutable.Stream.Empty
import scala.collection.mutable.ListBuffer
import scala.util.Random

import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils

case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
Expand Down Expand Up @@ -221,7 +223,8 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],

val evaluator = new ExpressionEvaluator()
val resultSchema = new Schema(resultField.getChildren)
val arrowSchema = ArrowUtils.toArrowSchema(child.schema, SQLConf.get.sessionLocalTimeZone)
val arrowSchema = ArrowUtils.toArrowSchema(child.schema,
SparkSchemaUtils.getLocalTimezoneID())
evaluator.build(arrowSchema,
List(TreeBuilder.makeExpression(window,
resultField)).asJava, resultSchema, true)
Expand Down Expand Up @@ -440,8 +443,8 @@ object ColumnarWindowExec extends Logging {
Cast(we.copy(
windowFunction =
ae.copy(aggregateFunction = Min(Cast(Cast(e, TimestampType,
Some(DateTimeUtils.TimeZoneUTC.getID)), LongType)))),
TimestampType), DateType, Some(DateTimeUtils.TimeZoneUTC.getID))
Some(SparkSchemaUtils.getLocalTimezoneID())), LongType)))),
TimestampType), DateType, Some(SparkSchemaUtils.getLocalTimezoneID()))
case _ => we
}
case Max(e) => e.dataType match {
Expand All @@ -454,8 +457,8 @@ object ColumnarWindowExec extends Logging {
Cast(we.copy(
windowFunction =
ae.copy(aggregateFunction = Max(Cast(Cast(e, TimestampType,
Some(DateTimeUtils.TimeZoneUTC.getID)), LongType)))),
TimestampType), DateType, Some(DateTimeUtils.TimeZoneUTC.getID))
Some(SparkSchemaUtils.getLocalTimezoneID())), LongType)))),
TimestampType), DateType, Some(SparkSchemaUtils.getLocalTimezoneID()))
case _ => we
}
case _ => we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,24 @@ import org.apache.arrow.vector.IntVector
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
import org.apache.arrow.vector.types.pojo.ArrowType

import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils

object CodeGeneration {
val timeZoneId = SQLConf.get.sessionLocalTimeZone
private val defaultTimeZoneId = SparkSchemaUtils.getLocalTimezoneID()

def getResultType(left: ArrowType, right: ArrowType): ArrowType = {
//TODO(): remove this API
left
}

def getResultType(dataType: DataType): ArrowType = {
getResultType(dataType, defaultTimeZoneId)
}

def getResultType(dataType: DataType, timeZoneId: String): ArrowType = {
dataType match {
case other =>
ArrowUtils.toArrowType(dataType, timeZoneId)
Expand Down
Loading

0 comments on commit 19bfe50

Please sign in to comment.