Skip to content

Commit

Permalink
Now with more than one relation.
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Jul 7, 2014
1 parent 2fd5a85 commit 457d699
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 16 deletions.
44 changes: 32 additions & 12 deletions sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ object SQLMacros {

case class Schema(dataType: DataType, nullable: Boolean)

def sqlImpl[A <: Product : c.WeakTypeTag](c: Context)(relation: c.Expr[RDD[A]]) = {
def sqlImpl(c: Context)(args: c.Expr[Any]*) = {
import c.universe._

// TODO: Don't copy this function.
// TODO: Don't copy this function from ScalaReflection.
def schemaFor(tpe: `Type`): Schema = tpe match {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Expand Down Expand Up @@ -69,33 +69,54 @@ object SQLMacros {
scala.StringContext.apply(..$rawParts))""" = c.prefix.tree

val parts = rawParts.map(_.toString.stripPrefix("\"").stripSuffix("\""))
val query = parts(0) + "table1" + parts(1)
val query = parts(0) + (0 until args.size).map { i =>
s"table$i" + parts(i + 1)
}.mkString("")

val parser = new SqlParser()
val logicalPlan = parser(query)
val catalog = new SimpleCatalog
val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, false)

val inputType = weakTypeTag[A]
val inputSchema = schemaFor(inputType.tpe).dataType.asInstanceOf[StructType].toAttributes
catalog.registerTable(None, "table1", LocalRelation(inputSchema:_*))
val tables = args.zipWithIndex.map { case (arg, i) =>
val TypeRef(_, _, Seq(schemaType)) = arg.actualType

val inputSchema = schemaFor(schemaType).dataType.asInstanceOf[StructType].toAttributes
(s"table$i", LocalRelation(inputSchema:_*))
}

tables.foreach(t => catalog.registerTable(None, t._1, t._2))

val analyzedPlan = analyzer(logicalPlan)

val fields = analyzedPlan.output.zipWithIndex.map {
case (attr, i) => q"${attr.name} -> row.getString($i)"
case (attr, i) =>
q"""${attr.name} -> row.${newTermName("get" + primitiveForType(attr.dataType))}($i)"""
}

val tree = q"""
import records.R
$relation.registerAsTable("table1")
..${args.zipWithIndex.map{ case (r,i) => q"""$r.registerAsTable(${s"table$i"})""" }}
val result = sql($query)
// TODO: Avoid double copy
result.map(row => R(..$fields))
"""

// TODO: Why do I need this cast?
c.Expr(tree).asInstanceOf[c.Expr[org.apache.spark.rdd.RDD[records.R{def name: String}]]]
println(tree)

c.Expr(tree)
}

// TODO: Duplicated from codegen PR...
protected def primitiveForType(dt: DataType) = dt match {
case IntegerType => "Int"
case LongType => "Long"
case ShortType => "Short"
case ByteType => "Byte"
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
case StringType => "String"
}
}

Expand All @@ -104,8 +125,7 @@ trait TypedSQL {

@Experimental
implicit class SqlInterpolator(val strCtx: StringContext) {
// TODO: Handle more than one relation
// TODO: Handle functions...
def sql[A <: Product](relation: RDD[A]) = macro SQLMacros.sqlImpl[A]
def sql(args: Any*): Any = macro SQLMacros.sqlImpl
}
}
24 changes: 20 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,30 @@ case class Person(name: String, age: Int)
class TypedSqlSuite extends FunSuite {
import TestSQLContext._

test("typed query") {
val people = sparkContext.parallelize(
Person("Michael", 30) ::
Person("Bob", 40) :: Nil)
val people = sparkContext.parallelize(
Person("Michael", 30) ::
Person("Bob", 40) :: Nil)

test("typed query") {
val results = sql"SELECT name FROM $people WHERE age = 30"
assert(results.first().name == "Michael")
}

test("int results") {
val results = sql"SELECT * FROM $people WHERE age = 30"
assert(results.first().name == "Michael")
assert(results.first().age == 30)
}

ignore("nested results") { }

ignore("join query") {
val results = sql"""
SELECT a.name
FROM $people a
JOIN $people b ON a.age = b.age
"""
// TODO: R is not serializable.
// assert(results.first().name == "Michael")
}
}

0 comments on commit 457d699

Please sign in to comment.