Skip to content

Commit

Permalink
Close #257 - Support refined types
Browse files Browse the repository at this point in the history
  • Loading branch information
cchantep committed Oct 13, 2021
1 parent b28cf3c commit f32f410
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 10 deletions.
28 changes: 21 additions & 7 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@ ThisBuild / scalaVersion := (ThisBuild / crossScalaVersions).value.last
ThisBuild / mimaFailOnNoPrevious := false

lazy val root = Project("frameless", file("." + "frameless")).in(file("."))
.aggregate(core, cats, dataset, ml, docs)
.aggregate(core, cats, dataset, refined, ml, docs)
.settings(framelessSettings: _*)
.settings(noPublishSettings: _*)
.settings(noPublishSettings)
.settings(mimaPreviousArtifacts := Set())

lazy val core = project
.settings(name := "frameless-core")
.settings(framelessSettings: _*)
.settings(publishSettings: _*)
.settings(framelessSettings)
.settings(publishSettings)


lazy val cats = project
.settings(name := "frameless-cats")
.settings(framelessSettings: _*)
.settings(publishSettings: _*)
.settings(framelessSettings)
.settings(publishSettings)
.settings(
addCompilerPlugin("org.typelevel" % "kind-projector" % "0.13.2" cross CrossVersion.full),
scalacOptions += "-Ypartial-unification"
Expand All @@ -57,7 +57,7 @@ lazy val dataset = project
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % sparkVersion % Provided,
"org.apache.spark" %% "spark-sql" % sparkVersion % Provided,
"net.ceedubs" %% "irrec-regex-gen" % irrecVersion % Test
"net.ceedubs" %% "irrec-regex-gen" % irrecVersion % Test,
),
mimaBinaryIssueFilters ++= {
import com.typesafe.tools.mima.core._
Expand All @@ -78,6 +78,20 @@ lazy val dataset = project
))
.dependsOn(core % "test->test;compile->compile")

lazy val refined = project
.settings(name := "frameless-refined")
.settings(framelessSettings)
.settings(framelessTypedDatasetREPL)
.settings(publishSettings)
.settings(Seq(
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % sparkVersion % Provided,
"org.apache.spark" %% "spark-sql" % sparkVersion % Provided,
"eu.timepit" %% "refined" % "0.9.27"
)
))
.dependsOn(dataset % "test->test;compile->compile")

lazy val ml = project
.settings(name := "frameless-ml")
.settings(framelessSettings: _*)
Expand Down
10 changes: 8 additions & 2 deletions dataset/src/main/scala/frameless/RecordEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package frameless

