Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
cchantep committed Sep 6, 2021
1 parent 87075ef commit 044e7bc
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 83 deletions.
61 changes: 14 additions & 47 deletions dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import shapeless.ops.record.Selector

import scala.annotation.implicitNotFound
import scala.reflect.ClassTag
import scala.reflect.macros.whitebox

import scala.language.experimental.macros

sealed trait UntypedExpression[T] {
Expand Down Expand Up @@ -866,8 +866,7 @@ object SortedTypedColumn {
implicit def caseTypedColumn[T, U : CatalystOrdered] = at[TypedColumn[T, U]](c => defaultAscending(c))
implicit def caseTypeSortedColumn[T, U] = at[SortedTypedColumn[T, U]](identity)
}
}

}

object TypedColumn {
/** Evidence that type `T` has column `K` with type `V`. */
Expand Down Expand Up @@ -899,49 +898,17 @@ object TypedColumn {
): Exists[T, K, V] = new Exists[T, K, V] {}
}

def apply[T, A](x: Function1[T, A]): TypedColumn[T, A] = macro macroImpl[T, A]

def macroImpl[T: c.WeakTypeTag, A: c.WeakTypeTag](c: whitebox.Context)(x: c.Tree) = {

import c.universe._

val t = c.weakTypeOf[T]
val a = c.weakTypeOf[A]

def buildExpression(columnNames: List[String]) = {
val columnName = columnNames.mkString(".")
c.Expr[TypedColumn[T, A]](q"new frameless.TypedColumn[$t, $a]((org.apache.spark.sql.functions.col($columnName)).expr)")
}

x match {
case q"((${_: TermName}:${_: Type}) => ${_: TermName}.${p: TermName})" => buildExpression(List(p.toString()))
case q"(_.${p: TermName})" => buildExpression(List(p.toString()))
case q"(_.${p: TermName}.${x: TermName})" => buildExpression(List(p.toString(), x.toString()))
case q"(_.${p: TermName}.${x: TermName}.${y: TermName})" => buildExpression(List(p.toString(), x.toString(), y.toString()))
case q"(_.${p: TermName}.${x: TermName}.${y: TermName}.${z: TermName})" => buildExpression(List(p.toString(), x.toString(), y.toString(), z.toString()))
case x => throw new IllegalArgumentException(s"$x is not supported")
}
}
}

/** Compute the intersection of two types:
*
* - With[A, A] = A
* - With[A, B] = A with B (when A != B)
*
* This type function is needed to prevent IDEs from infering large types
* with shape `A with A with ... with A`. These types could be confusing for
* both end users and IDE's type checkers.
*/
trait With[A, B] { type Out }

trait LowPrioWith {
type Aux[A, B, W] = With[A, B] { type Out = W }
protected[this] val theInstance = new With[Any, Any] {}
protected[this] def of[A, B, W]: With[A, B] { type Out = W } = theInstance.asInstanceOf[Aux[A, B, W]]
implicit def identity[T]: Aux[T, T, T] = of[T, T, T]
}
/**
* {{{
* import frameless.TypedColumn
*
* case class Foo(id: Int, bar: String)
*
* val colbar: TypedColumn[Foo, String] = TypedColumn { foo: Foo => foo.bar }
* val colid = TypedColumn[Foo, Int](_.id)
* }}}
*/
def apply[T, U](x: T => U): TypedColumn[T, U] =
macro TypedColumnMacroImpl.applyImpl[T, U]

object With extends LowPrioWith {
implicit def combine[A, B]: Aux[A, B, A with B] = of[A, B, A with B]
}
84 changes: 84 additions & 0 deletions dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package frameless

import scala.reflect.macros.whitebox

