From c6c60e38cd7bb8b8878bf1e010c910e88bb372c5 Mon Sep 17 00:00:00 2001 From: Tobias Schlatter Date: Tue, 8 Jul 2014 13:41:57 +0200 Subject: [PATCH] Remove intermediate map for records. Allow serialization --- .../scala/org/apache/spark/sql/TypedSql.scala | 42 ++++++++++++++++--- .../org/apache/spark/sql/TypedSqlSuite.scala | 12 ++---- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala b/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala index 813b522d02c24..275686dc709b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala @@ -8,6 +8,7 @@ import org.apache.spark.sql.catalyst.types._ import scala.language.experimental.macros import records._ +import Macros.RecordMacros import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD @@ -89,17 +90,46 @@ object SQLMacros { val analyzedPlan = analyzer(logicalPlan) - val fields = analyzedPlan.output.zipWithIndex.map { - case (attr, i) => - q"""${attr.name} -> row.${newTermName("get" + primitiveForType(attr.dataType))}($i)""" + // TODO: This shouldn't probably be here but somewhere generic + // which defines the catalyst <-> Scala type mapping + def toScalaType(dt: DataType) = dt match { + case IntegerType => definitions.IntTpe + case LongType => definitions.LongTpe + case ShortType => definitions.ShortTpe + case ByteType => definitions.ByteTpe + case DoubleType => definitions.DoubleTpe + case FloatType => definitions.FloatTpe + case BooleanType => definitions.BooleanTpe + case StringType => definitions.StringClass.toType } + val schema = analyzedPlan.output.map(attr => (attr.name, toScalaType(attr.dataType))) + val dataImpl = { + // Generate a case for each field + val cases = analyzedPlan.output.zipWithIndex.map { + case (attr, i) => + cq"""${attr.name} => row.${newTermName("get" + primitiveForType(attr.dataType))}($i)""" + } + + // Implement __data using these cases. + // TODO: Unfortunately, this still boxes. We cannot resolve this + // since the R abstraction depends on the fully generic __data. + // The only way to change this is to create __dataLong, etc. on + // R itself + q""" + val res = fieldName match { + case ..$cases + case _ => ??? + } + res.asInstanceOf[T] + """ + } + + val record: c.Expr[Nothing] = new RecordMacros[c.type](c).record(schema)(tq"Serializable")()(dataImpl) val tree = q""" - import records.R ..${args.zipWithIndex.map{ case (r,i) => q"""$r.registerAsTable(${s"table$i"})""" }} val result = sql($query) - // TODO: Avoid double copy - result.map(row => R(..$fields)) + result.map(row => $record) """ println(tree) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala index a5fb932d883b2..95c886f1b6ccf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala @@ -43,13 +43,9 @@ class TypedSqlSuite extends FunSuite { ignore("nested results") { } - ignore("join query") { - val results = sql""" - SELECT a.name - FROM $people a - JOIN $people b ON a.age = b.age - """ - // TODO: R is not serializable. - // assert(results.first().name == "Michael") + test("join query") { + val results = sql"""SELECT a.name FROM $people a JOIN $people b ON a.age = b.age""" + + assert(results.first().name == "Michael") } } \ No newline at end of file