Skip to content

Commit

Permalink
Remove intermediate map for records. Allow serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
gzm0 committed Jul 8, 2014
1 parent 457d699 commit c6c60e3
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 14 deletions.
42 changes: 36 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.apache.spark.sql.catalyst.types._
import scala.language.experimental.macros

import records._
import Macros.RecordMacros

import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -89,17 +90,46 @@ object SQLMacros {

val analyzedPlan = analyzer(logicalPlan)

val fields = analyzedPlan.output.zipWithIndex.map {
case (attr, i) =>
q"""${attr.name} -> row.${newTermName("get" + primitiveForType(attr.dataType))}($i)"""
// TODO: This shouldn't probably be here but somewhere generic
// which defines the catalyst <-> Scala type mapping
def toScalaType(dt: DataType) = dt match {
case IntegerType => definitions.IntTpe
case LongType => definitions.LongTpe
case ShortType => definitions.ShortTpe
case ByteType => definitions.ByteTpe
case DoubleType => definitions.DoubleTpe
case FloatType => definitions.FloatTpe
case BooleanType => definitions.BooleanTpe
case StringType => definitions.StringClass.toType
}

val schema = analyzedPlan.output.map(attr => (attr.name, toScalaType(attr.dataType)))
val dataImpl = {
// Generate a case for each field
val cases = analyzedPlan.output.zipWithIndex.map {
case (attr, i) =>
cq"""${attr.name} => row.${newTermName("get" + primitiveForType(attr.dataType))}($i)"""
}

// Implement __data using these cases.
// TODO: Unfortunately, this still boxes. We cannot resolve this
// since the R abstraction depends on the fully generic __data.
// The only way to change this is to create __dataLong, etc. on
// R itself
q"""
val res = fieldName match {
case ..$cases
case _ => ???
}
res.asInstanceOf[T]
"""
}

val record: c.Expr[Nothing] = new RecordMacros[c.type](c).record(schema)(tq"Serializable")()(dataImpl)
val tree = q"""
import records.R
..${args.zipWithIndex.map{ case (r,i) => q"""$r.registerAsTable(${s"table$i"})""" }}
val result = sql($query)
// TODO: Avoid double copy
result.map(row => R(..$fields))
result.map(row => $record)
"""

println(tree)
Expand Down
12 changes: 4 additions & 8 deletions sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,9 @@ class TypedSqlSuite extends FunSuite {

ignore("nested results") { }

ignore("join query") {
val results = sql"""
SELECT a.name
FROM $people a
JOIN $people b ON a.age = b.age
"""
// TODO: R is not serializable.
// assert(results.first().name == "Michael")
test("join query") {
val results = sql"""SELECT a.name FROM $people a JOIN $people b ON a.age = b.age"""

assert(results.first().name == "Michael")
}
}

0 comments on commit c6c60e3

Please sign in to comment.