diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index a6cfd978b5f62..d618ec5298526 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.IdentityConverter import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io.Writable @@ -115,7 +116,7 @@ class HadoopTableReader( val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) + HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) } deserializedHadoopRDD @@ -194,7 +195,7 @@ class HadoopTableReader( // fill the non partition key attributes HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, - mutableRow, Some(tableSerDe)) + mutableRow, tableSerDe) } }.toSeq @@ -264,37 +265,27 @@ private[hive] object HadoopTableReader extends HiveInspectors { * Transform all given raw `Writable`s into `Row`s. * * @param iterator Iterator of all `Writable`s to be transformed - * @param deserializer The `Deserializer` associated with the input `Writable` + * @param rawDeser The `Deserializer` associated with the input `Writable` * @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding * positions in the output schema * @param mutableRow A reusable `MutableRow` that should be filled - * @param convertdeserializer The `Deserializer` covert the `deserializer` + * @param tableDeser Table Deserializer * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( iterator: Iterator[Writable], - deserializer: Deserializer, + rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], mutableRow: MutableRow, - convertdeserializer: Option[Deserializer] = None): Iterator[Row] = { + tableDeser: Deserializer): Iterator[Row] = { - val soi = convertdeserializer match { - case Some(convert) => - // check need to convert - if (deserializer.getObjectInspector.equals(convert.getObjectInspector)) { - deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] - } - else { - HiveShim.getConvertedOI( - deserializer.getObjectInspector(), - convert.getObjectInspector()).asInstanceOf[StructObjectInspector] - } - case None => - deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] - } + val soi = HiveShim.getConvertedOI( + rawDeser.getObjectInspector, tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector] + + val inputFields = soi.getAllStructFieldRefs val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => - soi.getStructFieldRef(attr.name) -> ordinal + (inputFields.get(ordinal), ordinal) }.unzip // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern @@ -335,17 +326,15 @@ private[hive] object HadoopTableReader extends HiveInspectors { } } - /** - * when the soi and deserializer.getObjectInspector is equal, - * we will get `IdentityConverter`,which mean it won't convert the - * value when schema match - */ - val partTblObjectInspectorConverter = ObjectInspectorConverters.getConverter( - deserializer.getObjectInspector, soi) + val converter = if (rawDeser == tableDeser) { + new IdentityConverter + } else { + ObjectInspectorConverters.getConverter(rawDeser.getObjectInspector, soi) + } // Map each tuple to a row object iterator.map { value => - val raw = partTblObjectInspectorConverter.convert(deserializer.deserialize(value)) + val raw = converter.convert(rawDeser.deserialize(value)) var i = 0 while (i < fieldRefs.length) { val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 33e859427a4b0..5ddde890da32a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -187,7 +187,7 @@ class InsertIntoHiveTableSuite extends QueryTest { sql(s"CREATE TABLE table_with_partition(key int,value string) PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') SELECT key,value FROM testData") - //test schema is the same + // test schema the same between partition and table sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), testData.toSchemaRDD.collect.toSeq diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 4fe8e8621f7b8..d7a4b509b1468 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -242,12 +242,9 @@ private[hive] object HiveShim { } } - // make getConvertedOI compatible between 0.12.0 and 0.13.1 def getConvertedOI(inputOI: ObjectInspector, - outputOI: ObjectInspector, - equalsCheck: java.lang.Boolean = - new java.lang.Boolean(true)): ObjectInspector = { - ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, equalsCheck) + outputOI: ObjectInspector): ObjectInspector = { + ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, true) } def prepareWritable(w: Writable): Writable = { diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 55041417ecc17..3f2b6859e61ce 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -399,7 +399,6 @@ private[hive] object HiveShim { } } - // make getConvertedOI compatible between 0.12.0 and 0.13.1 def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = { ObjectInspectorConverters.getConvertedOI(inputOI, outputOI) }