diff --git a/core/src/main/scala/frameless/CatalystOrdered.scala b/core/src/main/scala/frameless/CatalystOrdered.scala index 4943e09f..e7360490 100644 --- a/core/src/main/scala/frameless/CatalystOrdered.scala +++ b/core/src/main/scala/frameless/CatalystOrdered.scala @@ -3,6 +3,7 @@ package frameless import scala.annotation.implicitNotFound import shapeless.{Generic, HList, Lazy} import shapeless.ops.hlist.LiftAll +import java.time.{Duration, Instant, Period} /** Types that can be ordered/compared by Catalyst. */ @implicitNotFound("Cannot compare columns of type ${A}.") @@ -23,6 +24,9 @@ object CatalystOrdered { implicit val framelessSQLDateOrdered : CatalystOrdered[SQLDate] = of[SQLDate] implicit val framelessSQLTimestampOrdered: CatalystOrdered[SQLTimestamp] = of[SQLTimestamp] implicit val framelessStringOrdered : CatalystOrdered[String] = of[String] + implicit val framelessInstantOrdered : CatalystOrdered[Instant] = of[Instant] + implicit val framelessDurationOrdered : CatalystOrdered[Duration] = of[Duration] + implicit val framelessPeriodOrdered : CatalystOrdered[Period] = of[Period] implicit def injectionOrdered[A, B] (implicit diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index 40a17ed4..9da7d84b 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -1,5 +1,6 @@ package frameless +import java.time.{Duration, Instant, Period} import scala.reflect.ClassTag import org.apache.spark.sql.FramelessInternals @@ -7,7 +8,7 @@ import org.apache.spark.sql.FramelessInternals.UserDefinedType import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -230,6 +231,45 @@ object TypedEncoder { ) } + /** java.time Encoders, Spark uses https://github.com/apache/spark/blob/v3.2.0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala for encoding / decoding. */ + implicit val timeInstant: TypedEncoder[Instant] = new TypedEncoder[Instant] { + def nullable: Boolean = false + + def jvmRepr: DataType = ScalaReflection.dataTypeFor[Instant] + def catalystRepr: DataType = TimestampType + + def toCatalyst(path: Expression): Expression = + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "instantToMicros", + path :: Nil, + returnNullable = false) + + def fromCatalyst(path: Expression): Expression = + StaticInvoke( + staticObject = DateTimeUtils.getClass, + dataType = jvmRepr, + functionName = "microsToInstant", + arguments = path :: Nil, + propagateNull = true + ) + } + + /** + * DayTimeIntervalType and YearMonthIntervalType in Spark 3.2.0. + * We maintain Spark 3.x cross compilation and handle Duration and Period as an injections to be compatible with Spark versions < 3.2 + * See + * * https://github.com/apache/spark/blob/v3.2.0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala#L1031-L1047 + * * https://github.com/apache/spark/blob/v3.2.0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala#L1075-L1087 + */ + // DayTimeIntervalType + implicit val timeDurationInjection: Injection[Duration, Long] = Injection(_.toMillis, Duration.ofMillis) + // YearMonthIntervalType + implicit val timePeriodInjection: Injection[Period, Int] = Injection(_.getDays, Period.ofDays) + implicit val timePeriodEncoder: TypedEncoder[Period] = TypedEncoder.usingInjection + implicit val timeDurationEncoder: TypedEncoder[Duration] = TypedEncoder.usingInjection + implicit def arrayEncoder[T: ClassTag]( implicit i0: Lazy[RecordFieldEncoder[T]]): TypedEncoder[Array[T]] = new TypedEncoder[Array[T]] { diff --git a/dataset/src/test/scala/frameless/ColumnTests.scala b/dataset/src/test/scala/frameless/ColumnTests.scala index 638ff7b8..3dea3e56 100644 --- a/dataset/src/test/scala/frameless/ColumnTests.scala +++ b/dataset/src/test/scala/frameless/ColumnTests.scala @@ -1,7 +1,6 @@ package frameless -import java.time.Instant - +import java.time.{Instant, Period, Duration} import org.scalacheck.Prop._ import org.scalacheck.{Arbitrary, Gen, Prop}, Arbitrary.arbitrary import org.scalatest.matchers.should.Matchers @@ -14,11 +13,17 @@ final class ColumnTests extends TypedDatasetSuite with Matchers { private implicit object OrderingImplicits { implicit val sqlDateOrdering: Ordering[SQLDate] = Ordering.by(_.days) implicit val sqlTimestmapOrdering: Ordering[SQLTimestamp] = Ordering.by(_.us) - implicit val arbInstant: Arbitrary[Instant] = Arbitrary( - Gen.chooseNum(0L, Instant.MAX.getEpochSecond) - .map(Instant.ofEpochSecond)) - implicit val instantAsLongInjection: Injection[Instant, Long] = - Injection(_.getEpochSecond, Instant.ofEpochSecond) + implicit val periodOrdering: Ordering[Period] = Ordering.by(p => (p.getYears, p.getMonths, p.getDays)) + /** + * DateTimeUtils.instantToMicros supports dates starting 1970-01-01T00:00:00Z, which is Instant.EPOCH. + * This function also overflows on Instant.MAX, to be sure it never overflows we use Instant.MAX / 4. + * For implementation details check the org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros function details. + */ + + val genInstant = Gen.choose[Instant](Instant.EPOCH, Instant.ofEpochMilli(Instant.MAX.getEpochSecond / 4)) + implicit val arbInstant: Arbitrary[Instant] = Arbitrary(genInstant) + implicit val arbDuration: Arbitrary[Duration] = Arbitrary(genInstant.map(i => Duration.ofMillis(i.toEpochMilli))) + implicit val arbPeriod: Arbitrary[Period] = Arbitrary(Gen.chooseNum(0, Int.MaxValue).map(l => Period.of(l, l, l))) } test("select('a < 'b, 'a <= 'b, 'a > 'b, 'a >= 'b)") { @@ -49,6 +54,8 @@ final class ColumnTests extends TypedDatasetSuite with Matchers { check(forAll(prop[SQLTimestamp] _)) check(forAll(prop[String] _)) check(forAll(prop[Instant] _)) + check(forAll(prop[Duration] _)) + check(forAll(prop[Period] _)) } test("between") { @@ -76,6 +83,8 @@ final class ColumnTests extends TypedDatasetSuite with Matchers { check(forAll(prop[SQLTimestamp] _)) check(forAll(prop[String] _)) check(forAll(prop[Instant] _)) + check(forAll(prop[Duration] _)) + check(forAll(prop[Period] _)) } test("toString") { diff --git a/dataset/src/test/scala/frameless/EncoderTests.scala b/dataset/src/test/scala/frameless/EncoderTests.scala index dc8a47a6..51687f73 100644 --- a/dataset/src/test/scala/frameless/EncoderTests.scala +++ b/dataset/src/test/scala/frameless/EncoderTests.scala @@ -4,6 +4,9 @@ import org.scalatest.matchers.should.Matchers object EncoderTests { case class Foo(s: Seq[(Int, Int)]) + case class InstantRow(i: java.time.Instant) + case class DurationRow(d: java.time.Duration) + case class PeriodRow(p: java.time.Period) } class EncoderTests extends TypedDatasetSuite with Matchers { @@ -12,4 +15,16 @@ class EncoderTests extends TypedDatasetSuite with Matchers { test("It should encode deeply nested collections") { implicitly[TypedEncoder[Seq[Foo]]] } + + test("It should encode java.time.Instant") { + implicitly[TypedEncoder[InstantRow]] + } + + test("It should encode java.time.Duration") { + implicitly[TypedEncoder[DurationRow]] + } + + test("It should encode java.time.Period") { + implicitly[TypedEncoder[PeriodRow]] + } }