diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index 6e74967d..a518d8c6 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -10,7 +10,7 @@ import shapeless.ops.record.Selector import scala.annotation.implicitNotFound import scala.reflect.ClassTag -import scala.reflect.macros.whitebox + import scala.language.experimental.macros sealed trait UntypedExpression[T] { @@ -866,8 +866,7 @@ object SortedTypedColumn { implicit def caseTypedColumn[T, U : CatalystOrdered] = at[TypedColumn[T, U]](c => defaultAscending(c)) implicit def caseTypeSortedColumn[T, U] = at[SortedTypedColumn[T, U]](identity) } - } - +} object TypedColumn { /** Evidence that type `T` has column `K` with type `V`. */ @@ -899,49 +898,17 @@ object TypedColumn { ): Exists[T, K, V] = new Exists[T, K, V] {} } - def apply[T, A](x: Function1[T, A]): TypedColumn[T, A] = macro macroImpl[T, A] - - def macroImpl[T: c.WeakTypeTag, A: c.WeakTypeTag](c: whitebox.Context)(x: c.Tree) = { - - import c.universe._ - - val t = c.weakTypeOf[T] - val a = c.weakTypeOf[A] - - def buildExpression(columnNames: List[String]) = { - val columnName = columnNames.mkString(".") - c.Expr[TypedColumn[T, A]](q"new frameless.TypedColumn[$t, $a]((org.apache.spark.sql.functions.col($columnName)).expr)") - } - - x match { - case q"((${_: TermName}:${_: Type}) => ${_: TermName}.${p: TermName})" => buildExpression(List(p.toString())) - case q"(_.${p: TermName})" => buildExpression(List(p.toString())) - case q"(_.${p: TermName}.${x: TermName})" => buildExpression(List(p.toString(), x.toString())) - case q"(_.${p: TermName}.${x: TermName}.${y: TermName})" => buildExpression(List(p.toString(), x.toString(), y.toString())) - case q"(_.${p: TermName}.${x: TermName}.${y: TermName}.${z: TermName})" => buildExpression(List(p.toString(), x.toString(), y.toString(), z.toString())) - case x => throw new IllegalArgumentException(s"$x is not supported") - } - } -} - -/** Compute the intersection of two types: - * - * - With[A, A] = A - * - With[A, B] = A with B (when A != B) - * - * This type function is needed to prevent IDEs from infering large types - * with shape `A with A with ... with A`. These types could be confusing for - * both end users and IDE's type checkers. - */ -trait With[A, B] { type Out } - -trait LowPrioWith { - type Aux[A, B, W] = With[A, B] { type Out = W } - protected[this] val theInstance = new With[Any, Any] {} - protected[this] def of[A, B, W]: With[A, B] { type Out = W } = theInstance.asInstanceOf[Aux[A, B, W]] - implicit def identity[T]: Aux[T, T, T] = of[T, T, T] -} + /** + * {{{ + * import frameless.TypedColumn + * + * case class Foo(id: Int, bar: String) + * + * val colbar: TypedColumn[Foo, String] = TypedColumn { foo: Foo => foo.bar } + * val colid = TypedColumn[Foo, Int](_.id) + * }}} + */ + def apply[T, U](x: T => U): TypedColumn[T, U] = + macro TypedColumnMacroImpl.applyImpl[T, U] -object With extends LowPrioWith { - implicit def combine[A, B]: Aux[A, B, A with B] = of[A, B, A with B] } diff --git a/dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala b/dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala new file mode 100644 index 00000000..62fa2765 --- /dev/null +++ b/dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala @@ -0,0 +1,84 @@ +package frameless + +import scala.reflect.macros.whitebox + +private[frameless] object TypedColumnMacroImpl { + + def applyImpl[T: c.WeakTypeTag, U: c.WeakTypeTag](c: whitebox.Context)(x: c.Tree): c.Expr[TypedColumn[T, U]] = { + import c.universe._ + + val t = c.weakTypeOf[T] + val u = c.weakTypeOf[U] + + def buildExpression(path: List[String]): c.Expr[TypedColumn[T, U]] = { + val columnName = path.mkString(".") + + c.Expr[TypedColumn[T, U]](q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($columnName)).expr)") + } + + def abort(msg: String) = c.abort(c.enclosingPosition, msg) + + @annotation.tailrec + def path(in: Select, out: List[TermName]): List[TermName] = + in.qualifier match { + case sub: Select => + path(sub, in.name.toTermName :: out) + + case id: Ident => + id.name.toTermName :: in.name.toTermName :: out + + case u => + abort(s"Unsupported selection: $u") + } + + @annotation.tailrec + def check(current: Type, in: List[TermName]): Boolean = in match { + case next :: tail => { + val sym = current.decl(next).asTerm + + if (!sym.isStable) { + abort(s"Stable term expected: ${current}.${next}") + } + + check(sym.info, tail) + } + + case _ => + true + } + + x match { + case fn: Function => fn.body match { + case select: Select if select.name.isTermName => + val expectedRoot: Option[String] = fn.vparams match { + case List(rt) if rt.rhs == EmptyTree => + Option.empty[String] + + case List(rt) => + Some(rt.toString) + + case u => + abort(s"Select expression must have a single parameter: ${u mkString ", "}") + } + + path(select, List.empty) match { + case root :: tail if ( + expectedRoot.forall(_ == root) && check(t, tail)) => { + val colPath = tail.mkString(".") + + c.Expr[TypedColumn[T, U]](q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($colPath)).expr)") + } + + case _ => + abort(s"Invalid select expression: $select") + } + + case t => + abort(s"Select expression expected: $t") + } + + case _ => + abort(s"Function expected: $x") + } + } +} diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index e89a2d4c..f8ff0fec 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -248,8 +248,8 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val * * It is statically checked that column with such name exists and has type `A`. */ - def col[A](x: Function1[T, A]): TypedColumn[T, A] = macro TypedColumn.macroImpl[T, A] - + def col[A](x: Function1[T, A]): TypedColumn[T, A] = + macro TypedColumnMacroImpl.applyImpl[T, A] /** Projects the entire TypedDataset[T] into a single column of type TypedColumn[T,T] * {{{ diff --git a/dataset/src/main/scala/frameless/With.scala b/dataset/src/main/scala/frameless/With.scala new file mode 100644 index 00000000..11ceaa35 --- /dev/null +++ b/dataset/src/main/scala/frameless/With.scala @@ -0,0 +1,27 @@ +package frameless + +/** Compute the intersection of two types: + * + * - With[A, A] = A + * - With[A, B] = A with B (when A != B) + * + * This type function is needed to prevent IDEs from infering large types + * with shape `A with A with ... with A`. These types could be confusing for + * both end users and IDE's type checkers. + */ +trait With[A, B] { type Out } + +object With extends LowPrioWith { + implicit def combine[A, B]: Aux[A, B, A with B] = of[A, B, A with B] +} + +private[frameless] sealed trait LowPrioWith { + type Aux[A, B, W] = With[A, B] { type Out = W } + + protected[this] val theInstance = new With[Any, Any] {} + + protected[this] def of[A, B, W]: With[A, B] { type Out = W } = + theInstance.asInstanceOf[Aux[A, B, W]] + + implicit def identity[T]: Aux[T, T, T] = of[T, T, T] +} diff --git a/dataset/src/test/scala/frameless/ColumnTests.scala b/dataset/src/test/scala/frameless/ColumnTests.scala index 17ee0f20..417e7735 100644 --- a/dataset/src/test/scala/frameless/ColumnTests.scala +++ b/dataset/src/test/scala/frameless/ColumnTests.scala @@ -10,7 +10,7 @@ import ceedubs.irrec.regex.gen.CharRegexGen.genCharRegexAndCandidate import scala.math.Ordering.Implicits._ -class ColumnTests extends TypedDatasetSuite with Matchers { +final class ColumnTests extends TypedDatasetSuite with Matchers { private implicit object OrderingImplicits { implicit val sqlDateOrdering: Ordering[SQLDate] = Ordering.by(_.days) @@ -440,7 +440,6 @@ class ColumnTests extends TypedDatasetSuite with Matchers { } test("col through lambda") { - case class MyClass1(a: Int, b: String, c: MyClass2) case class MyClass2(d: Long) @@ -455,6 +454,6 @@ class ColumnTests extends TypedDatasetSuite with Matchers { "ds.col(x => java.lang.Math.abs(x.a))" shouldNot typeCheck // we should be able to block the following as well... - //"ds.col(_.a.toInt)" shouldNot typeCheck + "ds.col(_.a.toInt)" shouldNot typeCheck } } diff --git a/dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala b/dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala index c9a78c51..4abce50e 100644 --- a/dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala +++ b/dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala @@ -8,54 +8,52 @@ case class MyClass2(d: Long, e: MyClass3) case class MyClass3(f: Double) case class MyClass4(h: Boolean) -class ColumnViaLambdaTests extends TypedDatasetSuite with Matchers { +final class ColumnViaLambdaTests extends TypedDatasetSuite with Matchers { def ds = { TypedDataset.create(Seq( - MyClass1(1, "2", MyClass2(3L, MyClass3(7.0)), Some(MyClass4(true))), - MyClass1(4, "5", MyClass2(6L, MyClass3(8.0)), None))) + MyClass1(1, "2", MyClass2(3L, MyClass3(7.0D)), Some(MyClass4(true))), + MyClass1(4, "5", MyClass2(6L, MyClass3(8.0D)), None))) } test("col(_.a)") { - val col = ds.col(_.a) - val actual = ds.select(col).collect.run() - val expected = Seq(1, 4) - actual shouldEqual expected + val col = TypedColumn[MyClass1, Int](_.a) + + ds.select(col).collect.run() shouldEqual Seq(1, 4) } test("col(x => x.a") { - val col = ds.col(x => x.a) - val actual = ds.select(col).collect.run() - val expected = Seq(1, 4) - actual shouldEqual expected + val col = TypedColumn[MyClass1, Int](x => x.a) + + ds.select(col).collect.run() shouldEqual Seq(1, 4) } test("col((x: MyClass1) => x.a") { - val col = ds.col((x: MyClass1) => x.a) - val actual = ds.select(col).collect.run() - val expected = Seq(1, 4) - actual shouldEqual expected + val col = TypedColumn { (x: MyClass1) => x.a } + + ds.select(col).collect.run() shouldEqual Seq(1, 4) } - test("col((x:MyClass1) => x.a") { - val col = ds.col((x: MyClass1) => x.a) - val actual = ds.select(col).collect.run() - val expected = Seq(1, 4) - actual shouldEqual expected + test("col((x: MyClass1) => x.c.e.f") { + val col = TypedColumn { (x: MyClass1) => x.c.e.f } + + ds.select(col).collect.run() shouldEqual Seq(7.0D, 8.0D) } test("col(_.c.d)") { - val col = ds.col(_.c.d) - val actual = ds.select(col).collect.run() - val expected = Seq(3, 6) - actual shouldEqual expected + val col = TypedColumn[MyClass1, Long](_.c.d) + + ds.select(col).collect.run() shouldEqual Seq(3L, 6L) } test("col(_.c.e.f)") { - val col = ds.col(_.c.e.f) - val actual = ds.select(col).collect.run() - val expected = Seq(7.0, 8.0) - actual shouldEqual expected + val col = TypedColumn[MyClass1, Double](_.c.e.f) + + ds.select(col).collect.run() shouldEqual Seq(7.0D, 8.0D) + } + + test("col(_.c.d) as int does not compile (is long)") { + illTyped("TypedColumn[MyClass1, Int](_.c.d)") } test("col(_.g.h does not compile") { @@ -63,12 +61,19 @@ class ColumnViaLambdaTests extends TypedDatasetSuite with Matchers { illTyped("""ds.col(_.g.h)""") } - test("col(_.a.toString) should not compile") { + test("col(_.a.toString) does not compile") { illTyped("""ds.col(_.a.toString)""") } - test("col(x => java.lang.Math.abs(x.a)) should not compile") { - illTyped("""col(x => java.lang.Math.abs(x.a))""") + test("col(_.a.toString.size) does not compile") { + illTyped("""ds.col(_.a.toString.size)""") + } + + test("col((x: MyClass1) => x.toString.size) does not compile") { + illTyped("""ds.col((x: MyClass1) => x.toString.size)""") } + test("col(x => java.lang.Math.abs(x.a)) does not compile") { + illTyped("""col(x => java.lang.Math.abs(x.a))""") + } }