Skip to content

Commit

Permalink
Add in checks for Parquet LEGACY date/time rebase (NVIDIA#435)
Browse files Browse the repository at this point in the history
  • Loading branch information
revans2 authored Jul 27, 2020
1 parent b1cb808 commit be9f426
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 27 deletions.
12 changes: 6 additions & 6 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def read_parquet_sql(data_path):

parquet_gens_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen,
TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))],
TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))],
pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/132'))]

@pytest.mark.parametrize('parquet_gens', parquet_gens_list, ids=idfn)
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_compress_read_round_trip(spark_tmp_path, compress):
string_gen, date_gen,
# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with
# timestamp_gen
TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))]
TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))]

@pytest.mark.parametrize('parquet_gen', parquet_pred_push_gens, ids=idfn)
@pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql])
Expand All @@ -102,7 +102,7 @@ def test_pred_push_round_trip(spark_tmp_path, parquet_gen, read_func):
def test_ts_read_round_trip(spark_tmp_path, ts_write, ts_rebase):
# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with
# timestamp_gen
gen = TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))
gen = TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark : unary_op_df(spark, gen).write.parquet(data_path),
Expand All @@ -113,7 +113,7 @@ def test_ts_read_round_trip(spark_tmp_path, ts_write, ts_rebase):

parquet_gens_legacy_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))],
TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))],
pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133')),
pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133'))]

