Skip to content

Commit

Permalink
feat: implement Scala 3 Constant and Union support for string-based l…
Browse files Browse the repository at this point in the history
…iterals as enums (#3846)

Co-authored-by: adamw <adam@warski.org>
  • Loading branch information
ThijsBroersen and adamw authored Jun 17, 2024
1 parent 4904700 commit 74aa169
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 0 deletions.
14 changes: 14 additions & 0 deletions core/src/main/scala-3/sttp/tapir/macros/CodecMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package sttp.tapir.macros
import sttp.tapir.CodecFormat.TextPlain
import sttp.tapir.{Codec, SchemaAnnotations, Validator}
import sttp.tapir.internal.CodecValueClassMacro
import sttp.tapir.Mapping
import sttp.tapir.DecodeResult
import sttp.tapir.DecodeResult.Value
import sttp.tapir.Schema

trait CodecMacros {

Expand Down Expand Up @@ -36,6 +40,16 @@ trait CodecMacros {
inline def derivedEnumerationValueCustomise[L, T <: scala.Enumeration#Value]: CreateDerivedEnumerationCodec[L, T] =
new CreateDerivedEnumerationCodec(derivedEnumerationValueValidator[T], SchemaAnnotations.derived[T])

/** Creates a codec for a string-based union of constant values, where the validator is derived using
* [[sttp.tapir.Validator.derivedStringBasedUnionEnumeration]]. This requires that the union is a union of string literals.
*
* @tparam T
* The type of the union.
*/
inline given derivedStringBasedUnionEnumeration[T](using IsUnionOf[String, T]): Codec[String, T, TextPlain] =
lazy val validator = Validator.derivedStringBasedUnionEnumeration[T]
Codec.string.validate(validator.asInstanceOf[Validator[String]]).map(_.asInstanceOf[T])(_.asInstanceOf[String])

/** A default codec for enumerations, which returns a string-based enumeration codec, using the enum's `.toString` to encode values, and
* performing a case-insensitive search through the possible values, converted to strings using `.toString`.
*
Expand Down
7 changes: 7 additions & 0 deletions core/src/main/scala-3/sttp/tapir/macros/SchemaMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ trait SchemaCompanionMacros extends SchemaMagnoliaDerivation {
*/
inline def derivedEnumeration[T]: CreateDerivedEnumerationSchema[T] =
new CreateDerivedEnumerationSchema(Validator.derivedEnumeration[T], SchemaAnnotations.derived[T])

inline given derivedStringBasedUnionEnumeration[S](using IsUnionOf[String, S]): Schema[S] =
lazy val validator = Validator.derivedStringBasedUnionEnumeration[S]
Schema
.string[S]
.name(SName(validator.possibleValues.toList.mkString("_or_")))
.validate(validator)
}

private[tapir] object SchemaCompanionMacros {
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/scala-3/sttp/tapir/macros/ValidatorMacros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ trait ValidatorMacros {
* first place (the decoder has no other option than to fail).
*/
inline def derivedEnumeration[T]: Validator.Enumeration[T] = ${ ValidatorMacros.derivedEnumerationImpl[T] }

inline def derivedStringBasedUnionEnumeration[T](using IsUnionOf[String, T]): Validator.Enumeration[T] = {
lazy val values = UnionDerivation.constValueUnionTuple[String, T]
Validator.enumeration(values.toList.asInstanceOf[List[T]])
}
}

private[tapir] object ValidatorMacros {
Expand Down
60 changes: 60 additions & 0 deletions core/src/main/scala-3/sttp/tapir/macros/union_derivation.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package sttp.tapir.macros

import scala.compiletime.*
import scala.deriving.*
import scala.quoted.*

@scala.annotation.implicitNotFound("${A} is not a union of ${T}")
private[tapir] sealed trait IsUnionOf[T, A]

private[tapir] object IsUnionOf:

private val singleton: IsUnionOf[Any, Any] = new IsUnionOf[Any, Any] {}

transparent inline given derived[T, A]: IsUnionOf[T, A] = ${ deriveImpl[T, A] }

private def deriveImpl[T, A](using quotes: Quotes, t: Type[T], a: Type[A]): Expr[IsUnionOf[T, A]] =
import quotes.reflect.*
val tpe: TypeRepr = TypeRepr.of[A]
val bound: TypeRepr = TypeRepr.of[T]

def validateTypes(tpe: TypeRepr): Unit =
tpe.dealias match
case o: OrType =>
validateTypes(o.left)
validateTypes(o.right)
case o =>
if o <:< bound then ()
else report.errorAndAbort(s"${o.show} is not a subtype of ${bound.show}")

tpe.dealias match
case o: OrType =>
validateTypes(o)
('{ IsUnionOf.singleton.asInstanceOf[IsUnionOf[T, A]] }).asExprOf[IsUnionOf[T, A]]
case o =>
if o <:< bound then ('{ IsUnionOf.singleton.asInstanceOf[IsUnionOf[T, A]] }).asExprOf[IsUnionOf[T, A]]
else report.errorAndAbort(s"${tpe.show} is not a Union")

private[tapir] object UnionDerivation:
transparent inline def constValueUnionTuple[T, A](using IsUnionOf[T, A]): Tuple = ${ constValueUnionTupleImpl[T, A] }

private def constValueUnionTupleImpl[T: Type, A: Type](using Quotes): Expr[Tuple] =
Expr.ofTupleFromSeq(constTypes[T, A])

private def constTypes[T: Type, A: Type](using Quotes): List[Expr[Any]] =
import quotes.reflect.*
val tpe: TypeRepr = TypeRepr.of[A]
val bound: TypeRepr = TypeRepr.of[T]

def transformTypes(tpe: TypeRepr): List[TypeRepr] =
tpe.dealias match
case o: OrType =>
transformTypes(o.left) ::: transformTypes(o.right)
case o: Constant if o <:< bound && o.isSingleton =>
o :: Nil
case o =>
report.errorAndAbort(s"${o.show} is not a subtype of ${bound.show}")

transformTypes(tpe).distinct.map(_.asType match
case '[t] => '{ constValue[t] }
)
25 changes: 25 additions & 0 deletions core/src/test/scala-3/sttp/tapir/CodecScala3Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package sttp.tapir

import org.scalatest.{Assertion, Inside}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.scalacheck.Checkers
import sttp.tapir.CodecFormat.TextPlain
import sttp.tapir.DecodeResult.Value

import sttp.tapir.DecodeResult.InvalidValue

class CodecScala3Test extends AnyFlatSpec with Matchers with Checkers with Inside {

it should "derive a codec for a string-based union type" in {
// given
val codec = summon[Codec[String, "Apple" | "Banana", TextPlain]]

// then
codec.encode("Apple") shouldBe "Apple"
codec.encode("Banana") shouldBe "Banana"
codec.decode("Apple") shouldBe Value("Apple")
codec.decode("Banana") shouldBe Value("Banana")
codec.decode("Orange") should matchPattern { case DecodeResult.InvalidValue(List(ValidationError(_, "Orange", _, _))) => }
}
}
36 changes: 36 additions & 0 deletions core/src/test/scala-3/sttp/tapir/SchemaMacroScala3Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,42 @@ class SchemaMacroScala3Test extends AnyFlatSpec with Matchers:
coproduct.subtypeSchema(true).map(_.schema.schemaType) shouldBe Some(SchemaType.SBoolean())
}

it should "derive schema for a string-based union type" in {
// when
val s: Schema["a" | "b"] = Schema.derivedStringBasedUnionEnumeration

// then
s.name.map(_.show) shouldBe Some("a_or_b")

s.schemaType should matchPattern { case SchemaType.SString() => }
s.validator should matchPattern { case Validator.Enumeration(List("a", "b"), _, _) => }
}

it should "derive schema for a const as a string-based union type" in {
// when
val s: Schema["a"] = Schema.derivedStringBasedUnionEnumeration

// then
s.name.map(_.show) shouldBe Some("a")

s.schemaType should matchPattern { case SchemaType.SString() => }
s.validator should matchPattern { case Validator.Enumeration(List("a"), _, _) => }
}

it should "derive a schema for a union of unions when all are string-based constants" in {
// when
type AorB = "a" | "b"
type C = "c"
type AorBorC = AorB | C
val s: Schema[AorBorC] = Schema.derivedStringBasedUnionEnumeration[AorBorC]

// then
s.name.map(_.show) shouldBe Some("a_or_b_or_c")

s.schemaType should matchPattern { case SchemaType.SString() => }
s.validator should matchPattern { case Validator.Enumeration(List("a", "b", "c"), _, _) => }
}

object SchemaMacroScala3Test:
enum Fruit:
case Apple, Banana
Expand Down
10 changes: 10 additions & 0 deletions core/src/test/scala-3/sttp/tapir/ValidatorScala3EnumTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ class ValidatorScala3EnumTest extends AnyFlatSpec with Matchers {
""")
}

it should "derive a validator for a string-based union type" in {
// given
val validator = Validator.derivedStringBasedUnionEnumeration["Apple" | "Banana"].asInstanceOf[Validator.Primitive[String]]

// then
validator.doValidate("Apple") shouldBe ValidationResult.Valid
validator.doValidate("Banana") shouldBe ValidationResult.Valid
validator.doValidate("Orange") shouldBe ValidationResult.Invalid()
}

}

enum ColorEnum {
Expand Down
17 changes: 17 additions & 0 deletions doc/endpoint/enumerations.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,23 @@ enum ColorEnum {
given Schema[ColorEnum] = Schema.derivedEnumeration.defaultStringBased
```

### Scala 3 string-based constant union types to enum

If a union type is a string-based constant union type, it can be auto-derived as field type or manually derived by using the `Schema.derivedStringBasedUnionEnumeration[T]` method.

Constant strings can be derived by using the `Schema.constStringToEnum[T]` method.

Examples:
```scala
val aOrB: Schema["a" | "b"] = Schema.derivedStringBasedUnionEnumeration
```
```scala
val a: Schema["a"] = Schema.constStringToEnum
```
```scala
case class Foo(aOrB: "a" | "b", optA: Option["a"]) derives Schema
```

### Creating an enum schema by hand

Creating an enumeration [schema](schema.md) by hand is exactly the same as for any other type. The only difference
Expand Down
5 changes: 5 additions & 0 deletions doc/endpoint/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ val s: Schema[StringOrInt] = Schema.derivedUnion[StringOrInt]
If any of the components of the union type is a generic type, any of its validations will be skipped when validating
the union type, as it's not possible to generate a runtime check for the generic type.

### Derivation for string-based constant union types
e.g. `type AorB = "a" | "b"`

See [enumerations](enumerations.md#scala-3-string-based-constant-union-types-to-enum) on how to use string-based unions of constant types as enums.

## Configuring derivation

It is possible to configure Magnolia's automatic derivation to use `snake_case`, `kebab-case` or a custom field naming
Expand Down

0 comments on commit 74aa169

Please sign in to comment.