From 2d6612cc8b98f767d73c4d15e4065bf3d6c12ea7 Mon Sep 17 00:00:00 2001 From: Nathan Howell Date: Wed, 6 May 2015 22:56:53 -0700 Subject: [PATCH] [SPARK-5938] [SPARK-5443] [SQL] Improve JsonRDD performance This patch comprises of a few related pieces of work: * Schema inference is performed directly on the JSON token stream * `String => Row` conversion populate Spark SQL structures without intermediate types * Projection pushdown is implemented via CatalystScan for DataFrame queries * Support for the legacy parser by setting `spark.sql.json.useJacksonStreamingAPI` to `false` Performance improvements depend on the schema and queries being executed, but it should be faster across the board. Below are benchmarks using the last.fm Million Song dataset: ``` Command | Baseline | Patched ---------------------------------------------------|----------|-------- import sqlContext.implicits._ | | val df = sqlContext.jsonFile("/tmp/lastfm.json") | 70.0s | 14.6s df.count() | 28.8s | 6.2s df.rdd.count() | 35.3s | 21.5s df.where($"artist" === "Robert Hood").collect() | 28.3s | 16.9s ``` To prepare this dataset for benchmarking, follow these steps: ``` # Fetch the datasets from http://labrosa.ee.columbia.edu/millionsong/lastfm wget http://labrosa.ee.columbia.edu/millionsong/sites/default/files/lastfm/lastfm_test.zip \ http://labrosa.ee.columbia.edu/millionsong/sites/default/files/lastfm/lastfm_train.zip # Decompress and combine, pipe through `jq -c` to ensure there is one record per line unzip -p lastfm_test.zip lastfm_train.zip | jq -c . > lastfm.json ``` Author: Nathan Howell Closes #5801 from NathanHowell/json-performance and squashes the following commits: 26fea31 [Nathan Howell] Recreate the baseRDD each for each scan operation a7ebeb2 [Nathan Howell] Increase coverage of inserts into a JSONRelation e06a1dd [Nathan Howell] Add comments to the `useJacksonStreamingAPI` config flag 6822712 [Nathan Howell] Split up JsonRDD2 into multiple objects fa8234f [Nathan Howell] Wrap long lines b31917b [Nathan Howell] Rename `useJsonRDD2` to `useJacksonStreamingAPI` 15c5d1b [Nathan Howell] JSONRelation's baseRDD need not be lazy f8add6e [Nathan Howell] Add comments on lack of support for precision and scale DecimalTypes fa0be47 [Nathan Howell] Remove unused default case in the field parser 80dba17 [Nathan Howell] Add comments regarding null handling and empty strings 842846d [Nathan Howell] Point the empty schema inference test at JsonRDD2 ab6ee87 [Nathan Howell] Add projection pushdown support to JsonRDD/JsonRDD2 f636c14 [Nathan Howell] Enable JsonRDD2 by default, add a flag to switch back to JsonRDD 0bbc445 [Nathan Howell] Improve JSON parsing and type inference performance 7ca70c1 [Nathan Howell] Eliminate arrow pattern, replace with pattern matches --- .../catalyst/analysis/HiveTypeCoercion.scala | 43 ++-- .../apache/spark/sql/types/StructType.scala | 4 + .../org/apache/spark/sql/DataFrame.scala | 4 +- .../scala/org/apache/spark/sql/SQLConf.scala | 8 + .../org/apache/spark/sql/SQLContext.scala | 34 +-- .../apache/spark/sql/json/InferSchema.scala | 171 ++++++++++++++ .../apache/spark/sql/json/JSONRelation.scala | 99 ++++++-- .../spark/sql/json/JacksonGenerator.scala | 77 +++++++ .../apache/spark/sql/json/JacksonParser.scala | 215 ++++++++++++++++++ .../apache/spark/sql/json/JacksonUtils.scala | 32 +++ .../org/apache/spark/sql/json/JsonRDD.scala | 50 ---- .../org/apache/spark/sql/json/JsonSuite.scala | 51 +++-- .../spark/sql/sources/InsertSuite.scala | 55 ++++- 13 files changed, 715 insertions(+), 128 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 96e2aee4de15b..873c75c525c3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -26,7 +26,14 @@ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: private val numericPrecedence = - Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited) + IndexedSeq( + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType.Unlimited) /** * Find the tightest common type of two types that might be used in a binary expression. @@ -34,25 +41,21 @@ object HiveTypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. */ - def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { - val valueTypes = Seq(t1, t2).filter(t => t != NullType) - if (valueTypes.distinct.size > 1) { - // Promote numeric types to the highest of the two and all numeric types to unlimited decimal - if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) { - Some(numericPrecedence.filter(t => t == t1 || t == t2).last) - } else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) { - // Fixed-precision decimals can up-cast into unlimited - if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) { - Some(DecimalType.Unlimited) - } else { - None - } - } else { - None - } - } else { - Some(if (valueTypes.size == 0) NullType else valueTypes.head) - } + val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + case (t1, t2) if t1 == t2 => Some(t1) + case (NullType, t1) => Some(t1) + case (t1, NullType) => Some(t1) + + // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => + val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) + Some(numericPrecedence(index)) + + // Fixed-precision decimals can up-cast into unlimited + case (DecimalType.Unlimited, _: DecimalType) => Some(DecimalType.Unlimited) + case (_: DecimalType, DecimalType.Unlimited) => Some(DecimalType.Unlimited) + + case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d80ffca18ec9a..7e00a27dfe724 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -134,6 +134,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) } + private[sql] def getFieldIndex(name: String): Option[Int] = { + nameToIndex.get(name) + } + protected[sql] def toAttributes: Seq[AttributeReference] = map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 9d2cd7aae3b82..79fbf50300d4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails -import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.json.{JacksonGenerator, JsonRDD} import org.apache.spark.sql.types._ import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} import org.apache.spark.util.Utils @@ -1415,7 +1415,7 @@ class DataFrame private[sql]( new Iterator[String] { override def hasNext: Boolean = iter.hasNext override def next(): String = { - JsonRDD.rowToJSON(rowSchema, gen)(iter.next()) + JacksonGenerator(rowSchema, gen)(iter.next()) gen.flush() val json = writer.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 3ffc2091d6ba1..bfaddd0f2ce1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -73,6 +73,8 @@ private[spark] object SQLConf { val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2" + val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI" + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -166,6 +168,12 @@ private[sql] class SQLConf extends Serializable { private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean + /** + * Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0 + */ + private[spark] def useJacksonStreamingAPI: Boolean = + getConf(USE_JACKSON_STREAMING_API, "true").toBoolean + /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to * a broadcast value during the physical executions of join operations. Setting this to -1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7eabb93c1e3d6..0563430a6fdc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -659,13 +659,17 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { - val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord - val appliedSchema = - Option(schema).getOrElse( - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema, needsConversion = false) + if (conf.useJacksonStreamingAPI) { + baseRelationToDataFrame(new JSONRelation(() => json, None, 1.0, Some(schema))(this)) + } else { + val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord + val appliedSchema = + Option(schema).getOrElse( + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) + } } /** @@ -689,12 +693,16 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { - val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord - val appliedSchema = - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema, needsConversion = false) + if (conf.useJacksonStreamingAPI) { + baseRelationToDataFrame(new JSONRelation(() => json, None, samplingRatio, None)(this)) + } else { + val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord + val appliedSchema = + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala new file mode 100644 index 0000000000000..9c58b8e4bb16a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import com.fasterxml.jackson.core._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion +import org.apache.spark.sql.json.JacksonUtils.nextUntil +import org.apache.spark.sql.types._ + +private[sql] object InferSchema { + /** + * Infer the type of a collection of json records in three stages: + * 1. Infer the type of each record + * 2. Merge types by choosing the lowest type necessary to cover equal keys + * 3. Replace any remaining null fields with string, the top type + */ + def apply( + json: RDD[String], + samplingRatio: Double = 1.0, + columnNameOfCorruptRecords: String): StructType = { + require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") + val schemaData = if (samplingRatio > 0.99) { + json + } else { + json.sample(withReplacement = false, samplingRatio, 1) + } + + // perform schema inference on each row and merge afterwards + schemaData.mapPartitions { iter => + val factory = new JsonFactory() + iter.map { row => + try { + val parser = factory.createParser(row) + parser.nextToken() + inferField(parser) + } catch { + case _: JsonParseException => + StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) + } + } + }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) match { + case st: StructType => nullTypeToStringType(st) + } + } + + /** + * Infer the type of a json document from the parser's token stream + */ + private def inferField(parser: JsonParser): DataType = { + import com.fasterxml.jackson.core.JsonToken._ + parser.getCurrentToken match { + case null | VALUE_NULL => NullType + + case FIELD_NAME => + parser.nextToken() + inferField(parser) + + case VALUE_STRING if parser.getTextLength < 1 => + // Zero length strings and nulls have special handling to deal + // with JSON generators that do not distinguish between the two. + // To accurately infer types for empty strings that are really + // meant to represent nulls we assume that the two are isomorphic + // but will defer treating null fields as strings until all the + // record fields' types have been combined. + NullType + + case VALUE_STRING => StringType + case START_OBJECT => + val builder = Seq.newBuilder[StructField] + while (nextUntil(parser, END_OBJECT)) { + builder += StructField(parser.getCurrentName, inferField(parser), nullable = true) + } + + StructType(builder.result().sortBy(_.name)) + + case START_ARRAY => + // If this JSON array is empty, we use NullType as a placeholder. + // If this array is not empty in other JSON objects, we can resolve + // the type as we pass through all JSON objects. + var elementType: DataType = NullType + while (nextUntil(parser, END_ARRAY)) { + elementType = compatibleType(elementType, inferField(parser)) + } + + ArrayType(elementType) + + case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => + import JsonParser.NumberType._ + parser.getNumberType match { + // For Integer values, use LongType by default. + case INT | LONG => LongType + // Since we do not have a data type backed by BigInteger, + // when we see a Java BigInteger, we use DecimalType. + case BIG_INTEGER | BIG_DECIMAL => DecimalType.Unlimited + case FLOAT | DOUBLE => DoubleType + } + + case VALUE_TRUE | VALUE_FALSE => BooleanType + } + } + + private def nullTypeToStringType(struct: StructType): StructType = { + val fields = struct.fields.map { + case StructField(fieldName, dataType, nullable, _) => + val newType = dataType match { + case NullType => StringType + case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) + case ArrayType(struct: StructType, containsNull) => + ArrayType(nullTypeToStringType(struct), containsNull) + case struct: StructType =>nullTypeToStringType(struct) + case other: DataType => other + } + + StructField(fieldName, newType, nullable) + } + + StructType(fields) + } + + /** + * Remove top-level ArrayType wrappers and merge the remaining schemas + */ + private def compatibleRootType: (DataType, DataType) => DataType = { + case (ArrayType(ty1, _), ty2) => compatibleRootType(ty1, ty2) + case (ty1, ArrayType(ty2, _)) => compatibleRootType(ty1, ty2) + case (ty1, ty2) => compatibleType(ty1, ty2) + } + + /** + * Returns the most general data type for two given data types. + */ + private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { + HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse { + // t1 or t2 is a StructType, ArrayType, or an unexpected type. + (t1, t2) match { + case (other: DataType, NullType) => other + case (NullType, other: DataType) => other + case (StructType(fields1), StructType(fields2)) => + val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { + case (name, fieldTypes) => + val dataType = fieldTypes.view.map(_.dataType).reduce(compatibleType) + StructField(name, dataType, nullable = true) + } + StructType(newFields.toSeq.sortBy(_.name)) + + case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + + // strings and every string is a Json object. + case (_, _) => StringType + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index e3352d02787fd..c772cd1f53e53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -22,14 +22,16 @@ import java.io.IOException import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute, Row} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} private[sql] class DefaultSource - extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider { + extends RelationProvider + with SchemaRelationProvider + with CreatableRelationProvider { private def checkPath(parameters: Map[String, String]): String = { parameters.getOrElse("path", sys.error("'path' must be specified for json data.")) @@ -42,7 +44,7 @@ private[sql] class DefaultSource val path = checkPath(parameters) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - JSONRelation(path, samplingRatio, None)(sqlContext) + new JSONRelation(path, samplingRatio, None, sqlContext) } /** Returns a new base relation with the given schema and parameters. */ @@ -53,7 +55,7 @@ private[sql] class DefaultSource val path = checkPath(parameters) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - JSONRelation(path, samplingRatio, Some(schema))(sqlContext) + new JSONRelation(path, samplingRatio, Some(schema), sqlContext) } override def createRelation( @@ -101,32 +103,87 @@ private[sql] class DefaultSource } } -private[sql] case class JSONRelation( - path: String, - samplingRatio: Double, +private[sql] class JSONRelation( + // baseRDD is not immutable with respect to INSERT OVERWRITE + // and so it must be recreated at least as often as the + // underlying inputs are modified. To be safe, a function is + // used instead of a regular RDD value to ensure a fresh RDD is + // recreated for each and every operation. + baseRDD: () => RDD[String], + val path: Option[String], + val samplingRatio: Double, userSpecifiedSchema: Option[StructType])( @transient val sqlContext: SQLContext) extends BaseRelation with TableScan - with InsertableRelation { - - // TODO: Support partitioned JSON relation. - private def baseRDD = sqlContext.sparkContext.textFile(path) + with InsertableRelation + with CatalystScan { + + def this( + path: String, + samplingRatio: Double, + userSpecifiedSchema: Option[StructType], + sqlContext: SQLContext) = + this( + () => sqlContext.sparkContext.textFile(path), + Some(path), + samplingRatio, + userSpecifiedSchema)(sqlContext) + + private val useJacksonStreamingAPI: Boolean = sqlContext.conf.useJacksonStreamingAPI override val needConversion: Boolean = false - override val schema = userSpecifiedSchema.getOrElse( - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema( - baseRDD, + override lazy val schema = userSpecifiedSchema.getOrElse { + if (useJacksonStreamingAPI) { + InferSchema( + baseRDD(), samplingRatio, - sqlContext.conf.columnNameOfCorruptRecord))) + sqlContext.conf.columnNameOfCorruptRecord) + } else { + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema( + baseRDD(), + samplingRatio, + sqlContext.conf.columnNameOfCorruptRecord)) + } + } - override def buildScan(): RDD[Row] = - JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord) + override def buildScan(): RDD[Row] = { + if (useJacksonStreamingAPI) { + JacksonParser( + baseRDD(), + schema, + sqlContext.conf.columnNameOfCorruptRecord) + } else { + JsonRDD.jsonStringToRow( + baseRDD(), + schema, + sqlContext.conf.columnNameOfCorruptRecord) + } + } + + override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = { + if (useJacksonStreamingAPI) { + JacksonParser( + baseRDD(), + StructType.fromAttributes(requiredColumns), + sqlContext.conf.columnNameOfCorruptRecord) + } else { + JsonRDD.jsonStringToRow( + baseRDD(), + StructType.fromAttributes(requiredColumns), + sqlContext.conf.columnNameOfCorruptRecord) + } + } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - val filesystemPath = new Path(path) + val filesystemPath = path match { + case Some(p) => new Path(p) + case None => + throw new IOException(s"Cannot INSERT into table with no path defined") + } + val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) if (overwrite) { @@ -147,7 +204,7 @@ private[sql] case class JSONRelation( } } // Write the data. - data.toJSON.saveAsTextFile(path) + data.toJSON.saveAsTextFile(filesystemPath.toString) // Right now, we assume that the schema is not changed. We will not update the schema. // schema = data.schema } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala new file mode 100644 index 0000000000000..80bf74aa02602 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import scala.collection.Map + +import com.fasterxml.jackson.core._ + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +private[sql] object JacksonGenerator { + /** Transforms a single Row to JSON using Jackson + * + * @param rowSchema the schema object used for conversion + * @param gen a JsonGenerator object + * @param row The row to convert + */ + def apply(rowSchema: StructType, gen: JsonGenerator)(row: Row): Unit = { + def valWriter: (DataType, Any) => Unit = { + case (_, null) | (NullType, _) => gen.writeNull() + case (StringType, v: String) => gen.writeString(v) + case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) + case (IntegerType, v: Int) => gen.writeNumber(v) + case (ShortType, v: Short) => gen.writeNumber(v) + case (FloatType, v: Float) => gen.writeNumber(v) + case (DoubleType, v: Double) => gen.writeNumber(v) + case (LongType, v: Long) => gen.writeNumber(v) + case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v) + case (ByteType, v: Byte) => gen.writeNumber(v.toInt) + case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) + case (BooleanType, v: Boolean) => gen.writeBoolean(v) + case (DateType, v) => gen.writeString(v.toString) + case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v)) + + case (ArrayType(ty, _), v: Seq[_] ) => + gen.writeStartArray() + v.foreach(valWriter(ty,_)) + gen.writeEndArray() + + case (MapType(kv,vv, _), v: Map[_,_]) => + gen.writeStartObject() + v.foreach { p => + gen.writeFieldName(p._1.toString) + valWriter(vv,p._2) + } + gen.writeEndObject() + + case (StructType(ty), v: Row) => + gen.writeStartObject() + ty.zip(v.toSeq).foreach { + case (_, null) => + case (field, v) => + gen.writeFieldName(field.name) + valWriter(field.dataType, v) + } + gen.writeEndObject() + } + + valWriter(rowSchema, row) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala new file mode 100644 index 0000000000000..a8e69ae61174f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import java.io.ByteArrayOutputStream +import java.sql.Timestamp + +import scala.collection.Map + +import com.fasterxml.jackson.core._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.json.JacksonUtils.nextUntil +import org.apache.spark.sql.types._ + +private[sql] object JacksonParser { + def apply( + json: RDD[String], + schema: StructType, + columnNameOfCorruptRecords: String): RDD[Row] = { + parseJson(json, schema, columnNameOfCorruptRecords) + } + + /** + * Parse the current token (and related children) according to a desired schema + */ + private[sql] def convertField( + factory: JsonFactory, + parser: JsonParser, + schema: DataType): Any = { + import com.fasterxml.jackson.core.JsonToken._ + (parser.getCurrentToken, schema) match { + case (null | VALUE_NULL, _) => + null + + case (FIELD_NAME, _) => + parser.nextToken() + convertField(factory, parser, schema) + + case (VALUE_STRING, StringType) => + UTF8String(parser.getText) + + case (VALUE_STRING, _) if parser.getTextLength < 1 => + // guard the non string type + null + + case (VALUE_STRING, DateType) => + DateUtils.millisToDays(DateUtils.stringToTime(parser.getText).getTime) + + case (VALUE_STRING, TimestampType) => + new Timestamp(DateUtils.stringToTime(parser.getText).getTime) + + case (VALUE_NUMBER_INT, TimestampType) => + new Timestamp(parser.getLongValue) + + case (_, StringType) => + val writer = new ByteArrayOutputStream() + val generator = factory.createGenerator(writer, JsonEncoding.UTF8) + generator.copyCurrentStructure(parser) + generator.close() + UTF8String(writer.toByteArray) + + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) => + parser.getFloatValue + + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) => + parser.getDoubleValue + + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DecimalType()) => + // TODO: add fixed precision and scale handling + Decimal(parser.getDecimalValue) + + case (VALUE_NUMBER_INT, ByteType) => + parser.getByteValue + + case (VALUE_NUMBER_INT, ShortType) => + parser.getShortValue + + case (VALUE_NUMBER_INT, IntegerType) => + parser.getIntValue + + case (VALUE_NUMBER_INT, LongType) => + parser.getLongValue + + case (VALUE_TRUE, BooleanType) => + true + + case (VALUE_FALSE, BooleanType) => + false + + case (START_OBJECT, st: StructType) => + convertObject(factory, parser, st) + + case (START_ARRAY, ArrayType(st, _)) => + convertList(factory, parser, st) + + case (START_OBJECT, ArrayType(st, _)) => + // the business end of SPARK-3308: + // when an object is found but an array is requested just wrap it in a list + convertField(factory, parser, st) :: Nil + + case (START_OBJECT, MapType(StringType, kt, _)) => + convertMap(factory, parser, kt) + + case (_, udt: UserDefinedType[_]) => + udt.deserialize(convertField(factory, parser, udt.sqlType)) + } + } + + /** + * Parse an object from the token stream into a new Row representing the schema. + * + * Fields in the json that are not defined in the requested schema will be dropped. + */ + private def convertObject(factory: JsonFactory, parser: JsonParser, schema: StructType): Row = { + val row = new GenericMutableRow(schema.length) + while (nextUntil(parser, JsonToken.END_OBJECT)) { + schema.getFieldIndex(parser.getCurrentName) match { + case Some(index) => + row.update(index, convertField(factory, parser, schema(index).dataType)) + + case None => + parser.skipChildren() + } + } + + row + } + + /** + * Parse an object as a Map, preserving all fields + */ + private def convertMap( + factory: JsonFactory, + parser: JsonParser, + valueType: DataType): Map[String, Any] = { + val builder = Map.newBuilder[String, Any] + while (nextUntil(parser, JsonToken.END_OBJECT)) { + builder += parser.getCurrentName -> convertField(factory, parser, valueType) + } + + builder.result() + } + + private def convertList( + factory: JsonFactory, + parser: JsonParser, + schema: DataType): Seq[Any] = { + val builder = Seq.newBuilder[Any] + while (nextUntil(parser, JsonToken.END_ARRAY)) { + builder += convertField(factory, parser, schema) + } + + builder.result() + } + + private def parseJson( + json: RDD[String], + schema: StructType, + columnNameOfCorruptRecords: String): RDD[Row] = { + + def failedRecord(record: String): Seq[Row] = { + // create a row even if no corrupt record column is present + val row = new GenericMutableRow(schema.length) + for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) { + require(schema(corruptIndex).dataType == StringType) + row.update(corruptIndex, record) + } + + Seq(row) + } + + json.mapPartitions { iter => + val factory = new JsonFactory() + + iter.flatMap { record => + try { + val parser = factory.createParser(record) + parser.nextToken() + + // to support both object and arrays (see SPARK-3308) we'll start + // by converting the StructType schema to an ArrayType and let + // convertField wrap an object into a single value array when necessary. + convertField(factory, parser, ArrayType(schema)) match { + case null => failedRecord(record) + case list: Seq[Row @unchecked] => list + case _ => + sys.error( + s"Failed to parse record $record. Please make sure that each line of the file " + + "(or each string in the RDD) is a valid JSON object or an array of JSON objects.") + } + } catch { + case _: JsonProcessingException => + failedRecord(record) + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala new file mode 100644 index 0000000000000..fde96852ce68e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import com.fasterxml.jackson.core.{JsonParser, JsonToken} + +private object JacksonUtils { + /** + * Advance the parser until a null or a specific token is found + */ + def nextUntil(parser: JsonParser, stopOn: JsonToken): Boolean = { + parser.nextToken() match { + case null => false + case x => x != stopOn + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 6e94e7056eb0b..f62973d5fcfab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -440,54 +440,4 @@ private[sql] object JsonRDD extends Logging { row } - - /** Transforms a single Row to JSON using Jackson - * - * @param rowSchema the schema object used for conversion - * @param gen a JsonGenerator object - * @param row The row to convert - */ - private[sql] def rowToJSON(rowSchema: StructType, gen: JsonGenerator)(row: Row) = { - def valWriter: (DataType, Any) => Unit = { - case (_, null) | (NullType, _) => gen.writeNull() - case (StringType, v: String) => gen.writeString(v) - case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) - case (IntegerType, v: Int) => gen.writeNumber(v) - case (ShortType, v: Short) => gen.writeNumber(v) - case (FloatType, v: Float) => gen.writeNumber(v) - case (DoubleType, v: Double) => gen.writeNumber(v) - case (LongType, v: Long) => gen.writeNumber(v) - case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v) - case (ByteType, v: Byte) => gen.writeNumber(v.toInt) - case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) - case (BooleanType, v: Boolean) => gen.writeBoolean(v) - case (DateType, v) => gen.writeString(v.toString) - case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v) - - case (ArrayType(ty, _), v: Seq[_] ) => - gen.writeStartArray() - v.foreach(valWriter(ty,_)) - gen.writeEndArray() - - case (MapType(kv,vv, _), v: Map[_,_]) => - gen.writeStartObject() - v.foreach { p => - gen.writeFieldName(p._1.toString) - valWriter(vv,p._2) - } - gen.writeEndObject() - - case (StructType(ty), v: Row) => - gen.writeStartObject() - ty.zip(v.toSeq).foreach { - case (_, null) => - case (field, v) => - gen.writeFieldName(field.name) - valWriter(field.dataType, v) - } - gen.writeEndObject() - } - - valWriter(rowSchema, row) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index fd0e2746dc045..263fafba930ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.json +import java.io.StringWriter import java.sql.{Date, Timestamp} +import com.fasterxml.jackson.core.JsonFactory import org.scalactic.Tolerance._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType} +import org.apache.spark.sql.json.InferSchema.compatibleType import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -46,6 +48,18 @@ class JsonSuite extends QueryTest { s"${expected}(${expected.getClass}).") } + val factory = new JsonFactory() + def enforceCorrectType(value: Any, dataType: DataType): Any = { + val writer = new StringWriter() + val generator = factory.createGenerator(writer) + generator.writeObject(value) + generator.flush() + + val parser = factory.createParser(writer.toString) + parser.nextToken() + JacksonParser.convertField(factory, parser, dataType) + } + val intNumber: Int = 2147483647 checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType)) checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) @@ -439,7 +453,7 @@ class JsonSuite extends QueryTest { val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") - // Right now, the analyzer does not promote strings in a boolean expreesion. + // Right now, the analyzer does not promote strings in a boolean expression. // Number and Boolean conflict: resolve the type as boolean in this query. checkAnswer( sql("select num_bool from jsonTable where NOT num_bool"), @@ -508,7 +522,7 @@ class JsonSuite extends QueryTest { Row(Seq(), "11", "[1,2,3]", Row(null), "[]") :: Row(null, """{"field":false}""", null, null, "{}") :: Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") :: - Row(Seq(7), "{}","[str1,str2,33]", Row("str"), """{"field":true}""") :: Nil + Row(Seq(7), "{}","""["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil ) } @@ -566,19 +580,19 @@ class JsonSuite extends QueryTest { val analyzed = jsonDF.queryExecution.analyzed assert( analyzed.isInstanceOf[LogicalRelation], - "The DataFrame returned by jsonFile should be based on JSONRelation.") + "The DataFrame returned by jsonFile should be based on LogicalRelation.") val relation = analyzed.asInstanceOf[LogicalRelation].relation assert( relation.isInstanceOf[JSONRelation], "The DataFrame returned by jsonFile should be based on JSONRelation.") - assert(relation.asInstanceOf[JSONRelation].path === path) + assert(relation.asInstanceOf[JSONRelation].path === Some(path)) assert(relation.asInstanceOf[JSONRelation].samplingRatio === (0.49 +- 0.001)) val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = jsonFile(path, schema).queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] - assert(relationWithSchema.path === path) + assert(relationWithSchema.path === Some(path)) assert(relationWithSchema.schema === schema) assert(relationWithSchema.samplingRatio > 0.99) } @@ -1020,15 +1034,24 @@ class JsonSuite extends QueryTest { } test("JSONRelation equality test") { - val relation1 = - JSONRelation("path", 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(null) + val context = org.apache.spark.sql.test.TestSQLContext + val relation1 = new JSONRelation( + "path", + 1.0, + Some(StructType(StructField("a", IntegerType, true) :: Nil)), + context) val logicalRelation1 = LogicalRelation(relation1) - val relation2 = - JSONRelation("path", 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)))( - org.apache.spark.sql.test.TestSQLContext) + val relation2 = new JSONRelation( + "path", + 0.5, + Some(StructType(StructField("a", IntegerType, true) :: Nil)), + context) val logicalRelation2 = LogicalRelation(relation2) - val relation3 = - JSONRelation("path", 1.0, Some(StructType(StructField("b", StringType, true) :: Nil)))(null) + val relation3 = new JSONRelation( + "path", + 1.0, + Some(StructType(StructField("b", StringType, true) :: Nil)), + context) val logicalRelation3 = LogicalRelation(relation3) assert(relation1 === relation2) @@ -1046,7 +1069,7 @@ class JsonSuite extends QueryTest { test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception - val emptySchema = JsonRDD.inferSchema(empty, 1.0, "") + val emptySchema = InferSchema(empty, 1.0, "") assert(StructType(Seq()) === emptySchema) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 80efe9728fbc2..50629ea4dc066 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -21,7 +21,7 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.{SaveMode, AnalysisException, Row} import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with BeforeAndAfterAll { @@ -100,23 +100,48 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { test("INSERT OVERWRITE a JSONRelation multiple times") { sql( s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt + """.stripMargin) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + // Writing the table to less part files. + val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5) + jsonRDD(rdd1).registerTempTable("jt1") sql( s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 + """.stripMargin) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + // Writing the table to more part files. + val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10) + jsonRDD(rdd2).registerTempTable("jt2") sql( s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt - """.stripMargin) - + |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 + """.stripMargin) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) ) + + sql( + s""" + |INSERT OVERWRITE TABLE jsonTable SELECT a * 10, b FROM jt1 + """.stripMargin) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i * 10, s"str$i")) + ) + + dropTempTable("jt1") + dropTempTable("jt2") } test("INSERT INTO not supported for JSONRelation for now") { @@ -128,6 +153,20 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } } + test("save directly to the path of a JSON table") { + table("jt").selectExpr("a * 5 as a", "b").save(path.toString, "json", SaveMode.Overwrite) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i * 5, s"str$i")) + ) + + table("jt").save(path.toString, "json", SaveMode.Overwrite) + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, s"str$i")) + ) + } + test("it is not allowed to write to a table while querying it.") { val message = intercept[AnalysisException] { sql(