Expand All @@ -132,7 +132,7 @@ def test_simple_partitioned_read(spark_tmp_path):
# we should go with a more standard set of generators
parquet_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))]
TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))]
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0'
with_cpu_session(
Expand All @@ -153,7 +153,7 @@ def test_read_merge_schema(spark_tmp_path):
# we should go with a more standard set of generators
parquet_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))]
TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))]
first_gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0'
with_cpu_session(
Expand Down
82 changes: 82 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark

import ai.rapids.cudf.{ColumnVector, DType, Scalar}
import com.nvidia.spark.rapids.Arm

import org.apache.spark.sql.catalyst.util.RebaseDateTime
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.execution.TrampolineUtil

object RebaseHelper extends Arm {
private[this] def isDateTimeRebaseNeeded(column: ColumnVector,
startDay: Int,
startTs: Long): Boolean = {
val dtype = column.getType
if (dtype == DType.TIMESTAMP_DAYS) {
withResource(Scalar.timestampDaysFromInt(startDay)) { minGood =>
withResource(column.lessThan(minGood)) { hasBad =>
withResource(hasBad.any()) { a =>
a.getBoolean
}
}
}
} else if (dtype.isTimestamp) {
assert(dtype == DType.TIMESTAMP_MICROSECONDS)
withResource(
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, startTs)) { minGood =>
withResource(column.lessThan(minGood)) { hasBad =>
withResource(hasBad.any()) { a =>
a.getBoolean
}
}
}
} else {
false
}
}

def isDateTimeRebaseNeededWrite(column: ColumnVector): Boolean =
isDateTimeRebaseNeeded(column,
RebaseDateTime.lastSwitchGregorianDay,
RebaseDateTime.lastSwitchGregorianTs)

def isDateTimeRebaseNeededRead(column: ColumnVector): Boolean =
isDateTimeRebaseNeeded(column,
RebaseDateTime.lastSwitchJulianDay,
RebaseDateTime.lastSwitchJulianTs)

def newRebaseExceptionInRead(format: String): Exception = {
val config = if (format == "Parquet") {
SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key
} else if (format == "Avro") {
SQLConf.LEGACY_AVRO_REBASE_MODE_IN_READ.key
} else {
throw new IllegalStateException("unrecognized format " + format)
}
TrampolineUtil.makeSparkUpgradeException("3.0",
"reading dates before 1582-10-15 or timestamps before " +
s"1900-01-01T00:00:00Z from $format files can be ambiguous, as the files may be written by " +
"Spark 2.x or legacy versions of Hive, which uses a legacy hybrid calendar that is " +
"different from Spark 3.0+'s Proleptic Gregorian calendar. See more details in " +
s"SPARK-31404. The RAPIDS Accelerator does not support reading these 'LEGACY' files. To do " +
s"so you should disable $format support in the RAPIDS Accelerator " +
s"or set $config to 'CORRECTED' to read the datetime values as it is.",
null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids

import scala.collection.mutable

import ai.rapids.cudf.{HostBufferConsumer, HostMemoryBuffer, NvtxColor, NvtxRange, TableWriter}
import ai.rapids.cudf.{HostBufferConsumer, HostMemoryBuffer, NvtxColor, NvtxRange, Table, TableWriter}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
import org.apache.hadoop.mapreduce.TaskAttemptContext
Expand Down Expand Up @@ -60,7 +60,7 @@ abstract class ColumnarOutputWriterFactory extends Serializable {
* `org.apache.spark.sql.execution.datasources.OutputWriter`.
*/
abstract class ColumnarOutputWriter(path: String, context: TaskAttemptContext,
dataSchema: StructType, rangeName: String) extends HostBufferConsumer {
dataSchema: StructType, rangeName: String) extends HostBufferConsumer with Arm {

val tableWriter: TableWriter
val conf = context.getConfiguration
Expand Down Expand Up @@ -130,6 +130,10 @@ abstract class ColumnarOutputWriter(path: String, context: TaskAttemptContext,
}
}

protected def scanTableBeforeWrite(table: Table): Unit = {
// NOOP for now, but allows a child to override this
}

/**
* Writes the columnar batch and returns the time in ns taken to write
*
Expand All @@ -140,17 +144,12 @@ abstract class ColumnarOutputWriter(path: String, context: TaskAttemptContext,
var needToCloseBatch = true
try {
val startTimestamp = System.nanoTime
val nvtxRange = new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)
try {
val table = GpuColumnVector.from(batch)
try {
withResource(new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)) { _ =>
withResource(GpuColumnVector.from(batch)) { table =>
scanTableBeforeWrite(table)
anythingWritten = true
tableWriter.write(table)
} finally {
table.close()
}
} finally {
nvtxRange.close()
}

// Batch is no longer needed, write process from here does not use GPU.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids

import ai.rapids.cudf._
import com.nvidia.spark.RebaseHelper
import org.apache.hadoop.mapreduce.{Job, OutputCommitter, TaskAttemptContext}
import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat}
import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel
Expand All @@ -25,11 +26,12 @@ import org.apache.parquet.hadoop.util.ContextUtil

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetWriteSupport}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types.{StructType, TimestampType}
import org.apache.spark.sql.types.{DateType, StructType, TimestampType}

object GpuParquetFileFormat {
def tagGpuSupport(
Expand Down Expand Up @@ -69,6 +71,21 @@ object GpuParquetFileFormat {
}
}

val schemaHasDates = schema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[DateType])
}

sqlConf.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE) match {
case "EXCEPTION" => //Good
case "CORRECTED" => //Good
case "LEGACY" =>
if (schemaHasDates || schemaHasTimestamps) {
meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported")
}
case other =>
meta.willNotWorkOnGpu(s"$other is not a supported rebase mode")
}

if (meta.canThisBeReplaced) {
Some(new GpuParquetFileFormat)
} else {
Expand Down Expand Up @@ -101,6 +118,9 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging {

val conf = ContextUtil.getConfiguration(job)

val dateTimeRebaseException =
"EXCEPTION".equals(conf.get(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key))

val committerClass =
conf.getClass(
SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key,
Expand Down Expand Up @@ -179,7 +199,7 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging {
path: String,
dataSchema: StructType,
context: TaskAttemptContext): ColumnarOutputWriter = {
new GpuParquetWriter(path, dataSchema, compressionType, context)
new GpuParquetWriter(path, dataSchema, compressionType, dateTimeRebaseException, context)
}

override def getFileExtension(context: TaskAttemptContext): String = {
Expand All @@ -193,9 +213,20 @@ class GpuParquetWriter(
path: String,
dataSchema: StructType,
compressionType: CompressionType,
dateTimeRebaseException: Boolean,
context: TaskAttemptContext)
extends ColumnarOutputWriter(path, context, dataSchema, "Parquet") {

override def scanTableBeforeWrite(table: Table): Unit = {
if (dateTimeRebaseException) {
(0 until table.getNumberOfColumns).foreach { i =>
if (RebaseHelper.isDateTimeRebaseNeededWrite(table.getColumn(i))) {
throw DataSourceUtils.newRebaseExceptionInWrite("Parquet")
}
}
}
}

override val tableWriter: TableWriter = {
val writeContext = new ParquetWriteSupport().init(conf)
val builder = ParquetWriterOptions.builder()
Expand Down
Loading

0 comments on commit be9f426

Please sign in to comment.