diff --git a/build.sbt b/build.sbt index 30f2f6b1..5bfd5534 100644 --- a/build.sbt +++ b/build.sbt @@ -68,6 +68,8 @@ lazy val dataset = project // TODO: Remove have version bump Seq( + imt("frameless.TypedEncoder.mapEncoder"), + imt("frameless.TypedEncoder.arrayEncoder"), imt("frameless.RecordEncoderFields.deriveRecordCons"), imt("frameless.RecordEncoderFields.deriveRecordLast"), mc("frameless.functions.FramelessLit"), diff --git a/dataset/src/main/scala/frameless/RecordEncoder.scala b/dataset/src/main/scala/frameless/RecordEncoder.scala index b51a51eb..2670a3a8 100644 --- a/dataset/src/main/scala/frameless/RecordEncoder.scala +++ b/dataset/src/main/scala/frameless/RecordEncoder.scala @@ -177,7 +177,11 @@ class RecordEncoder[F, G <: HList, H <: HList] } final class RecordFieldEncoder[T]( - val encoder: TypedEncoder[T]) extends Serializable + val encoder: TypedEncoder[T], + private[frameless] val jvmRepr: DataType, + private[frameless] val fromCatalyst: Expression => Expression, + private[frameless] val toCatalyst: Expression => Expression +) extends Serializable object RecordFieldEncoder extends RecordFieldEncoderLowPriority { @@ -197,31 +201,51 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { i4: IsHCons.Aux[KS, K, HNil], i5: TypedEncoder[V], i6: ClassTag[F] - ): RecordFieldEncoder[Option[F]] = RecordFieldEncoder[Option[F]](new TypedEncoder[Option[F]] { - val nullable = true - - val jvmRepr = ObjectType(classOf[Option[F]]) + ): RecordFieldEncoder[Option[F]] = { + val fieldName = i4.head(i3()).name + val innerJvmRepr = ObjectType(i6.runtimeClass) - @inline def catalystRepr: DataType = i5.catalystRepr + val catalyst: Expression => Expression = { path => + val value = UnwrapOption(innerJvmRepr, path) + val javaValue = Invoke(value, fieldName, i5.jvmRepr, Nil) - val innerJvmRepr = ObjectType(i6.runtimeClass) + i5.toCatalyst(javaValue) + } - def fromCatalyst(path: Expression): Expression = { + val fromCatalyst: Expression => Expression = { path => val javaValue = i5.fromCatalyst(path) val value = NewInstance(i6.runtimeClass, Seq(javaValue), innerJvmRepr) WrapOption(value, innerJvmRepr) } - @inline def toCatalyst(path: Expression): Expression = { - val value = UnwrapOption(innerJvmRepr, path) + val jvmr = ObjectType(classOf[Option[F]]) - val fieldName = i4.head(i3()).name - val javaValue = Invoke(value, fieldName, i5.jvmRepr, Nil) + new RecordFieldEncoder[Option[F]]( + encoder = new TypedEncoder[Option[F]] { + val nullable = true - i5.toCatalyst(javaValue) - } - }) + val jvmRepr = jvmr + + @inline def catalystRepr: DataType = i5.catalystRepr + + def fromCatalyst(path: Expression): Expression = { + val javaValue = i5.fromCatalyst(path) + val value = NewInstance( + i6.runtimeClass, Seq(javaValue), innerJvmRepr) + + WrapOption(value, innerJvmRepr) + } + + def toCatalyst(path: Expression): Expression = catalyst(path) + + override def toString: String = s"RecordFieldEncoder.optionValueClass[${i6.runtimeClass.getName}]('${fieldName}', $i5)" + }, + jvmRepr = jvmr, + fromCatalyst = fromCatalyst, + toCatalyst = catalyst + ) + } /** * @tparam F the value class @@ -229,28 +253,50 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { * @tparam H the single field of the value class (with guarantee it's not a `Unit` value) * @tparam V the inner value type */ - implicit def valueClass[F : IsValueClass, G <: ::[_, HNil], H <: ::[_, HNil], V] + implicit def valueClass[F : IsValueClass, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil]] (implicit i0: LabelledGeneric.Aux[F, G], i1: DropUnitValues.Aux[G, H], - i2: IsHCons.Aux[H, _ <: FieldType[_, V], HNil], - i3: TypedEncoder[V], - i4: ClassTag[F] - ): RecordFieldEncoder[F] = RecordFieldEncoder[F](new TypedEncoder[F] { - def nullable = i3.nullable - - def jvmRepr = i3.jvmRepr - - def catalystRepr: DataType = i3.catalystRepr - - def fromCatalyst(path: Expression): Expression = - i3.fromCatalyst(path) - - @inline def toCatalyst(path: Expression): Expression = - i3.toCatalyst(path) - }) + i2: IsHCons.Aux[H, _ <: FieldType[K, V], HNil], + i3: Keys.Aux[H, KS], + i4: IsHCons.Aux[KS, K, HNil], + i5: TypedEncoder[V], + i6: ClassTag[F] + ): RecordFieldEncoder[F] = { + val cls = i6.runtimeClass + val jvmr = i5.jvmRepr + val fieldName = i4.head(i3()).name + + new RecordFieldEncoder[F]( + encoder = new TypedEncoder[F] { + def nullable = i5.nullable + + def jvmRepr = jvmr + + def catalystRepr: DataType = i5.catalystRepr + + def fromCatalyst(path: Expression): Expression = + i5.fromCatalyst(path) + + @inline def toCatalyst(path: Expression): Expression = + i5.toCatalyst(path) + + override def toString: String = s"RecordFieldEncoder.valueClass[${cls.getName}]('${fieldName}', ${i5})" + }, + jvmRepr = FramelessInternals.objectTypeFor[F], + fromCatalyst = { expr: Expression => + NewInstance( + i6.runtimeClass, + i5.fromCatalyst(expr) :: Nil, + ObjectType(i6.runtimeClass)) + }, + toCatalyst = { expr: Expression => + i5.toCatalyst(Invoke(expr, fieldName, jvmr)) + } + ) + } } private[frameless] sealed trait RecordFieldEncoderLowPriority { - implicit def apply[T](implicit e: TypedEncoder[T]): RecordFieldEncoder[T] = new RecordFieldEncoder[T](e) + implicit def apply[T](implicit e: TypedEncoder[T]): RecordFieldEncoder[T] = new RecordFieldEncoder[T](e, e.jvmRepr, e.fromCatalyst, e.toCatalyst) } diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index eaba6439..73842e81 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -1,5 +1,7 @@ package frameless +import scala.reflect.ClassTag + import org.apache.spark.sql.FramelessInternals import org.apache.spark.sql.FramelessInternals.UserDefinedType import org.apache.spark.sql.catalyst.ScalaReflection @@ -8,11 +10,10 @@ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String + import shapeless._ import shapeless.ops.hlist.IsHCons -import scala.reflect.ClassTag - abstract class TypedEncoder[T](implicit val classTag: ClassTag[T]) extends Serializable { def nullable: Boolean @@ -54,15 +55,16 @@ object TypedEncoder { Invoke(path, "toString", jvmRepr) } - implicit val booleanEncoder: TypedEncoder[Boolean] = new TypedEncoder[Boolean] { - def nullable: Boolean = false + implicit val booleanEncoder: TypedEncoder[Boolean] = + new TypedEncoder[Boolean] { + def nullable: Boolean = false - def jvmRepr: DataType = BooleanType - def catalystRepr: DataType = BooleanType + def jvmRepr: DataType = BooleanType + def catalystRepr: DataType = BooleanType - def toCatalyst(path: Expression): Expression = path - def fromCatalyst(path: Expression): Expression = path - } + def toCatalyst(path: Expression): Expression = path + def fromCatalyst(path: Expression): Expression = path + } implicit val intEncoder: TypedEncoder[Int] = new TypedEncoder[Int] { def nullable: Boolean = false @@ -96,24 +98,30 @@ object TypedEncoder { implicit val charEncoder: TypedEncoder[Char] = new TypedEncoder[Char] { // tricky because while Char is primitive type, Spark doesn't support it - implicit val charAsString: Injection[java.lang.Character, String] = new Injection[java.lang.Character, String] { - def apply(a: java.lang.Character): String = String.valueOf(a) - def invert(b: String): java.lang.Character = { - require(b.length == 1) - b.charAt(0) + implicit val charAsString: Injection[java.lang.Character, String] = + new Injection[java.lang.Character, String] { + def apply(a: java.lang.Character): String = String.valueOf(a) + + def invert(b: String): java.lang.Character = { + require(b.length == 1) + b.charAt(0) + } } - } val underlying = usingInjection[java.lang.Character, String] def nullable: Boolean = false // this line fixes underlying encoder - def jvmRepr: DataType = FramelessInternals.objectTypeFor[java.lang.Character] + def jvmRepr: DataType = + FramelessInternals.objectTypeFor[java.lang.Character] + def catalystRepr: DataType = StringType def toCatalyst(path: Expression): Expression = underlying.toCatalyst(path) - def fromCatalyst(path: Expression): Expression = underlying.fromCatalyst(path) + + def fromCatalyst(path: Expression): Expression = + underlying.fromCatalyst(path) } implicit val byteEncoder: TypedEncoder[Byte] = new TypedEncoder[Byte] { @@ -146,31 +154,39 @@ object TypedEncoder { def fromCatalyst(path: Expression): Expression = path } - implicit val bigDecimalEncoder: TypedEncoder[BigDecimal] = new TypedEncoder[BigDecimal] { - def nullable: Boolean = false + implicit val bigDecimalEncoder: TypedEncoder[BigDecimal] = + new TypedEncoder[BigDecimal] { + def nullable: Boolean = false - def jvmRepr: DataType = ScalaReflection.dataTypeFor[BigDecimal] - def catalystRepr: DataType = DecimalType.SYSTEM_DEFAULT + def jvmRepr: DataType = ScalaReflection.dataTypeFor[BigDecimal] + def catalystRepr: DataType = DecimalType.SYSTEM_DEFAULT - def toCatalyst(path: Expression): Expression = - StaticInvoke(Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", path :: Nil) + def toCatalyst(path: Expression): Expression = + StaticInvoke( + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", path :: Nil) - def fromCatalyst(path: Expression): Expression = - Invoke(path, "toBigDecimal", jvmRepr) - } + def fromCatalyst(path: Expression): Expression = + Invoke(path, "toBigDecimal", jvmRepr) - implicit val javaBigDecimalEncoder: TypedEncoder[java.math.BigDecimal] = new TypedEncoder[java.math.BigDecimal] { - def nullable: Boolean = false + override def toString: String = "bigDecimalEncoder" + } - def jvmRepr: DataType = ScalaReflection.dataTypeFor[java.math.BigDecimal] - def catalystRepr: DataType = DecimalType.SYSTEM_DEFAULT + implicit val javaBigDecimalEncoder: TypedEncoder[java.math.BigDecimal] = + new TypedEncoder[java.math.BigDecimal] { + def nullable: Boolean = false - def toCatalyst(path: Expression): Expression = - StaticInvoke(Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", path :: Nil) + def jvmRepr: DataType = ScalaReflection.dataTypeFor[java.math.BigDecimal] + def catalystRepr: DataType = DecimalType.SYSTEM_DEFAULT - def fromCatalyst(path: Expression): Expression = - Invoke(path, "toJavaBigDecimal", jvmRepr) - } + def toCatalyst(path: Expression): Expression = + StaticInvoke( + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", path :: Nil) + + def fromCatalyst(path: Expression): Expression = + Invoke(path, "toJavaBigDecimal", jvmRepr) + + override def toString: String = "javaBigDecimalEncoder" + } implicit val sqlDate: TypedEncoder[SQLDate] = new TypedEncoder[SQLDate] { def nullable: Boolean = false @@ -210,29 +226,39 @@ object TypedEncoder { ) } - implicit def arrayEncoder[T: ClassTag](implicit encodeT: TypedEncoder[T]): TypedEncoder[Array[T]] = + implicit def arrayEncoder[T: ClassTag]( + implicit i0: Lazy[RecordFieldEncoder[T]]): TypedEncoder[Array[T]] = new TypedEncoder[Array[T]] { + private lazy val encodeT = i0.value.encoder + def nullable: Boolean = false - def jvmRepr: DataType = encodeT.jvmRepr match { + lazy val jvmRepr: DataType = i0.value.jvmRepr match { case ByteType => BinaryType case _ => FramelessInternals.objectTypeFor[Array[T]] } - def catalystRepr: DataType = encodeT.jvmRepr match { + lazy val catalystRepr: DataType = i0.value.jvmRepr match { case ByteType => BinaryType case _ => ArrayType(encodeT.catalystRepr, encodeT.nullable) } - def toCatalyst(path: Expression): Expression = - encodeT.jvmRepr match { - case IntegerType | LongType | DoubleType | FloatType | ShortType | BooleanType => - StaticInvoke(classOf[UnsafeArrayData], catalystRepr, "fromPrimitiveArray", path :: Nil) + def toCatalyst(path: Expression): Expression = { + val enc = i0.value + + enc.jvmRepr match { + case IntegerType | LongType | DoubleType | FloatType | + ShortType | BooleanType => + StaticInvoke( + classOf[UnsafeArrayData], + catalystRepr, "fromPrimitiveArray", path :: Nil) case ByteType => path - case _ => MapObjects(encodeT.toCatalyst _, path, encodeT.jvmRepr, encodeT.nullable) + case _ => MapObjects( + enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable) } + } def fromCatalyst(path: Expression): Expression = encodeT.jvmRepr match { @@ -246,53 +272,76 @@ object TypedEncoder { case ByteType => path case _ => - Invoke(MapObjects(encodeT.fromCatalyst _, path, encodeT.catalystRepr, encodeT.nullable), "array", jvmRepr) + Invoke(MapObjects( + i0.value.fromCatalyst, path, + encodeT.catalystRepr, encodeT.nullable), "array", jvmRepr) } + + override def toString: String = s"arrayEncoder($jvmRepr)" } implicit def collectionEncoder[C[X] <: Seq[X], T] (implicit - encodeT: Lazy[TypedEncoder[T]], - CT: ClassTag[C[T]] - ): TypedEncoder[C[T]] = - new TypedEncoder[C[T]] { - def nullable: Boolean = false + i0: Lazy[RecordFieldEncoder[T]], + i1: ClassTag[C[T]]): TypedEncoder[C[T]] = new TypedEncoder[C[T]] { + private lazy val encodeT = i0.value.encoder + + def nullable: Boolean = false - def jvmRepr: DataType = FramelessInternals.objectTypeFor[C[T]](CT) + def jvmRepr: DataType = FramelessInternals.objectTypeFor[C[T]](i1) - def catalystRepr: DataType = ArrayType(encodeT.value.catalystRepr, encodeT.value.nullable) + def catalystRepr: DataType = + ArrayType(encodeT.catalystRepr, encodeT.nullable) - def toCatalyst(path: Expression): Expression = - if (ScalaReflection.isNativeType(encodeT.value.jvmRepr)) - NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr) - else MapObjects(encodeT.value.toCatalyst _, path, encodeT.value.jvmRepr, encodeT.value.nullable) + def toCatalyst(path: Expression): Expression = { + val enc = i0.value - def fromCatalyst(path: Expression): Expression = - MapObjects( - encodeT.value.fromCatalyst, - path, - encodeT.value.catalystRepr, - encodeT.value.nullable, - Some(CT.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly - ) + if (ScalaReflection.isNativeType(enc.jvmRepr)) { + NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr) + } else { + MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable) } + } + def fromCatalyst(path: Expression): Expression = + MapObjects( + i0.value.fromCatalyst, + path, + encodeT.catalystRepr, + encodeT.nullable, + Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly + ) + + override def toString: String = s"collectionEncoder($jvmRepr)" + } + + /** + * @tparam A the key type + * @tparam B the value type + * @param i0 the keys encoder + * @param i1 the values encoder + */ implicit def mapEncoder[A: NotCatalystNullable, B] (implicit - encodeA: TypedEncoder[A], - encodeB: TypedEncoder[B] + i0: Lazy[RecordFieldEncoder[A]], + i1: Lazy[RecordFieldEncoder[B]], ): TypedEncoder[Map[A, B]] = new TypedEncoder[Map[A, B]] { def nullable: Boolean = false def jvmRepr: DataType = FramelessInternals.objectTypeFor[Map[A, B]] - def catalystRepr: DataType = MapType(encodeA.catalystRepr, encodeB.catalystRepr, encodeB.nullable) + private lazy val encodeA = i0.value.encoder + private lazy val encodeB = i1.value.encoder + + lazy val catalystRepr: DataType = MapType( + encodeA.catalystRepr, encodeB.catalystRepr, encodeB.nullable) def fromCatalyst(path: Expression): Expression = { val keyArrayType = ArrayType(encodeA.catalystRepr, containsNull = false) + val keyData = Invoke( MapObjects( - encodeA.fromCatalyst, + i0.value.fromCatalyst, Invoke(path, "keyArray", keyArrayType), encodeA.catalystRepr ), @@ -301,9 +350,10 @@ object TypedEncoder { ) val valueArrayType = ArrayType(encodeB.catalystRepr, encodeB.nullable) + val valueData = Invoke( MapObjects( - encodeB.fromCatalyst, + i1.value.fromCatalyst, Invoke(path, "valueArray", valueArrayType), encodeB.catalystRepr ), @@ -318,21 +368,30 @@ object TypedEncoder { keyData :: valueData :: Nil) } - def toCatalyst(path: Expression): Expression = ExternalMapToCatalyst( - path, - encodeA.jvmRepr, - encodeA.toCatalyst, - encodeA.nullable, - encodeB.jvmRepr, - encodeB.toCatalyst, - encodeB.nullable) + def toCatalyst(path: Expression): Expression = { + val encA = i0.value + val encB = i1.value + + ExternalMapToCatalyst( + path, + encA.jvmRepr, + encA.toCatalyst, + false, + encB.jvmRepr, + encB.toCatalyst, + encodeB.nullable) + } + + override def toString = s"mapEncoder($jvmRepr)" } implicit def optionEncoder[A](implicit underlying: TypedEncoder[A]): TypedEncoder[Option[A]] = new TypedEncoder[Option[A]] { def nullable: Boolean = true - def jvmRepr: DataType = FramelessInternals.objectTypeFor[Option[A]](classTag) + def jvmRepr: DataType = + FramelessInternals.objectTypeFor[Option[A]](classTag) + def catalystRepr: DataType = underlying.catalystRepr def toCatalyst(path: Expression): Expression = { @@ -340,41 +399,50 @@ object TypedEncoder { underlying.jvmRepr match { case IntegerType => Invoke( - UnwrapOption(ScalaReflection.dataTypeFor[java.lang.Integer], path), + UnwrapOption( + ScalaReflection.dataTypeFor[java.lang.Integer], path), "intValue", IntegerType) + case LongType => Invoke( UnwrapOption(ScalaReflection.dataTypeFor[java.lang.Long], path), "longValue", LongType) + case DoubleType => Invoke( UnwrapOption(ScalaReflection.dataTypeFor[java.lang.Double], path), "doubleValue", DoubleType) + case FloatType => Invoke( UnwrapOption(ScalaReflection.dataTypeFor[java.lang.Float], path), "floatValue", FloatType) + case ShortType => Invoke( UnwrapOption(ScalaReflection.dataTypeFor[java.lang.Short], path), "shortValue", ShortType) + case ByteType => Invoke( UnwrapOption(ScalaReflection.dataTypeFor[java.lang.Byte], path), "byteValue", ByteType) + case BooleanType => Invoke( - UnwrapOption(ScalaReflection.dataTypeFor[java.lang.Boolean], path), + UnwrapOption( + ScalaReflection.dataTypeFor[java.lang.Boolean], path), "booleanValue", BooleanType) - case _ => underlying.toCatalyst(UnwrapOption(underlying.jvmRepr, path)) + case _ => underlying.toCatalyst( + UnwrapOption(underlying.jvmRepr, path)) } } @@ -395,10 +463,9 @@ object TypedEncoder { Invoke(Literal.fromObject(inj), "invert", jvmRepr, Seq(bexpr)) } - def toCatalyst(path: Expression): Expression = { - val invoke = Invoke(Literal.fromObject(inj), "apply", trb.jvmRepr, Seq(path)) - trb.toCatalyst(invoke) - } + def toCatalyst(path: Expression): Expression = + trb.toCatalyst(Invoke( + Literal.fromObject(inj), "apply", trb.jvmRepr, Seq(path))) } /** Encodes things as records if there is no Injection defined */ @@ -415,7 +482,8 @@ object TypedEncoder { /** Encodes things using a Spark SQL's User Defined Type (UDT) if there is one defined in implicit */ implicit def usingUserDefinedType[A >: Null : UserDefinedType : ClassTag]: TypedEncoder[A] = { val udt = implicitly[UserDefinedType[A]] - val udtInstance = NewInstance(udt.getClass, Nil, dataType = ObjectType(udt.getClass)) + val udtInstance = NewInstance( + udt.getClass, Nil, dataType = ObjectType(udt.getClass)) new TypedEncoder[A] { def nullable: Boolean = false diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index 8ffd5665..15ff19a2 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -5,7 +5,7 @@ import scala.reflect.ClassTag import shapeless._ import shapeless.labelled.FieldType import shapeless.ops.hlist.IsHCons -import shapeless.ops.record.Values +import shapeless.ops.record.{Keys, Values} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Literal @@ -57,31 +57,33 @@ package object functions extends Udf with UnaryFunctions { * @tparam A the value class * @tparam T the row type */ - def litValue[A : IsValueClass, T, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], V, VS <: HList](value: A)( + def litValue[A : IsValueClass, T, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil], VS <: HList](value: A)( implicit i0: LabelledGeneric.Aux[A, G], i1: DropUnitValues.Aux[G, H], - i2: IsHCons.Aux[H, _ <: FieldType[_, V], HNil], - i3: Values.Aux[H, VS], - i4: IsHCons.Aux[VS, V, HNil], - i5: TypedEncoder[V], - i6: ClassTag[A] + i2: IsHCons.Aux[H, _ <: FieldType[K, V], HNil], + i3: Keys.Aux[H, KS], + i4: Values.Aux[H, VS], + i5: IsHCons.Aux[KS, K, HNil], + i6: IsHCons.Aux[VS, V, HNil], + i7: TypedEncoder[V], + i8: ClassTag[A] ): TypedColumn[T, A] = { val expr = { val field: H = i1(i0.to(value)) - val v: V = i4.head(i3(field)) + val v: V = i6.head(i4(field)) - new Literal(v, i5.jvmRepr) + new Literal(v, i7.jvmRepr) } implicit val enc: TypedEncoder[A] = - RecordFieldEncoder.valueClass[A, G, H, V].encoder + RecordFieldEncoder.valueClass[A, G, H, K, V, KS].encoder new TypedColumn[T, A]( Lit( - dataType = i5.catalystRepr, - nullable = i5.nullable, - toCatalyst = i5.toCatalyst(expr).genCode(_), + dataType = i7.catalystRepr, + nullable = i7.nullable, + toCatalyst = i7.toCatalyst(expr).genCode(_), show = value.toString ) ) diff --git a/dataset/src/test/resources/log4j.properties b/dataset/src/test/resources/log4j.properties index d3d35c98..9bd87dc0 100644 --- a/dataset/src/test/resources/log4j.properties +++ b/dataset/src/test/resources/log4j.properties @@ -147,4 +147,3 @@ log4j.logger.Remoting=ERROR # To debug expressions: #log4j.logger.org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator=DEBUG - diff --git a/dataset/src/test/scala/frameless/CreateTests.scala b/dataset/src/test/scala/frameless/CreateTests.scala index 9282aa8b..4d9b5547 100644 --- a/dataset/src/test/scala/frameless/CreateTests.scala +++ b/dataset/src/test/scala/frameless/CreateTests.scala @@ -97,8 +97,8 @@ class CreateTests extends TypedDatasetSuite with Matchers { check(prop[String]) } - test("map fields (scala.Predef.Map / scala.collection.immutable.Map)") { - def prop[A: Arbitrary: TypedEncoder, B: Arbitrary: TypedEncoder] = forAll { + test("Map fields (scala.Predef.Map / scala.collection.immutable.Map)") { + def prop[A: Arbitrary: NotCatalystNullable: TypedEncoder, B: Arbitrary: NotCatalystNullable: TypedEncoder] = forAll { (d1: Map[A, B], d2: Map[B, A], d3: Map[A, Option[B]], d4: Map[A, X1[B]], d5: Map[X1[A], B], d6: Map[X1[A], X1[B]]) => diff --git a/dataset/src/test/scala/frameless/IsValueClassTests.scala b/dataset/src/test/scala/frameless/IsValueClassTests.scala index 379da451..b2d63b1d 100644 --- a/dataset/src/test/scala/frameless/IsValueClassTests.scala +++ b/dataset/src/test/scala/frameless/IsValueClassTests.scala @@ -22,6 +22,7 @@ final class IsValueClassTests extends AnyFunSuite with Matchers { illTyped("implicitly[IsValueClass[Byte]]") illTyped("implicitly[IsValueClass[Unit]]") illTyped("implicitly[IsValueClass[Boolean]]") + illTyped("implicitly[IsValueClass[BigDecimal]]") } test("Value class evidence") { diff --git a/dataset/src/test/scala/frameless/RecordEncoderTests.scala b/dataset/src/test/scala/frameless/RecordEncoderTests.scala index 2a75b42e..fd925bc6 100644 --- a/dataset/src/test/scala/frameless/RecordEncoderTests.scala +++ b/dataset/src/test/scala/frameless/RecordEncoderTests.scala @@ -2,7 +2,16 @@ package frameless import org.apache.spark.sql.{Row, functions => F} import org.apache.spark.sql.types.{ - IntegerType, LongType, ObjectType, StringType, StructField, StructType + ArrayType, + BinaryType, + DecimalType, + IntegerType, + LongType, + MapType, + ObjectType, + StringType, + StructField, + StructType } import shapeless.{HList, LabelledGeneric} @@ -10,31 +19,7 @@ import shapeless.test.illTyped import org.scalatest.matchers.should.Matchers -case class UnitsOnly(a: Unit, b: Unit) - -case class TupleWithUnits(u0: Unit, _1: Int, u1: Unit, u2: Unit, _2: String, u3: Unit) - -object TupleWithUnits { - def apply(_1: Int, _2: String): TupleWithUnits = TupleWithUnits((), _1, (), (), _2, ()) -} - -case class OptionalNesting(o: Option[TupleWithUnits]) - -object RecordEncoderTests { - case class A(x: Int) - case class B(a: Seq[A]) - case class C(b: B) - - class Name(val value: String) extends AnyVal with Serializable { - override def toString = value - } - - case class Person(name: Name, age: Int) - - case class User(id: Long, name: Option[Name]) -} - -class RecordEncoderTests extends TypedDatasetSuite with Matchers { +final class RecordEncoderTests extends TypedDatasetSuite with Matchers { test("Unable to encode products made from units only") { illTyped("TypedEncoder[UnitsOnly]") } @@ -69,6 +54,7 @@ class RecordEncoderTests extends TypedDatasetSuite with Matchers { val schema = TypedEncoder[OptionalNesting].catalystRepr.asInstanceOf[StructType] val df = session.createDataFrame(rdd, schema) val ds = TypedDataset.createUnsafe(df)(TypedEncoder[OptionalNesting]) + ds.firstOption.run.get.o.isEmpty shouldBe true } @@ -226,4 +212,332 @@ class RecordEncoderTests extends TypedDatasetSuite with Matchers { // Safely created ds TypedDataset.create(expected).collect.run() shouldBe expected } + + test("Case class with simple Map") { + import RecordEncoderTests._ + + val encoder = TypedEncoder[D] + + encoder.jvmRepr shouldBe ObjectType(classOf[D]) + + val expectedStructType = StructType(Seq( + StructField("m", MapType( + keyType = StringType, + valueType = IntegerType, + valueContainsNull = false), false))) + + encoder.catalystRepr shouldBe expectedStructType + + val sqlContext = session.sqlContext + import sqlContext.implicits._ + + val ds1 = TypedDataset.createUnsafe[D] { + val df = Seq( + """{"m":{"pizza":1,"sushi":2}}""", + """{"m":{"red":3,"blue":4}}""", + ).toDF + + df.withColumn( + "jsonValue", + F.from_json(df.col("value"), expectedStructType)). + select("jsonValue.*") + } + + val expected = Seq( + D(m = Map("pizza" -> 1, "sushi" -> 2)), + D(m = Map("red" -> 3, "blue" -> 4))) + + ds1.collect.run() shouldBe expected + + val m2 = Map("updated" -> 5) + + val ds2 = ds1.withColumnReplaced('m, functions.lit(m2)) + + ds2.collect.run() shouldBe expected.map(_.copy(m = m2)) + } + + test("Case class with Map & Value class") { + import RecordEncoderTests._ + + val encoder = TypedEncoder[Student] + + encoder.jvmRepr shouldBe ObjectType(classOf[Student]) + + val expectedStudentStructType = StructType(Seq( + StructField("name", StringType, false), + StructField("grades", MapType( + keyType = StringType, + valueType = DecimalType.SYSTEM_DEFAULT, + valueContainsNull = false), false))) + + encoder.catalystRepr shouldBe expectedStudentStructType + + val sqlContext = session.sqlContext + import sqlContext.implicits._ + + val ds1 = TypedDataset.createUnsafe[Student] { + val df = Seq( + """{"name":"Foo","grades":{"math":1,"physics":"23.4"}}""", + """{"name":"Bar","grades":{"biology":18.5,"geography":4}}""", + ).toDF + + df.withColumn( + "jsonValue", + F.from_json(df.col("value"), expectedStudentStructType)). + select("jsonValue.*") + } + + val expected = Seq( + Student(name = "Foo", grades = Map( + new Subject("math") -> new Grade(BigDecimal(1)), + new Subject("physics") -> new Grade(BigDecimal(23.4D)))), + Student(name = "Bar", grades = Map( + new Subject("biology") -> new Grade(BigDecimal(18.5)), + new Subject("geography") -> new Grade(BigDecimal(4L))))) + + ds1.collect.run() shouldBe expected + + val grades = Map[Subject, Grade]( + new Subject("any") -> new Grade(BigDecimal(Long.MaxValue) + 1L)) + + val ds2 = ds1.withColumnReplaced('grades, functions.lit(grades)) + + ds2.collect.run() shouldBe Seq( + Student("Foo", grades), Student("Bar", grades)) + } + + test("Encode binary array") { + val encoder = TypedEncoder[Tuple2[String, Array[Byte]]] + + encoder.jvmRepr shouldBe ObjectType( + classOf[Tuple2[String, Array[Byte]]]) + + val expectedStructType = StructType(Seq( + StructField("_1", StringType, false), + StructField("_2", BinaryType, false))) + + encoder.catalystRepr shouldBe expectedStructType + + val ds1: TypedDataset[(String, Array[Byte])] = { + val rdd = sc.parallelize(Seq( + Row.fromTuple("Foo" -> Array[Byte](3, 4)), + Row.fromTuple("Bar" -> Array[Byte](5)) + )) + val df = session.createDataFrame(rdd, expectedStructType) + + TypedDataset.createUnsafe(df)(encoder) + } + + val expected = Seq("Foo" -> Seq[Byte](3, 4), "Bar" -> Seq[Byte](5)) + + ds1.collect.run().map { + case (_1, _2) => _1 -> _2.toSeq + } shouldBe expected + + val subjects = "lorem".getBytes("UTF-8").toSeq + + val ds2 = ds1.withColumnReplaced('_2, functions.lit(subjects.toArray)) + + ds2.collect.run().map { + case (_1, _2) => _1 -> _2.toSeq + } shouldBe expected.map(_.copy(_2 = subjects)) + } + + test("Encode simple array") { + val encoder = TypedEncoder[Tuple2[String, Array[Int]]] + + encoder.jvmRepr shouldBe ObjectType( + classOf[Tuple2[String, Array[Int]]]) + + val expectedStructType = StructType(Seq( + StructField("_1", StringType, false), + StructField("_2", ArrayType(IntegerType, false), false))) + + encoder.catalystRepr shouldBe expectedStructType + + val sqlContext = session.sqlContext + import sqlContext.implicits._ + + val ds1 = TypedDataset.createUnsafe[(String, Array[Int])] { + val df = Seq( + """{"_1":"Foo", "_2":[3, 4]}""", + """{"_1":"Bar", "_2":[5]}""", + ).toDF + + df.withColumn( + "jsonValue", + F.from_json(df.col("value"), expectedStructType)). + select("jsonValue.*") + } + + val expected = Seq("Foo" -> Seq(3, 4), "Bar" -> Seq(5)) + + ds1.collect.run().map { + case (_1, _2) => _1 -> _2.toSeq + } shouldBe expected + + val subjects = Seq(6, 6, 7) + + val ds2 = ds1.withColumnReplaced('_2, functions.lit(subjects.toArray)) + + ds2.collect.run().map { + case (_1, _2) => _1 -> _2.toSeq + } shouldBe expected.map(_.copy(_2 = subjects)) + } + + test("Encode array of Value class") { + import RecordEncoderTests._ + + val encoder = TypedEncoder[Tuple2[String, Array[Subject]]] + + encoder.jvmRepr shouldBe ObjectType( + classOf[Tuple2[String, Array[Subject]]]) + + val expectedStructType = StructType(Seq( + StructField("_1", StringType, false), + StructField("_2", ArrayType(StringType, false), false))) + + encoder.catalystRepr shouldBe expectedStructType + + val sqlContext = session.sqlContext + import sqlContext.implicits._ + + val ds1 = TypedDataset.createUnsafe[(String, Array[Subject])] { + val df = Seq( + """{"_1":"Foo", "_2":["math","physics"]}""", + """{"_1":"Bar", "_2":["biology","geography"]}""", + ).toDF + + df.withColumn( + "jsonValue", + F.from_json(df.col("value"), expectedStructType)). + select("jsonValue.*") + } + + val expected = Seq( + "Foo" -> Seq(new Subject("math"), new Subject("physics")), + "Bar" -> Seq(new Subject("biology"), new Subject("geography"))) + + ds1.collect.run().map { + case (_1, _2) => _1 -> _2.toSeq + } shouldBe expected + + val subjects = Seq(new Subject("lorem"), new Subject("ipsum")) + + val ds2 = ds1.withColumnReplaced('_2, functions.lit(subjects.toArray)) + + ds2.collect.run().map { + case (_1, _2) => _1 -> _2.toSeq + } shouldBe expected.map(_.copy(_2 = subjects)) + } + + test("Encode case class with simple Seq") { + import RecordEncoderTests._ + + val encoder = TypedEncoder[B] + + encoder.jvmRepr shouldBe ObjectType(classOf[B]) + + val expectedStructType = StructType(Seq( + StructField("a", ArrayType(StructType(Seq( + StructField("x", IntegerType, false))), false), false))) + + encoder.catalystRepr shouldBe expectedStructType + + val ds1: TypedDataset[B] = { + val rdd = sc.parallelize(Seq( + Row.fromTuple(Tuple1(Seq( + Row.fromTuple(Tuple1[Int](1)), + Row.fromTuple(Tuple1[Int](3)) + ))), + Row.fromTuple(Tuple1(Seq( + Row.fromTuple(Tuple1[Int](2)) + ))) + )) + val df = session.createDataFrame(rdd, expectedStructType) + + TypedDataset.createUnsafe(df)(encoder) + } + + val expected = Seq(B(Seq(A(1), A(3))), B(Seq(A(2)))) + + ds1.collect.run() shouldBe expected + + val as = Seq(A(5), A(6)) + + val ds2 = ds1.withColumnReplaced('a, functions.lit(as)) + + ds2.collect.run() shouldBe expected.map(_.copy(a = as)) + } + + test("Encode case class with Value class") { + import RecordEncoderTests._ + + val encoder = TypedEncoder[Tuple2[Int, Seq[Name]]] + + encoder.jvmRepr shouldBe ObjectType(classOf[Tuple2[Int, Seq[Name]]]) + + val expectedStructType = StructType(Seq( + StructField("_1", IntegerType, false), + StructField("_2", ArrayType(StringType, false), false))) + + encoder.catalystRepr shouldBe expectedStructType + + val ds1 = TypedDataset.createUnsafe[(Int, Seq[Name])] { + val sqlContext = session.sqlContext + import sqlContext.implicits._ + + val df = Seq( + """{"_1":1, "_2":["foo", "bar"]}""", + """{"_1":2, "_2":["lorem"]}""", + ).toDF + + df.withColumn( + "jsonValue", + F.from_json(df.col("value"), expectedStructType)). + select("jsonValue.*") + } + + val expected = Seq( + 1 -> Seq(new Name("foo"), new Name("bar")), + 2 -> Seq(new Name("lorem"))) + + ds1.collect.run() shouldBe expected + } +} + +// --- + +case class UnitsOnly(a: Unit, b: Unit) + +case class TupleWithUnits( + u0: Unit, _1: Int, u1: Unit, u2: Unit, _2: String, u3: Unit) + +object TupleWithUnits { + def apply(_1: Int, _2: String): TupleWithUnits = + TupleWithUnits((), _1, (), (), _2, ()) +} + +case class OptionalNesting(o: Option[TupleWithUnits]) + +object RecordEncoderTests { + case class A(x: Int) + case class B(a: Seq[A]) + case class C(b: B) + + class Name(val value: String) extends AnyVal with Serializable { + override def toString = s"Name($value)" + } + + case class Person(name: Name, age: Int) + + case class User(id: Long, name: Option[Name]) + + case class D(m: Map[String, Int]) + + final class Subject(val name: String) extends AnyVal with Serializable + + final class Grade(val value: BigDecimal) extends AnyVal with Serializable + + case class Student(name: String, grades: Map[Subject, Grade]) }