import org.apache.spark.sql.FramelessInternals

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.{
Invoke, NewInstance, UnwrapOption, WrapOption
Expand All @@ -22,6 +23,9 @@ case class RecordEncoderField(

trait RecordEncoderFields[T <: HList] extends Serializable {
def value: List[RecordEncoderField]

override def toString: String =
s"""RecordEncoderFields${value.mkString("[", ", ", "]")}"""
}

object RecordEncoderFields {
Expand Down Expand Up @@ -164,11 +168,13 @@ class RecordEncoder[F, G <: HList, H <: HList]

def fromCatalyst(path: Expression): Expression = {
val exprs = fields.value.value.map { field =>
field.encoder.fromCatalyst( GetStructField(path, field.ordinal, Some(field.name)) )
field.encoder.fromCatalyst(
GetStructField(path, field.ordinal, Some(field.name)))
}

val newArgs = newInstanceExprs.value.from(exprs)
val newExpr = NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true)
val newExpr = NewInstance(
classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true)

val nullExpr = Literal.create(null, jvmRepr)

Expand Down
4 changes: 4 additions & 0 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ object TypedEncoder {

def fromCatalyst(path: Expression): Expression =
Invoke(path, "toString", jvmRepr)

override val toString = "stringEncoder"
}

implicit val booleanEncoder: TypedEncoder[Boolean] = new TypedEncoder[Boolean] {
Expand All @@ -72,6 +74,8 @@ object TypedEncoder {

def toCatalyst(path: Expression): Expression = path
def fromCatalyst(path: Expression): Expression = path

override def toString = "intEncoder"
}

implicit val longEncoder: TypedEncoder[Long] = new TypedEncoder[Long] {
Expand Down
1 change: 0 additions & 1 deletion dataset/src/test/resources/log4j.properties
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,3 @@ log4j.logger.Remoting=ERROR

# To debug expressions:
#log4j.logger.org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator=DEBUG

Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package frameless.refined

import scala.reflect.ClassTag

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.{
Invoke, NewInstance, UnwrapOption, WrapOption
}
import org.apache.spark.sql.types._

import eu.timepit.refined.api.RefType

import frameless.{ TypedEncoder, RecordFieldEncoder }

private[refined] trait RefinedFieldEncoders {
/**
* @tparam T the refined type (e.g. `String`)
*/
implicit def optionRefined[F[_, _], T, R](
implicit
i0: RefType[F],
i1: TypedEncoder[T],
i2: ClassTag[F[T, R]],
): RecordFieldEncoder[Option[F[T, R]]] =
RecordFieldEncoder[Option[F[T, R]]](new TypedEncoder[Option[F[T, R]]] {
def nullable = true

// `Refined` is a Value class: https://github.com/fthomas/refined/blob/master/modules/core/shared/src/main/scala-3.0-/eu/timepit/refined/api/Refined.scala#L8
def jvmRepr = ObjectType(classOf[Option[F[T, R]]])

def catalystRepr: DataType = i1.catalystRepr

val innerJvmRepr = ObjectType(i2.runtimeClass)

def fromCatalyst(path: Expression): Expression = {
val javaValue = i1.fromCatalyst(path)
val value = NewInstance(i2.runtimeClass, Seq(javaValue), innerJvmRepr)

WrapOption(value, innerJvmRepr)
}

@inline def toCatalyst(path: Expression): Expression = {
val value = UnwrapOption(innerJvmRepr, path)

val javaValue = Invoke(value, "value", i1.jvmRepr, Nil)

i1.toCatalyst(javaValue)
}

override def toString = s"optionRefined[${i2.runtimeClass.getName}]"
})

/**
* @tparam T the refined type (e.g. `String`)
*/
implicit def refined[F[_, _], T, R](
implicit
i0: RefType[F],
i1: TypedEncoder[T],
i2: ClassTag[F[T, R]],
): RecordFieldEncoder[F[T, R]] =
RecordFieldEncoder[F[T, R]](new TypedEncoder[F[T, R]] {
def nullable = i1.nullable

// `Refined` is a Value class: https://github.com/fthomas/refined/blob/master/modules/core/shared/src/main/scala-3.0-/eu/timepit/refined/api/Refined.scala#L8
def jvmRepr = i1.jvmRepr

def catalystRepr: DataType = i1.catalystRepr

def fromCatalyst(path: Expression): Expression =
i1.fromCatalyst(path)

@inline def toCatalyst(path: Expression): Expression =
i1.toCatalyst(path)

override def toString = s"refined[${i2.runtimeClass.getName}]"
})
}

33 changes: 33 additions & 0 deletions refined/src/main/scala/frameless/refined/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package frameless

import scala.reflect.ClassTag

import eu.timepit.refined.api.{ RefType, Validate }

package object refined extends RefinedFieldEncoders {
implicit def refinedInjection[F[_, _], T, R](
implicit
refType: RefType[F],
validate: Validate[T, R]
): Injection[F[T, R], T] = Injection(
refType.unwrap,
{ value =>
refType.refine[R](value) match {
case Left(errMsg) =>
throw new IllegalArgumentException(
s"Value $value does not satisfy refinement predicate: $errMsg")

case Right(res) => res
}
})

implicit def refinedEncoder[F[_, _], T, R](
implicit
i0: RefType[F],
i1: Validate[T, R],
i2: TypedEncoder[T],
i3: ClassTag[F[T, R]]
): TypedEncoder[F[T, R]] = TypedEncoder.usingInjection(
i3, refinedInjection, i2)
}

121 changes: 121 additions & 0 deletions refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package frameless

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{
IntegerType, ObjectType, StringType, StructField, StructType
}

import org.scalatest.matchers.should.Matchers

class RefinedFieldEncoderTests extends TypedDatasetSuite with Matchers {
test("Encode a bare refined type") {
import eu.timepit.refined.auto._
import eu.timepit.refined.types.string.NonEmptyString

val encoder: TypedEncoder[NonEmptyString] = {
import frameless.refined.refinedEncoder
TypedEncoder[NonEmptyString]
}

val ss = session
import ss.implicits._


encoder.catalystRepr shouldBe StringType

val nes: NonEmptyString = "Non Empty String"

val unsafeDs = TypedDataset.createUnsafe(sc.parallelize(Seq(nes.value)).toDF())(encoder)

val expected = Seq(nes)

unsafeDs.collect.run() shouldBe expected
}

test("Encode case class with a refined field") {
import RefinedTypesTests._

// Check jvmRepr
import org.apache.spark.sql.types.ObjectType

encoderA.jvmRepr shouldBe ObjectType(classOf[A])

// Check catalystRepr
val expectedAStructType = StructType(Seq(
StructField("a", IntegerType, false),
StructField("s", StringType, false)))

encoderA.catalystRepr shouldBe expectedAStructType

// Check unsafe
val unsafeDs: TypedDataset[A] = {
val rdd = sc.parallelize(Seq(Row(as.a, as.s.toString)))
val df = session.createDataFrame(rdd, expectedAStructType)

TypedDataset.createUnsafe(df)(encoderA)
}

val expected = Seq(as)

unsafeDs.collect.run() shouldBe expected

// Check safe
val safeDs = TypedDataset.create(expected)

safeDs.collect.run() shouldBe expected
}

test("Encode case class with a refined optional field") {
import RefinedTypesTests._

// Check jvmRepr
encoderB.jvmRepr shouldBe ObjectType(classOf[B])

// Check catalystRepr
val expectedBStructType = StructType(Seq(
StructField("a", IntegerType, false),
StructField("s", StringType, true)))

encoderB.catalystRepr shouldBe expectedBStructType

// Check unsafe
val unsafeDs: TypedDataset[B] = {
val rdd = sc.parallelize(Seq(
Row(bs.a, bs.s.mkString),
Row(2, null.asInstanceOf[String]),
))

val df = session.createDataFrame(rdd, expectedBStructType)

TypedDataset.createUnsafe(df)(encoderB)
}

val expected = Seq(bs, B(2, None))

unsafeDs.collect.run() shouldBe expected

// Check safe
val safeDs = TypedDataset.create(expected)

safeDs.collect.run() shouldBe expected
}
}

object RefinedTypesTests {
import eu.timepit.refined.auto._
import eu.timepit.refined.types.string.NonEmptyString

case class A(a: Int, s: NonEmptyString)
case class B(a: Int, s: Option[NonEmptyString])

val nes: NonEmptyString = "Non Empty String"

val as = A(-42, nes)
val bs = B(-42, Option(nes))

import frameless.refined._ // implicit instances for refined

implicit val encoderA: TypedEncoder[A] = TypedEncoder.usingDerivation

implicit val encoderB: TypedEncoder[B] = TypedEncoder.usingDerivation
}

0 comments on commit f32f410

Please sign in to comment.