Skip to content

Commit

Permalink
feat: implement scala 3 enumeration support
Browse files Browse the repository at this point in the history
  • Loading branch information
ThijsBroersen committed Feb 21, 2024
1 parent 008e5b1 commit 6bca314
Showing 1 changed file with 81 additions and 24 deletions.
105 changes: 81 additions & 24 deletions zio-json/shared/src/main/scala-3/zio/json/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -193,21 +193,7 @@ final class jsonNoExtraFields extends Annotation
*/
final class jsonExclude extends Annotation

// TODO: implement same configuration as for Scala 2 once this issue is resolved: https://github.com/softwaremill/magnolia/issues/296
object DeriveJsonDecoder extends Derivation[JsonDecoder] { self =>
def join[A](ctx: CaseClass[Typeclass, A]): JsonDecoder[A] = {
val (transformNames, nameTransform): (Boolean, String => String) =
ctx.annotations.collectFirst { case jsonMemberNames(format) => format }
.map(true -> _)
.getOrElse(false -> identity)

val no_extra = ctx
.annotations
.collectFirst { case _: jsonNoExtraFields => () }
.isDefined

if (ctx.params.isEmpty) {
new JsonDecoder[A] {
private class CaseObjectDecoder[Typeclass[*], A](val ctx: CaseClass[Typeclass, A], no_extra: Boolean) extends JsonDecoder[A] {
def unsafeDecode(trace: List[JsonError], in: RetractReader): A = {
if (no_extra) {
Lexer.char(trace, in, '{')
Expand All @@ -225,6 +211,22 @@ object DeriveJsonDecoder extends Derivation[JsonDecoder] { self =>
case _ => throw UnsafeJson(JsonError.Message("Not an object") :: trace)
}
}

// TODO: implement same configuration as for Scala 2 once this issue is resolved: https://github.com/softwaremill/magnolia/issues/296
object DeriveJsonDecoder extends Derivation[JsonDecoder] { self =>
def join[A](ctx: CaseClass[Typeclass, A]): JsonDecoder[A] = {
val (transformNames, nameTransform): (Boolean, String => String) =
ctx.annotations.collectFirst { case jsonMemberNames(format) => format }
.map(true -> _)
.getOrElse(false -> identity)

val no_extra = ctx
.annotations
.collectFirst { case _: jsonNoExtraFields => () }
.isDefined

if (ctx.params.isEmpty) {
new CaseObjectDecoder(ctx, no_extra)
} else {
new JsonDecoder[A] {
val (names, aliases): (Array[String], Array[(String, Int)]) = {
Expand Down Expand Up @@ -384,9 +386,32 @@ object DeriveJsonDecoder extends Derivation[JsonDecoder] { self =>
lazy val namesMap: Map[String, Int] =
names.zipWithIndex.toMap

def isEnumeration = ctx.subtypes.forall(_.typeclass.isInstanceOf[CaseObjectDecoder[?, ?]])

def discrim = ctx.annotations.collectFirst { case jsonDiscriminator(n) => n }

if (discrim.isEmpty) {
if (isEnumeration) {
new JsonDecoder[A] {
def unsafeDecode(trace: List[JsonError], in: RetractReader): A = {
val typeName = Lexer.string(trace, in).toString()
namesMap.find(_._1 == typeName) match {
case Some((_, idx)) => tcs(idx).asInstanceOf[CaseObjectDecoder[JsonDecoder, A]].ctx.rawConstruct(Nil)
case None => throw UnsafeJson(JsonError.Message(s"Invalid enumeration value $typeName") :: trace)
}
}

override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A = {
json match {
case Json.Str(typeName) =>
ctx.subtypes.find(_.typeInfo.short == typeName) match {
case Some(sub) => sub.typeclass.asInstanceOf[CaseObjectDecoder[JsonDecoder, A]].ctx.rawConstruct(Nil)
case None => throw UnsafeJson(JsonError.Message(s"Invalid enumeration value $typeName") :: trace)
}
case _ => throw UnsafeJson(JsonError.Message("Not a string") :: trace)
}
}
}
} else if (discrim.isEmpty) {
// We're not allowing extra fields in this encoding
new JsonDecoder[A] {
val spans: Array[JsonError] = names.map(JsonError.ObjectAccess(_))
Expand Down Expand Up @@ -490,16 +515,18 @@ object DeriveJsonDecoder extends Derivation[JsonDecoder] { self =>
}
}

private lazy val caseObjectEncoder = new JsonEncoder[Any] {
def unsafeEncode(a: Any, indent: Option[Int], out: Write): Unit =
out.write("{}")

override final def toJsonAST(a: Any): Either[String, Json] =
Right(Json.Obj(Chunk.empty))
}

object DeriveJsonEncoder extends Derivation[JsonEncoder] { self =>
def join[A](ctx: CaseClass[Typeclass, A]): JsonEncoder[A] =
if (ctx.params.isEmpty) {
new JsonEncoder[A] {
def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit =
out.write("{}")

override final def toJsonAST(a: A): Either[String, Json] =
Right(Json.Obj(Chunk.empty))
}
caseObjectEncoder.narrow[A]
} else {
new JsonEncoder[A] {
val (transformNames, nameTransform): (Boolean, String => String) =
Expand Down Expand Up @@ -595,13 +622,43 @@ object DeriveJsonEncoder extends Derivation[JsonEncoder] { self =>
}

def split[A](ctx: SealedTrait[JsonEncoder, A]): JsonEncoder[A] = {
val isEnumeration = ctx.subtypes.forall(_.typeclass == caseObjectEncoder)

val discrim = ctx
.annotations
.collectFirst {
case jsonDiscriminator(n) => n
}

if (discrim.isEmpty) {
if (isEnumeration) {
new JsonEncoder[A] {
def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit = {
val typeName = ctx.choose(a) { sub =>
sub
.annotations
.collectFirst {
case jsonHint(name) => name
}.getOrElse(sub.typeInfo.short)
}

JsonEncoder.string.unsafeEncode(typeName, indent, out)
}

override final def toJsonAST(a: A): Either[String, Json] = {
ctx.choose(a) { sub =>
Right(
Json.Str(
sub
.annotations
.collectFirst {
case jsonHint(name) => name
}.getOrElse(sub.typeInfo.short)
)
)
}
}
}
} else if (discrim.isEmpty) {
new JsonEncoder[A] {
def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit = {
ctx.choose(a) { sub =>
Expand Down

0 comments on commit 6bca314

Please sign in to comment.