Skip to content

Commit

Permalink
Implement java.time.{Duration, Instant, Period} type encoders (#581)
Browse files Browse the repository at this point in the history
* Implement java.time.{Duration, Instant, Period} type encoders

Co-authored-by: Grigory Pomadchin <gr.pomadchin@gmail.com>
  • Loading branch information
jgoday and pomadchin authored Mar 5, 2022
1 parent 69a1b59 commit 18fc6d2
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 8 deletions.
4 changes: 4 additions & 0 deletions core/src/main/scala/frameless/CatalystOrdered.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package frameless
import scala.annotation.implicitNotFound
import shapeless.{Generic, HList, Lazy}
import shapeless.ops.hlist.LiftAll
import java.time.{Duration, Instant, Period}

/** Types that can be ordered/compared by Catalyst. */
@implicitNotFound("Cannot compare columns of type ${A}.")
Expand All @@ -23,6 +24,9 @@ object CatalystOrdered {
implicit val framelessSQLDateOrdered : CatalystOrdered[SQLDate] = of[SQLDate]
implicit val framelessSQLTimestampOrdered: CatalystOrdered[SQLTimestamp] = of[SQLTimestamp]
implicit val framelessStringOrdered : CatalystOrdered[String] = of[String]
implicit val framelessInstantOrdered : CatalystOrdered[Instant] = of[Instant]
implicit val framelessDurationOrdered : CatalystOrdered[Duration] = of[Duration]
implicit val framelessPeriodOrdered : CatalystOrdered[Period] = of[Period]

implicit def injectionOrdered[A, B]
(implicit
Expand Down
42 changes: 41 additions & 1 deletion dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package frameless

import java.time.{Duration, Instant, Period}
import scala.reflect.ClassTag

import org.apache.spark.sql.FramelessInternals
import org.apache.spark.sql.FramelessInternals.UserDefinedType
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -230,6 +231,45 @@ object TypedEncoder {
)
}

/** java.time Encoders, Spark uses https://github.com/apache/spark/blob/v3.2.0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala for encoding / decoding. */
implicit val timeInstant: TypedEncoder[Instant] = new TypedEncoder[Instant] {
def nullable: Boolean = false

def jvmRepr: DataType = ScalaReflection.dataTypeFor[Instant]
def catalystRepr: DataType = TimestampType

def toCatalyst(path: Expression): Expression =
StaticInvoke(
DateTimeUtils.getClass,
TimestampType,
"instantToMicros",
path :: Nil,
returnNullable = false)

def fromCatalyst(path: Expression): Expression =
StaticInvoke(
staticObject = DateTimeUtils.getClass,
dataType = jvmRepr,
functionName = "microsToInstant",
arguments = path :: Nil,
propagateNull = true
)
}

/**
* DayTimeIntervalType and YearMonthIntervalType in Spark 3.2.0.
* We maintain Spark 3.x cross compilation and handle Duration and Period as an injections to be compatible with Spark versions < 3.2
* See
* * https://github.com/apache/spark/blob/v3.2.0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala#L1031-L1047
* * https://github.com/apache/spark/blob/v3.2.0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala#L1075-L1087
*/
// DayTimeIntervalType
implicit val timeDurationInjection: Injection[Duration, Long] = Injection(_.toMillis, Duration.ofMillis)
// YearMonthIntervalType
implicit val timePeriodInjection: Injection[Period, Int] = Injection(_.getDays, Period.ofDays)
implicit val timePeriodEncoder: TypedEncoder[Period] = TypedEncoder.usingInjection
implicit val timeDurationEncoder: TypedEncoder[Duration] = TypedEncoder.usingInjection

implicit def arrayEncoder[T: ClassTag](
implicit i0: Lazy[RecordFieldEncoder[T]]): TypedEncoder[Array[T]] =
new TypedEncoder[Array[T]] {
Expand Down
23 changes: 16 additions & 7 deletions dataset/src/test/scala/frameless/ColumnTests.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package frameless

import java.time.Instant

import java.time.{Instant, Period, Duration}
import org.scalacheck.Prop._
import org.scalacheck.{Arbitrary, Gen, Prop}, Arbitrary.arbitrary
import org.scalatest.matchers.should.Matchers
Expand All @@ -14,11 +13,17 @@ final class ColumnTests extends TypedDatasetSuite with Matchers {
private implicit object OrderingImplicits {
implicit val sqlDateOrdering: Ordering[SQLDate] = Ordering.by(_.days)
implicit val sqlTimestmapOrdering: Ordering[SQLTimestamp] = Ordering.by(_.us)
implicit val arbInstant: Arbitrary[Instant] = Arbitrary(
Gen.chooseNum(0L, Instant.MAX.getEpochSecond)
.map(Instant.ofEpochSecond))
implicit val instantAsLongInjection: Injection[Instant, Long] =
Injection(_.getEpochSecond, Instant.ofEpochSecond)
implicit val periodOrdering: Ordering[Period] = Ordering.by(p => (p.getYears, p.getMonths, p.getDays))
/**
* DateTimeUtils.instantToMicros supports dates starting 1970-01-01T00:00:00Z, which is Instant.EPOCH.
* This function also overflows on Instant.MAX, to be sure it never overflows we use Instant.MAX / 4.
* For implementation details check the org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros function details.
*/

val genInstant = Gen.choose[Instant](Instant.EPOCH, Instant.ofEpochMilli(Instant.MAX.getEpochSecond / 4))
implicit val arbInstant: Arbitrary[Instant] = Arbitrary(genInstant)
implicit val arbDuration: Arbitrary[Duration] = Arbitrary(genInstant.map(i => Duration.ofMillis(i.toEpochMilli)))
implicit val arbPeriod: Arbitrary[Period] = Arbitrary(Gen.chooseNum(0, Int.MaxValue).map(l => Period.of(l, l, l)))
}

test("select('a < 'b, 'a <= 'b, 'a > 'b, 'a >= 'b)") {
Expand Down Expand Up @@ -49,6 +54,8 @@ final class ColumnTests extends TypedDatasetSuite with Matchers {
check(forAll(prop[SQLTimestamp] _))
check(forAll(prop[String] _))
check(forAll(prop[Instant] _))
check(forAll(prop[Duration] _))
check(forAll(prop[Period] _))
}

test("between") {
Expand Down Expand Up @@ -76,6 +83,8 @@ final class ColumnTests extends TypedDatasetSuite with Matchers {
check(forAll(prop[SQLTimestamp] _))
check(forAll(prop[String] _))
check(forAll(prop[Instant] _))
check(forAll(prop[Duration] _))
check(forAll(prop[Period] _))
}

test("toString") {
Expand Down
15 changes: 15 additions & 0 deletions dataset/src/test/scala/frameless/EncoderTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import org.scalatest.matchers.should.Matchers

object EncoderTests {
case class Foo(s: Seq[(Int, Int)])
case class InstantRow(i: java.time.Instant)
case class DurationRow(d: java.time.Duration)
case class PeriodRow(p: java.time.Period)
}

class EncoderTests extends TypedDatasetSuite with Matchers {
Expand All @@ -12,4 +15,16 @@ class EncoderTests extends TypedDatasetSuite with Matchers {
test("It should encode deeply nested collections") {
implicitly[TypedEncoder[Seq[Foo]]]
}

test("It should encode java.time.Instant") {
implicitly[TypedEncoder[InstantRow]]
}

test("It should encode java.time.Duration") {
implicitly[TypedEncoder[DurationRow]]
}

test("It should encode java.time.Period") {
implicitly[TypedEncoder[PeriodRow]]
}
}

0 comments on commit 18fc6d2

Please sign in to comment.