private[frameless] object TypedColumnMacroImpl {

def applyImpl[T: c.WeakTypeTag, U: c.WeakTypeTag](c: whitebox.Context)(x: c.Tree): c.Expr[TypedColumn[T, U]] = {
import c.universe._

val t = c.weakTypeOf[T]
val u = c.weakTypeOf[U]

def buildExpression(path: List[String]): c.Expr[TypedColumn[T, U]] = {
val columnName = path.mkString(".")

c.Expr[TypedColumn[T, U]](q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($columnName)).expr)")
}

def abort(msg: String) = c.abort(c.enclosingPosition, msg)

@annotation.tailrec
def path(in: Select, out: List[TermName]): List[TermName] =
in.qualifier match {
case sub: Select =>
path(sub, in.name.toTermName :: out)

case id: Ident =>
id.name.toTermName :: in.name.toTermName :: out

case u =>
abort(s"Unsupported selection: $u")
}

@annotation.tailrec
def check(current: Type, in: List[TermName]): Boolean = in match {
case next :: tail => {
val sym = current.decl(next).asTerm

if (!sym.isStable) {
abort(s"Stable term expected: ${current}.${next}")
}

check(sym.info, tail)
}

case _ =>
true
}

x match {
case fn: Function => fn.body match {
case select: Select if select.name.isTermName =>
val expectedRoot: Option[String] = fn.vparams match {
case List(rt) if rt.rhs == EmptyTree =>
Option.empty[String]

case List(rt) =>
Some(rt.toString)

case u =>
abort(s"Select expression must have a single parameter: ${u mkString ", "}")
}

path(select, List.empty) match {
case root :: tail if (
expectedRoot.forall(_ == root) && check(t, tail)) => {
val colPath = tail.mkString(".")

c.Expr[TypedColumn[T, U]](q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($colPath)).expr)")
}

case _ =>
abort(s"Invalid select expression: $select")
}

case t =>
abort(s"Select expression expected: $t")
}

case _ =>
abort(s"Function expected: $x")
}
}
}
4 changes: 2 additions & 2 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
*
* It is statically checked that column with such name exists and has type `A`.
*/
def col[A](x: Function1[T, A]): TypedColumn[T, A] = macro TypedColumn.macroImpl[T, A]

def col[A](x: Function1[T, A]): TypedColumn[T, A] =
macro TypedColumnMacroImpl.applyImpl[T, A]

/** Projects the entire TypedDataset[T] into a single column of type TypedColumn[T,T]
* {{{
Expand Down
27 changes: 27 additions & 0 deletions dataset/src/main/scala/frameless/With.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package frameless

/** Compute the intersection of two types:
*
* - With[A, A] = A
* - With[A, B] = A with B (when A != B)
*
* This type function is needed to prevent IDEs from infering large types
* with shape `A with A with ... with A`. These types could be confusing for
* both end users and IDE's type checkers.
*/
trait With[A, B] { type Out }

object With extends LowPrioWith {
implicit def combine[A, B]: Aux[A, B, A with B] = of[A, B, A with B]
}

private[frameless] sealed trait LowPrioWith {
type Aux[A, B, W] = With[A, B] { type Out = W }

protected[this] val theInstance = new With[Any, Any] {}

protected[this] def of[A, B, W]: With[A, B] { type Out = W } =
theInstance.asInstanceOf[Aux[A, B, W]]

implicit def identity[T]: Aux[T, T, T] = of[T, T, T]
}
5 changes: 2 additions & 3 deletions dataset/src/test/scala/frameless/ColumnTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import ceedubs.irrec.regex.gen.CharRegexGen.genCharRegexAndCandidate

import scala.math.Ordering.Implicits._

