From f5a4049e534da3c55e1b495ce34155236dfb6dee Mon Sep 17 00:00:00 2001 From: Xi Liu Date: Tue, 17 Jun 2014 13:14:40 +0200 Subject: [PATCH] [SPARK-2164][SQL] Allow Hive UDF on columns of type struct Author: Xi Liu Closes #796 from xiliu82/sqlbug and squashes the following commits: 328dfc4 [Xi Liu] [Spark SQL] remove a temporary function after test 354386a [Xi Liu] [Spark SQL] add test suite for UDF on struct 8fc6f51 [Xi Liu] [SparkSQL] allow UDF on struct --- .../org/apache/spark/sql/hive/hiveUdfs.scala | 3 + .../resources/data/files/testUdf/part-00000 | Bin 0 -> 153 bytes .../sql/hive/execution/HiveUdfSuite.scala | 127 ++++++++++++++++++ 3 files changed, 130 insertions(+) create mode 100755 sql/hive/src/test/resources/data/files/testUdf/part-00000 create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 771d2bccf43a7..ad5e24c62c621 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -335,6 +335,9 @@ private[hive] trait HiveInspectors { case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector + case StructType(fields) => + ObjectInspectorFactory.getStandardStructObjectInspector( + fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) } def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUdf/part-00000 new file mode 100755 index 0000000000000000000000000000000000000000..240a5c1a63c5c4016d096cbd13ddc8b787aee8da GIT binary patch literal 153 zcmWG`4P;ZyFG|--EJ#ewNY%?oOv%qL(96u%^DE8C2`|blNleN~)j?8GT##6ltyf%_ zqnD9cma3Opk(yjul9`{U7m`|B5|Ef#!~h0IwrxAkdN!tuT_U@F&YjICfr1 + |) + |PARTITIONED BY (partition STRING) + |ROW FORMAT SERDE '%s' + |STORED AS SEQUENCEFILE + """.stripMargin.format(classOf[PairSerDe].getName) + ) + + TestHive.hql( + "ALTER TABLE hiveUdfTestTable ADD IF NOT EXISTS PARTITION(partition='testUdf') LOCATION '%s'" + .format(this.getClass.getClassLoader.getResource("data/files/testUdf").getFile) + ) + + TestHive.hql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName)) + + TestHive.hql("SELECT testUdf(pair) FROM hiveUdfTestTable") + + TestHive.hql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") +} + +class TestPair(x: Int, y: Int) extends Writable with Serializable { + def this() = this(0, 0) + var entry: (Int, Int) = (x, y) + + override def write(output: DataOutput): Unit = { + output.writeInt(entry._1) + output.writeInt(entry._2) + } + + override def readFields(input: DataInput): Unit = { + val x = input.readInt() + val y = input.readInt() + entry = (x, y) + } +} + +class PairSerDe extends AbstractSerDe { + override def initialize(p1: Configuration, p2: Properties): Unit = {} + + override def getObjectInspector: ObjectInspector = { + ObjectInspectorFactory + .getStandardStructObjectInspector( + Seq("pair"), + Seq(ObjectInspectorFactory.getStandardStructObjectInspector( + Seq("id", "value"), + Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector)) + )) + } + + override def getSerializedClass: Class[_ <: Writable] = classOf[TestPair] + + override def getSerDeStats: SerDeStats = null + + override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = null + + override def deserialize(value: Writable): AnyRef = { + val pair = value.asInstanceOf[TestPair] + + val row = new util.ArrayList[util.ArrayList[AnyRef]] + row.add(new util.ArrayList[AnyRef](2)) + row(0).add(Integer.valueOf(pair.entry._1)) + row(0).add(Integer.valueOf(pair.entry._2)) + + row + } +} + +class PairUdf extends GenericUDF { + override def initialize(p1: Array[ObjectInspector]): ObjectInspector = + ObjectInspectorFactory.getStandardStructObjectInspector( + Seq("id", "value"), + Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector) + ) + + override def evaluate(args: Array[DeferredObject]): AnyRef = { + println("Type = %s".format(args(0).getClass.getName)) + Integer.valueOf(args(0).get.asInstanceOf[TestPair].entry._2) + } + + override def getDisplayString(p1: Array[String]): String = "" +} + + +