Skip to content

Commit

Permalink
Merge pull request #2376 from softwaremill/docs-class-used-directly-a…
Browse files Browse the repository at this point in the history
…nd-in-coproduct

Properly support classes which are both used directly and as a member of a coproduct
  • Loading branch information
adamw authored Aug 20, 2022
2 parents 9fa97be + 26128cb commit a424da3
Show file tree
Hide file tree
Showing 17 changed files with 273 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ private[docs] object SecuritySchemesForEndpoints {
nameSecuritySchemes(tail, takenNames, acc + (scheme -> name))
case Some(((None, scheme), tail)) =>
val baseName = scheme.`type` + "Auth"
val name = uniqueName(baseName, !takenNames.contains(_))
val name = uniqueString(baseName, !takenNames.contains(_))
nameSecuritySchemes(tail, takenNames + name, acc + (scheme -> name))
case None => acc
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package object apispec {
(shortName +: info.typeParameterShortNames).mkString("_")
}

private[docs] def uniqueName(base: String, isUnique: String => Boolean): String = {
private[docs] def uniqueString(base: String, isUnique: String => Boolean): String = {
var i = 0
var result = base
while (!isUnique(result)) {
Expand All @@ -26,15 +26,15 @@ package object apispec {
result
}

private def rawToString[T](v: Any): String = v match {
private def rawToString(v: Any): String = v match {
case a: Array[Byte] => new String(a, "UTF-8")
case b: ByteBuffer => Charset.forName("UTF-8").decode(b).toString
case _ => v.toString
}

private[docs] def exampleValue[T](v: String): ExampleValue = ExampleSingleValue(v)
private[docs] def exampleValue(v: String): ExampleValue = ExampleSingleValue(v)
private[docs] def exampleValue[T](codec: Codec[_, T, _], e: T): Option[ExampleValue] = exampleValue(codec.schema, codec.encode(e))
private[docs] def exampleValue[T](schema: Schema[_], raw: Any): Option[ExampleValue] = {
private[docs] def exampleValue(schema: Schema[_], raw: Any): Option[ExampleValue] = {
(raw, schema.schemaType) match {
case (it: Iterable[_], SchemaType.SArray(_)) => Some(ExampleMultipleValue(it.map(rawToString).toList))
case (it: Iterable[_], _) => it.headOption.map(v => ExampleSingleValue(rawToString(v)))
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package sttp.tapir.docs.apispec.schema

import sttp.tapir.{SchemaType, Schema => TSchema}

/** A schema key consists of both the name, and the schema's fields, in case this is a product. This is needed as the same schema name can
* have two different sets of fields, in case the class is a member of an inheritance hierarchy, and a discriminator field is used (#2358).
*/
private[docs] case class SchemaKey(name: TSchema.SName, fields: Set[String])

private[docs] object SchemaKey {
def apply(schema: TSchema[_]): Option[SchemaKey] = schema.name.map(apply(schema, _))

def apply(schema: TSchema[_], name: TSchema.SName): SchemaKey = {
val fields = schema.schemaType match {
case SchemaType.SProduct(fields) => fields.map(_.name.name).toSet
case SchemaType.SOpenProduct(fields, _) => fields.map(_.name.name).toSet
case _ => Set.empty[String]
}

SchemaKey(name, fields)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ import sttp.tapir.{Codec, Schema => TSchema, SchemaType => TSchemaType}
/** Converts a tapir schema to an OpenAPI/AsyncAPI reference (if the schema is named), or to the appropriate schema. */
class Schemas(
tschemaToASchema: TSchemaToASchema,
nameToSchemaReference: NameToSchemaReference,
toSchemaReference: ToSchemaReference,
markOptionsAsNullable: Boolean
) {
def apply[T](codec: Codec[T, _, _]): ReferenceOr[ASchema] = apply(codec.schema)

def apply(schema: TSchema[_]): ReferenceOr[ASchema] = {
schema.name match {
case Some(name) => Left(nameToSchemaReference.map(name))
SchemaKey(schema) match {
case Some(key) => Left(toSchemaReference.map(key))
case None =>
schema.schemaType match {
case TSchemaType.SArray(TSchema(_, Some(name), isOptional, _, _, _, _, _, _, _, _)) =>
Right(ASchema(SchemaType.Array).copy(items = Some(Left(nameToSchemaReference.map(name)))))
case TSchemaType.SArray(nested @ TSchema(_, Some(name), isOptional, _, _, _, _, _, _, _, _)) =>
Right(ASchema(SchemaType.Array).copy(items = Some(Left(toSchemaReference.map(SchemaKey(nested, name))))))
.map(s => if (isOptional && markOptionsAsNullable) s.copy(nullable = Some(true)) else s)
case TSchemaType.SOption(ts) => apply(ts)
case _ => tschemaToASchema(schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,33 @@ import scala.collection.immutable.ListMap
class SchemasForEndpoints(
es: Iterable[AnyEndpoint],
schemaName: SName => String,
toNamedSchemas: ToNamedSchemas,
toKeyedSchemas: ToKeyedSchemas,
markOptionsAsNullable: Boolean
) {

def apply(): (ListMap[ObjectKey, ReferenceOr[ASchema]], Schemas) = {
val sObjects = ToNamedSchemas.unique(
def apply(): (ListMap[SchemaId, ReferenceOr[ASchema]], Schemas) = {
val keyedSchemas = ToKeyedSchemas.unique(
es.flatMap(e => forInput(e.securityInput) ++ forInput(e.input) ++ forOutput(e.errorOutput) ++ forOutput(e.output))
)
val infoToKey = calculateUniqueKeys(sObjects.map(_._1), schemaName)
val keysToIds = calculateUniqueIds(keyedSchemas.map(_._1), (key: SchemaKey) => schemaName(key.name))

val objectToSchemaReference = new NameToSchemaReference(infoToKey)
val tschemaToASchema = new TSchemaToASchema(objectToSchemaReference, markOptionsAsNullable)
val schemas = new Schemas(tschemaToASchema, objectToSchemaReference, markOptionsAsNullable)
val infosToSchema = sObjects.map(td => (td._1, tschemaToASchema(td._2))).toListMap
val toSchemaReference = new ToSchemaReference(keysToIds)
val tschemaToASchema = new TSchemaToASchema(toSchemaReference, markOptionsAsNullable)
val schemas = new Schemas(tschemaToASchema, toSchemaReference, markOptionsAsNullable)
val keysToSchemas = keyedSchemas.map(td => (td._1, tschemaToASchema(td._2))).toListMap

val schemaKeys = infosToSchema.map { case (k, v) => k -> ((infoToKey(k), v)) }
(schemaKeys.values.toListMap, schemas)
val schemaIds = keysToSchemas.map { case (k, v) => k -> ((keysToIds(k), v)) }
(schemaIds.values.toListMap, schemas)
}

private def forInput(input: EndpointInput[_]): List[NamedSchema] = {
private def forInput(input: EndpointInput[_]): List[KeyedSchema] = {
input match {
case EndpointInput.FixedMethod(_, _, _) => List.empty
case EndpointInput.FixedPath(_, _, _) => List.empty
case EndpointInput.PathCapture(_, codec, _) => toNamedSchemas(codec)
case EndpointInput.PathCapture(_, codec, _) => toKeyedSchemas(codec)
case EndpointInput.PathsCapture(_, _) => List.empty
case EndpointInput.Query(_, _, codec, _) => toNamedSchemas(codec)
case EndpointInput.Cookie(_, codec, _) => toNamedSchemas(codec)
case EndpointInput.Query(_, _, codec, _) => toKeyedSchemas(codec)
case EndpointInput.Cookie(_, codec, _) => toKeyedSchemas(codec)
case EndpointInput.QueryParams(_, _) => List.empty
case _: EndpointInput.Auth[_, _] => List.empty
case _: EndpointInput.ExtractFromRequest[_] => List.empty
Expand All @@ -45,7 +45,7 @@ class SchemasForEndpoints(
case op: EndpointIO[_] => forIO(op)
}
}
private def forOutput(output: EndpointOutput[_]): List[NamedSchema] = {
private def forOutput(output: EndpointOutput[_]): List[KeyedSchema] = {
output match {
case EndpointOutput.OneOf(variants, _) => variants.flatMap(variant => forOutput(variant.output)).toList
case EndpointOutput.StatusCode(_, _, _) => List.empty
Expand All @@ -54,19 +54,19 @@ class SchemasForEndpoints(
case EndpointOutput.Void() => List.empty
case EndpointOutput.Pair(left, right, _, _) => forOutput(left) ++ forOutput(right)
case EndpointOutput.WebSocketBodyWrapper(wrapped) =>
toNamedSchemas(wrapped.codec) ++ toNamedSchemas(wrapped.requests) ++ toNamedSchemas(wrapped.responses)
toKeyedSchemas(wrapped.codec) ++ toKeyedSchemas(wrapped.requests) ++ toKeyedSchemas(wrapped.responses)
case op: EndpointIO[_] => forIO(op)
}
}

private def forIO(io: EndpointIO[_]): List[NamedSchema] = {
private def forIO(io: EndpointIO[_]): List[KeyedSchema] = {
io match {
case EndpointIO.Pair(left, right, _, _) => forIO(left) ++ forIO(right)
case EndpointIO.Header(_, codec, _) => toNamedSchemas(codec)
case EndpointIO.Header(_, codec, _) => toKeyedSchemas(codec)
case EndpointIO.Headers(_, _) => List.empty
case EndpointIO.Body(_, codec, _) => toNamedSchemas(codec)
case EndpointIO.Body(_, codec, _) => toKeyedSchemas(codec)
case EndpointIO.OneOfBody(variants, _) => variants.flatMap(v => forIO(v.bodyAsAtom))
case EndpointIO.StreamBodyWrapper(StreamBodyIO(_, codec, _, _, _)) => toNamedSchemas(codec.schema)
case EndpointIO.StreamBodyWrapper(StreamBodyIO(_, codec, _, _, _)) => toKeyedSchemas(codec.schema)
case EndpointIO.MappedPair(wrapped, _) => forIO(wrapped)
case EndpointIO.FixedHeader(_, _, _) => List.empty
case EndpointIO.Empty(_, _) => List.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import sttp.tapir.internal.{IterableToListMap, _}
import sttp.tapir.{Validator, Schema => TSchema, SchemaType => TSchemaType}

/** Converts a tapir schema to an OpenAPI/AsyncAPI schema, using the given map to resolve nested references. */
private[schema] class TSchemaToASchema(nameToSchemaReference: NameToSchemaReference, markOptionsAsNullable: Boolean) {
private[schema] class TSchemaToASchema(toSchemaReference: ToSchemaReference, markOptionsAsNullable: Boolean) {
def apply[T](schema: TSchema[T], isOptionElement: Boolean = false): ReferenceOr[ASchema] = {
val nullable = markOptionsAsNullable && isOptionElement
val result = schema.schemaType match {
Expand All @@ -23,24 +23,25 @@ private[schema] class TSchemaToASchema(nameToSchemaReference: NameToSchemaRefere
properties = extractProperties(fields)
)
)
case TSchemaType.SArray(TSchema(_, Some(name), _, _, _, _, _, _, _, _, _)) =>
Right(ASchema(SchemaType.Array).copy(items = Some(Left(nameToSchemaReference.map(name)))))
case TSchemaType.SArray(nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _)) =>
Right(ASchema(SchemaType.Array).copy(items = Some(Left(toSchemaReference.map(SchemaKey(nested, name))))))
case TSchemaType.SArray(el) => Right(ASchema(SchemaType.Array).copy(items = Some(apply(el))))
case TSchemaType.SOption(TSchema(_, Some(name), _, _, _, _, _, _, _, _, _)) => Left(nameToSchemaReference.map(name))
case TSchemaType.SOption(el) => apply(el, isOptionElement = true)
case TSchemaType.SOption(nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _)) =>
Left(toSchemaReference.map(SchemaKey(nested, name)))
case TSchemaType.SOption(el) => apply(el, isOptionElement = true)
case TSchemaType.SBinary() => Right(ASchema(SchemaType.String).copy(format = SchemaFormat.Binary))
case TSchemaType.SDate() => Right(ASchema(SchemaType.String).copy(format = SchemaFormat.Date))
case TSchemaType.SDateTime() => Right(ASchema(SchemaType.String).copy(format = SchemaFormat.DateTime))
case TSchemaType.SRef(fullName) => Left(nameToSchemaReference.map(fullName))
case TSchemaType.SRef(fullName) => Left(toSchemaReference.mapDirect(fullName))
case TSchemaType.SCoproduct(schemas, d) =>
Right(
ASchema
.apply(
schemas
.filterNot(_.hidden)
.map {
case TSchema(_, Some(name), _, _, _, _, _, _, _, _, _) => Left(nameToSchemaReference.map(name))
case t => apply(t)
case nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _) => Left(toSchemaReference.map(SchemaKey(nested, name)))
case t => apply(t)
}
.sortBy {
case Left(Reference(ref)) => ref
Expand All @@ -54,9 +55,9 @@ private[schema] class TSchemaToASchema(nameToSchemaReference: NameToSchemaRefere
ASchema(SchemaType.Object).copy(
required = p.required.map(_.encodedName),
properties = extractProperties(fields),
additionalProperties = Some(valueSchema.name match {
case Some(name) => Left(nameToSchemaReference.map(name))
case _ => apply(valueSchema)
additionalProperties = Some(SchemaKey(valueSchema) match {
case Some(key) => Left(toSchemaReference.map(key))
case _ => apply(valueSchema)
}).filterNot(_ => valueSchema.hidden)
)
)
Expand All @@ -78,9 +79,9 @@ private[schema] class TSchemaToASchema(nameToSchemaReference: NameToSchemaRefere
fields
.filterNot(_.schema.hidden)
.map { f =>
f.schema match {
case TSchema(_, Some(name), _, _, _, _, _, _, _, _, _) => f.name.encodedName -> Left(nameToSchemaReference.map(name))
case schema => f.name.encodedName -> apply(schema)
SchemaKey(f.schema) match {
case Some(key) => f.name.encodedName -> Left(toSchemaReference.map(key))
case None => f.name.encodedName -> apply(f.schema)
}
}
.toListMap
Expand Down Expand Up @@ -148,9 +149,13 @@ private[schema] class TSchemaToASchema(nameToSchemaReference: NameToSchemaRefere

private def tDiscriminatorToADiscriminator(discriminator: TSchemaType.SDiscriminator): Discriminator = {
val schemas = Some(
discriminator.mapping.map { case (k, TSchemaType.SRef(fullName)) =>
k -> nameToSchemaReference.map(fullName).$ref
}.toListMap
discriminator.mapping
.map { case (k, TSchemaType.SRef(fullName)) =>
k -> toSchemaReference.mapDiscriminator(fullName).$ref
}
.toList
.sortBy(_._1)
.toListMap
)
Discriminator(discriminator.name.encodedName, schemas)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@ import sttp.tapir.{Codec, Schema => TSchema, SchemaType => TSchemaType}

import scala.collection.mutable.ListBuffer

class ToNamedSchemas {
def apply[T](codec: Codec[_, T, _]): List[NamedSchema] = apply(codec.schema)
class ToKeyedSchemas {
def apply[T](codec: Codec[_, T, _]): List[KeyedSchema] = apply(codec.schema)

def apply(schema: TSchema[_]): List[NamedSchema] = {
val thisSchema = schema.name match {
case Some(name) => List(name -> schema)
case None => Nil
}
def apply(schema: TSchema[_]): List[KeyedSchema] = {
val thisSchema = SchemaKey(schema).map(_ -> schema).toList
val nestedSchemas = schema match {
case TSchema(TSchemaType.SArray(o), _, _, _, _, _, _, _, _, _, _) => apply(o)
case t @ TSchema(o: TSchemaType.SOption[_, _], _, _, _, _, _, _, _, _, _, _) =>
Expand All @@ -27,19 +24,20 @@ class ToNamedSchemas {
thisSchema ++ nestedSchemas
}

private def productSchemas[T](st: TSchemaType.SProduct[T]): List[NamedSchema] = st.fields.flatMap(a => apply(a.schema))
private def productSchemas[T](st: TSchemaType.SProduct[T]): List[KeyedSchema] = st.fields.flatMap(a => apply(a.schema))

private def coproductSchemas[T](st: TSchemaType.SCoproduct[T]): List[NamedSchema] = st.subtypes.flatMap(apply)
private def coproductSchemas[T](st: TSchemaType.SCoproduct[T]): List[KeyedSchema] = st.subtypes.flatMap(apply)
}

object ToNamedSchemas {
object ToKeyedSchemas {

/** Keeps only the first object data for each `SName`. In case of recursive objects, the first one is the most complete as it contains the
* built-up structure, unlike subsequent ones, which only represent leaves (#354).
/** Keeps only the first object data for each [[SchemaKey]]. In case of recursive objects, the first one is the most complete as it
* contains the built-up structure, unlike subsequent ones, which only represent leaves (#354, later extended for #2358, so that the
* schemas have a secondary key - the product fields (if any)).
*/
def unique(objs: Iterable[NamedSchema]): Iterable[NamedSchema] = {
val seen: collection.mutable.Set[TSchema.SName] = collection.mutable.Set()
val result: ListBuffer[NamedSchema] = ListBuffer()
def unique(objs: Iterable[KeyedSchema]): Iterable[KeyedSchema] = {
val seen: collection.mutable.Set[SchemaKey] = collection.mutable.Set()
val result: ListBuffer[KeyedSchema] = ListBuffer()
objs.foreach { obj =>
if (!seen.contains(obj._1)) {
seen.add(obj._1)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package sttp.tapir.docs.apispec.schema

import sttp.apispec.Reference
import sttp.tapir.{Schema => TSchema}

private[schema] class ToSchemaReference(keyToId: Map[SchemaKey, SchemaId]) {

def map(key: SchemaKey): Reference = map(key.name, keyToId.get(key))

/** When mapping a reference used directly (which contains a name only), in case there are multiple schema keys with that name, we use the
* one with a smaller number of fields - as these duplicates must be because one is a member of an inheritance hierarchy; then, we choose
* the variant without the extra discriminator field (#2358).
*
* When mapping a referenced used in a discriminator, we choose the variant with the higher number of fields in [[mapDiscriminator]].
*/
def mapDirect(name: TSchema.SName): Reference =
map(name, keyToId.filter(_._1.name == name).toList.sortBy(_._1.fields.size).headOption.map(_._2))

def mapDiscriminator(name: TSchema.SName): Reference =
map(name, keyToId.filter(_._1.name == name).toList.sortBy(-_._1.fields.size).headOption.map(_._2))

private def map(name: TSchema.SName, maybeId: Option[SchemaId]): Reference = maybeId match {
case Some(id) =>
Reference.to("#/components/schemas/", id)
case None =>
// no reference to internal model found. assuming external reference
Reference(name.fullName)
}
}
Loading

0 comments on commit a424da3

Please sign in to comment.