class ColumnTests extends TypedDatasetSuite with Matchers {
final class ColumnTests extends TypedDatasetSuite with Matchers {

private implicit object OrderingImplicits {
implicit val sqlDateOrdering: Ordering[SQLDate] = Ordering.by(_.days)
Expand Down Expand Up @@ -440,7 +440,6 @@ class ColumnTests extends TypedDatasetSuite with Matchers {
}

test("col through lambda") {

case class MyClass1(a: Int, b: String, c: MyClass2)
case class MyClass2(d: Long)

Expand All @@ -455,6 +454,6 @@ class ColumnTests extends TypedDatasetSuite with Matchers {
"ds.col(x => java.lang.Math.abs(x.a))" shouldNot typeCheck

// we should be able to block the following as well...
//"ds.col(_.a.toInt)" shouldNot typeCheck
"ds.col(_.a.toInt)" shouldNot typeCheck
}
}
67 changes: 36 additions & 31 deletions dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,67 +8,72 @@ case class MyClass2(d: Long, e: MyClass3)
case class MyClass3(f: Double)
case class MyClass4(h: Boolean)

class ColumnViaLambdaTests extends TypedDatasetSuite with Matchers {
final class ColumnViaLambdaTests extends TypedDatasetSuite with Matchers {

def ds = {
TypedDataset.create(Seq(
MyClass1(1, "2", MyClass2(3L, MyClass3(7.0)), Some(MyClass4(true))),
MyClass1(4, "5", MyClass2(6L, MyClass3(8.0)), None)))
MyClass1(1, "2", MyClass2(3L, MyClass3(7.0D)), Some(MyClass4(true))),
MyClass1(4, "5", MyClass2(6L, MyClass3(8.0D)), None)))
}

test("col(_.a)") {
val col = ds.col(_.a)
val actual = ds.select(col).collect.run()
val expected = Seq(1, 4)
actual shouldEqual expected
val col = TypedColumn[MyClass1, Int](_.a)

ds.select(col).collect.run() shouldEqual Seq(1, 4)
}

test("col(x => x.a") {
val col = ds.col(x => x.a)
val actual = ds.select(col).collect.run()
val expected = Seq(1, 4)
actual shouldEqual expected
val col = TypedColumn[MyClass1, Int](x => x.a)

ds.select(col).collect.run() shouldEqual Seq(1, 4)
}

test("col((x: MyClass1) => x.a") {
val col = ds.col((x: MyClass1) => x.a)
val actual = ds.select(col).collect.run()
val expected = Seq(1, 4)
actual shouldEqual expected
val col = TypedColumn { (x: MyClass1) => x.a }

ds.select(col).collect.run() shouldEqual Seq(1, 4)
}

test("col((x:MyClass1) => x.a") {
val col = ds.col((x: MyClass1) => x.a)
val actual = ds.select(col).collect.run()
val expected = Seq(1, 4)
actual shouldEqual expected
test("col((x: MyClass1) => x.c.e.f") {
val col = TypedColumn { (x: MyClass1) => x.c.e.f }

ds.select(col).collect.run() shouldEqual Seq(7.0D, 8.0D)
}

test("col(_.c.d)") {
val col = ds.col(_.c.d)
val actual = ds.select(col).collect.run()
val expected = Seq(3, 6)
actual shouldEqual expected
val col = TypedColumn[MyClass1, Long](_.c.d)

ds.select(col).collect.run() shouldEqual Seq(3L, 6L)
}

test("col(_.c.e.f)") {
val col = ds.col(_.c.e.f)
val actual = ds.select(col).collect.run()
val expected = Seq(7.0, 8.0)
actual shouldEqual expected
val col = TypedColumn[MyClass1, Double](_.c.e.f)

ds.select(col).collect.run() shouldEqual Seq(7.0D, 8.0D)
}

test("col(_.c.d) as int does not compile (is long)") {
illTyped("TypedColumn[MyClass1, Int](_.c.d)")
}

test("col(_.g.h does not compile") {
val col = ds.col(_.g) // the path "ends" at .g (can't access h)
illTyped("""ds.col(_.g.h)""")
}

test("col(_.a.toString) should not compile") {
test("col(_.a.toString) does not compile") {
illTyped("""ds.col(_.a.toString)""")
}

test("col(x => java.lang.Math.abs(x.a)) should not compile") {
illTyped("""col(x => java.lang.Math.abs(x.a))""")
test("col(_.a.toString.size) does not compile") {
illTyped("""ds.col(_.a.toString.size)""")
}

test("col((x: MyClass1) => x.toString.size) does not compile") {
illTyped("""ds.col((x: MyClass1) => x.toString.size)""")
}

test("col(x => java.lang.Math.abs(x.a)) does not compile") {
illTyped("""col(x => java.lang.Math.abs(x.a))""")
}
}

0 comments on commit 044e7bc

Please sign in to comment.