Skip to content

Commit

Permalink
Merge pull request #433 from chris-twiner/temp/Spark3
Browse files Browse the repository at this point in the history
resolve #427-  Spark 3 support
  • Loading branch information
imarios authored Aug 31, 2020
2 parents 732dbc4 + 4713716 commit 614986b
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 83 deletions.
15 changes: 3 additions & 12 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,13 @@ jobs:
include:
- stage: Documentation 2.12
env: PHASE=A
scala: 2.12.8
- stage: Documentation 2.11
env: PHASE=A
scala: 2.11.12
scala: 2.12.10
- stage: Unit Tests 2.12
env: PHASE=B
scala: 2.12.8
- stage: Unit Tests 2.11
env: PHASE=B
scala: 2.11.12
- stage: Publish 2.11
env: PHASE=C
scala: 2.11.12
scala: 2.12.10
- stage: Publish 2.12
env: PHASE=C
scala: 2.12.8
scala: 2.12.10

script:
- scripts/travis-publish.sh
Expand Down
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ The compatible versions of [Spark](http://spark.apache.org/) and
| 0.6.1 | 2.3.0 | 1.x | 0.8 | 2.11
| 0.7.0 | 2.3.1 | 1.x | 1.x | 2.11
| 0.8.0 | 2.4.0 | 1.x | 1.x | 2.11/2.12

| 0.9.0 | 3.0.0 | 1.x | 1.x | 2.12


Versions 0.5.x and 0.6.x have identical features. The first is compatible with Spark 2.2.1 and the second with 2.3.0.
Expand All @@ -48,6 +48,10 @@ This essentially allows you to use any version of Frameless with any version of
The aforementioned table simply provides the versions of Spark we officially compile
and test Frameless with, but other versions may probably work as well.

### Breaking changes in 0.9

* Spark 3 introduces a new ExpressionEncoder approach, the schema for single value DataFrame's is now ["value"](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala#L270) not "_1".

## Why?

Frameless introduces a new Spark API, called `TypedDataset`.
Expand All @@ -74,12 +78,12 @@ detailed comparison of `TypedDataset` with Spark's `Dataset` API.
* [Proof of Concept: TypedDataFrame](http://typelevel.org/frameless/TypedDataFrame.html)

## Quick Start
Frameless is compiled against Scala 2.11.x (and Scala 2.12.x since Frameless 0.8.0)
Since the 0.9.x release, Frameless is compiled only against Scala 2.12.x.

To use Frameless in your project add the following in your `build.sbt` file as needed:

```scala
val framelessVersion = "0.8.0" // for Spark 2.4.0
val framelessVersion = "0.9.0" // for Spark 3.0.0

libraryDependencies ++= List(
"org.typelevel" %% "frameless-dataset" % framelessVersion,
Expand Down
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
val sparkVersion = "2.4.6"
val sparkVersion = "3.0.0"
val catsCoreVersion = "2.0.0"
val catsEffectVersion = "2.0.0"
val catsMtlVersion = "0.7.0"
Expand Down Expand Up @@ -85,7 +85,7 @@ lazy val docs = project

lazy val framelessSettings = Seq(
organization := "org.typelevel",
crossScalaVersions := Seq("2.11.12", "2.12.10"),
crossScalaVersions := Seq("2.12.10"),
scalaVersion := crossScalaVersions.value.last,
scalacOptions ++= commonScalacOptions(scalaVersion.value),
licenses += ("Apache-2.0", url("http://opensource.org/licenses/Apache-2.0")),
Expand Down
19 changes: 4 additions & 15 deletions dataset/src/main/scala/frameless/RecordEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package frameless

import org.apache.spark.sql.FramelessInternals
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -162,23 +161,13 @@ class RecordEncoder[F, G <: HList, H <: HList]

def fromCatalyst(path: Expression): Expression = {
val exprs = fields.value.value.map { field =>
val fieldPath = path match {
case BoundReference(ordinal, dataType, nullable) =>
GetColumnByOrdinal(field.ordinal, field.encoder.jvmRepr)
case other =>
GetStructField(path, field.ordinal, Some(field.name))
}
field.encoder.fromCatalyst(fieldPath)
field.encoder.fromCatalyst( GetStructField(path, field.ordinal, Some(field.name)) )
}

val newArgs = newInstanceExprs.value.from(exprs)
val newExpr = NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true)
path match {
case BoundReference(0, _, _) => newExpr
case _ => {
val nullExpr = Literal.create(null, jvmRepr)
If(IsNull(path), nullExpr, newExpr)
}
}

val nullExpr = Literal.create(null, jvmRepr)
If(IsNull(path), nullExpr, newExpr)
}
}
10 changes: 5 additions & 5 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import frameless.ops._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.types.StructType
import shapeless._
Expand Down Expand Up @@ -620,7 +620,7 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
import FramelessInternals._
val leftPlan = logicalPlan(dataset)
val rightPlan = logicalPlan(other.dataset)
val join = disambiguate(Join(leftPlan, rightPlan, Inner, Some(condition.expr)))
val join = disambiguate(Join(leftPlan, rightPlan, Inner, Some(condition.expr), JoinHint.NONE))
val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan)
val joinedDs = mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, U)])
TypedDataset.create[(T, U)](joinedDs)
Expand Down Expand Up @@ -902,11 +902,11 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
i2: Tupler.Aux[Out0, Out],
i3: TypedEncoder[Out]
): TypedDataset[Out] = {
val selected = dataset.toDF()
val base = dataset.toDF()
.select(columns.toList[UntypedExpression[T]].map(c => new Column(c.expr)):_*)
.as[Out](TypedExpressionEncoder[Out])
val selected = base.as[Out](TypedExpressionEncoder[Out])

TypedDataset.create[Out](selected)
TypedDataset.create[Out](selected)
}
}

Expand Down
6 changes: 3 additions & 3 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ object TypedEncoder {

case ByteType => path

case _ => MapObjects(encodeT.toCatalyst, path, encodeT.jvmRepr, encodeT.nullable)
case _ => MapObjects(encodeT.toCatalyst _, path, encodeT.jvmRepr, encodeT.nullable)
}

def fromCatalyst(path: Expression): Expression =
Expand All @@ -246,7 +246,7 @@ object TypedEncoder {
case ByteType => path

case _ =>
Invoke(MapObjects(encodeT.fromCatalyst, path, encodeT.catalystRepr, encodeT.nullable), "array", jvmRepr)
Invoke(MapObjects(encodeT.fromCatalyst _, path, encodeT.catalystRepr, encodeT.nullable), "array", jvmRepr)
}
}

Expand All @@ -265,7 +265,7 @@ object TypedEncoder {
def toCatalyst(path: Expression): Expression =
if (ScalaReflection.isNativeType(encodeT.value.jvmRepr))
NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr)
else MapObjects(encodeT.value.toCatalyst, path, encodeT.value.jvmRepr, encodeT.value.nullable)
else MapObjects(encodeT.value.toCatalyst _, path, encodeT.value.jvmRepr, encodeT.value.nullable)

def fromCatalyst(path: Expression): Expression =
MapObjects(
Expand Down
28 changes: 13 additions & 15 deletions dataset/src/main/scala/frameless/TypedExpressionEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,48 +1,46 @@
package frameless

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, If, Literal}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, If}
import org.apache.spark.sql.types.StructType

object TypedExpressionEncoder {

/** In Spark, DataFrame has always schema of StructType
*
* DataFrames of primitive types become records with a single field called "_1".
* DataFrames of primitive types become records with a single field called "value" set in ExpressionEncoder.
*/
def targetStructType[A](encoder: TypedEncoder[A]): StructType = {
encoder.catalystRepr match {
case x: StructType =>
if (encoder.nullable) StructType(x.fields.map(_.copy(nullable = true)))
else x
case dt => new StructType().add("_1", dt, nullable = encoder.nullable)
case dt => new StructType().add("value", dt, nullable = encoder.nullable)
}
}

def apply[T: TypedEncoder]: ExpressionEncoder[T] = {
def apply[T: TypedEncoder]: Encoder[T] = {
val encoder = TypedEncoder[T]
val schema = targetStructType(encoder)

val in = BoundReference(0, encoder.jvmRepr, encoder.nullable)

val (out, toRowExpressions) = encoder.toCatalyst(in) match {
case If(_, _, x: CreateNamedStruct) =>
val out = BoundReference(0, encoder.catalystRepr, encoder.nullable)
val (out, serializer) = encoder.toCatalyst(in) match {
case it @ If(_, _, _: CreateNamedStruct) =>
val out = GetColumnByOrdinal(0, encoder.catalystRepr)

(out, x.flatten)
(out, it)
case other =>
val out = GetColumnByOrdinal(0, encoder.catalystRepr)

(out, CreateNamedStruct(Literal("_1") :: other :: Nil).flatten)
(out, other)
}

new ExpressionEncoder[T](
schema = schema,
flat = false,
serializer = toRowExpressions,
deserializer = encoder.fromCatalyst(out),
objSerializer = serializer,
objDeserializer = encoder.fromCatalyst(out),
clsTag = encoder.classTag
)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ trait AggregateFunctions {
def sum[A, T, Out](column: TypedColumn[T, A])(
implicit
summable: CatalystSummable[A, Out],
oencoder: TypedEncoder[Out]
oencoder: TypedEncoder[Out],
aencoder: TypedEncoder[A]
): TypedAggregate[T, Out] = {
val zeroExpr = Literal.create(summable.zero, TypedEncoder[Out].catalystRepr)
val zeroExpr = Literal.create(summable.zero, TypedEncoder[A].catalystRepr)
val sumExpr = expr(sparkFunctions.sum(column.untyped))
val sumOrZero = Coalesce(Seq(sumExpr, zeroExpr))

Expand All @@ -79,9 +80,10 @@ trait AggregateFunctions {
def sumDistinct[A, T, Out](column: TypedColumn[T, A])(
implicit
summable: CatalystSummable[A, Out],
oencoder: TypedEncoder[Out]
oencoder: TypedEncoder[Out],
aencoder: TypedEncoder[A]
): TypedAggregate[T, Out] = {
val zeroExpr = Literal.create(summable.zero, TypedEncoder[Out].catalystRepr)
val zeroExpr = Literal.create(summable.zero, TypedEncoder[A].catalystRepr)
val sumExpr = expr(sparkFunctions.sumDistinct(column.untyped))
val sumOrZero = Coalesce(Seq(sumExpr, zeroExpr))

Expand Down
50 changes: 44 additions & 6 deletions dataset/src/main/scala/frameless/functions/Udf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package frameless
package functions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression}
import org.apache.spark.sql.catalyst.expressions.codegen._, Block._
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, NonSQLExpression}
import org.apache.spark.sql.catalyst.expressions.codegen._
import Block._
import org.apache.spark.sql.types.DataType
import shapeless.syntax.std.tuple._

Expand Down Expand Up @@ -90,7 +90,7 @@ case class FramelessUdf[T, R](
override def nullable: Boolean = rencoder.nullable
override def toString: String = s"FramelessUdf(${children.mkString(", ")})"

def eval(input: InternalRow): Any = {
lazy val evalCode = {
val ctx = new CodegenContext()
val eval = genCode(ctx)

Expand Down Expand Up @@ -123,7 +123,11 @@ case class FramelessUdf[T, R](
val (clazz, _) = CodeGenerator.compile(code)
val codegen = clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef]

codegen(input)
codegen
}

def eval(input: InternalRow): Any = {
evalCode(input)
}

def dataType: DataType = rencoder.catalystRepr
Expand Down Expand Up @@ -152,7 +156,8 @@ case class FramelessUdf[T, R](
val internalTpe = CodeGenerator.boxedType(rencoder.jvmRepr)
val internalTerm = ctx.addMutableState(internalTpe, ctx.freshName("internal"))
val internalNullTerm = ctx.addMutableState("boolean", ctx.freshName("internalNull"))
val internalExpr = LambdaVariable(internalTerm, internalNullTerm, rencoder.jvmRepr)
// CTw - can't inject the term, may have to duplicate old code for parity
val internalExpr = Spark2_4_LambdaVariable(internalTerm, internalNullTerm, rencoder.jvmRepr, true)

val resultEval = rencoder.toCatalyst(internalExpr).genCode(ctx)

Expand All @@ -171,6 +176,39 @@ case class FramelessUdf[T, R](
}
}

case class Spark2_4_LambdaVariable(
value: String,
isNull: String,
dataType: DataType,
nullable: Boolean = true) extends LeafExpression with NonSQLExpression {

private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)

// Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
override def eval(input: InternalRow): Any = {
assert(input.numFields == 1,
"The input row of interpreted LambdaVariable should have only 1 field.")
if (nullable && input.isNullAt(0)) {
null
} else {
accessor(input, 0)
}
}

override def genCode(ctx: CodegenContext): ExprCode = {
val isNullValue = if (nullable) {
JavaCode.isNullVariable(isNull)
} else {
FalseLiteral
}
ExprCode(value = JavaCode.variable(value, dataType), isNull = isNullValue)
}

// This won't be called as `genCode` is overrided, just overriding it to make
// `LambdaVariable` non-abstract.
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev
}

object FramelessUdf {
// Spark needs case class with `children` field to mutate it
def apply[T, R](
Expand Down
19 changes: 15 additions & 4 deletions dataset/src/test/scala/frameless/SchemaTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,29 @@ package frameless

import frameless.functions.aggregate._
import frameless.functions._
import org.apache.spark.sql.types.StructType
import org.scalacheck.Prop
import org.scalacheck.Prop._
import org.scalatest.matchers.should.Matchers

class SchemaTests extends TypedDatasetSuite with Matchers {

def prop[A](dataset: TypedDataset[A]): Prop = {
def structToNonNullable(struct: StructType): StructType = {
StructType(struct.fields.map( f => f.copy(nullable = false)))
}

def prop[A](dataset: TypedDataset[A], ignoreNullable: Boolean = false): Prop = {
val schema = dataset.dataset.schema

Prop.all(
dataset.schema ?= schema,
TypedExpressionEncoder.targetStructType(dataset.encoder) ?= schema
if (!ignoreNullable)
dataset.schema ?= schema
else
structToNonNullable(dataset.schema) ?= structToNonNullable(schema),
if (!ignoreNullable)
TypedExpressionEncoder.targetStructType(dataset.encoder) ?= schema
else
structToNonNullable(TypedExpressionEncoder.targetStructType(dataset.encoder)) ?= structToNonNullable(schema)
)
}

Expand All @@ -24,7 +35,7 @@ class SchemaTests extends TypedDatasetSuite with Matchers {

val df = df0.groupBy(_a).agg(sum(_b))

check(prop(df))
check(prop(df, true))
}

test("schema of select(lit(1L))") {
Expand Down
Loading

0 comments on commit 614986b

Please sign in to comment.