Skip to content

Commit

Permalink
Generate exact zio patterns for endpoints, taking into account lack o…
Browse files Browse the repository at this point in the history
…f zio's support for no-trailing-slashes (#3949)
  • Loading branch information
adamw authored Jul 23, 2024
1 parent bc9b89f commit 251ae15
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 39 deletions.
12 changes: 9 additions & 3 deletions core/src/main/scala/sttp/tapir/Tapir.scala
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,14 @@ trait TapirComputedInputs { this: Tapir =>

/** An input which matches if the request URI ends with a trailing slash, otherwise the result is a decode failure on the path. Has no
* effect when used by documentation or client interpreters.
*
* The input has the [[NoTrailingSlash.Attribute]] attribute set to `true`, which might be useful for server interpreters.
*/
val noTrailingSlash: EndpointInput[Unit] = extractFromRequest(_.uri.path).mapDecode(ps =>
if (ps.lastOption.contains("")) DecodeResult.Mismatch("", "/") else DecodeResult.Value(())
)(_ => Nil)
val noTrailingSlash: EndpointInput[Unit] = extractFromRequest(_.uri.path)
.mapDecode(ps => if (ps.lastOption.contains("")) DecodeResult.Mismatch("", "/") else DecodeResult.Value(()))(_ => Nil)
.attribute(NoTrailingSlash.Attribute, true)

object NoTrailingSlash {
val Attribute: AttributeKey[Boolean] = new AttributeKey[Boolean]("sttp.tapir.NoTrailingSlash")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import sttp.tapir.EndpointIO.{Body, StreamBodyWrapper}
import sttp.tapir.EndpointInput.{FixedPath, PathCapture, PathsCapture}
import sttp.tapir.RawBodyType.FileBody
import sttp.tapir.internal.{RichEndpoint, RichEndpointInput, RichEndpointOutput}
import sttp.tapir.{AnyEndpoint, EndpointInput, EndpointTransput, RawBodyType, noTrailingSlash}
import sttp.tapir._

private[armeria] object RouteMapping {

Expand All @@ -25,10 +25,11 @@ private[armeria] object RouteMapping {

val hasNoTrailingSlash = e.securityInput
.and(e.input)
.traverseInputs {
case i if i == noTrailingSlash => Vector(())
.asVectorOfBasicInputs()
.exists {
case i: EndpointInput.ExtractFromRequest[_] if i.attribute(NoTrailingSlash.Attribute).getOrElse(false) => true
case _ => false
}
.nonEmpty

toPathPatterns(inputs, hasNoTrailingSlash).map { path =>
// Allows all HTTP method to handle invalid requests by RejectInterceptor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.finatra.FinatraServerInterpreter.FutureMonadError
import sttp.tapir.server.interceptor.RequestResult
import sttp.tapir.server.interpreter.ServerInterpreter
import sttp.tapir.{AnyEndpoint, EndpointInput, noTrailingSlash}
import sttp.tapir._

trait FinatraServerInterpreter extends Logging {

Expand Down Expand Up @@ -61,8 +61,8 @@ trait FinatraServerInterpreter extends Logging {
}

private[finatra] def path(input: EndpointInput[_]): String = {
val p = input
.asVectorOfBasicInputs()
val basicInputs = input.asVectorOfBasicInputs()
val p = basicInputs
.collect {
case segment: EndpointInput.FixedPath[_] => segment.show
case PathCapture(Some(name), _, _) => s"/:$name"
Expand All @@ -73,9 +73,10 @@ trait FinatraServerInterpreter extends Logging {
if (p.isEmpty) "/:*"
// checking if there's an input which rejects trailing slashes; otherwise the default behavior is to accept them
else if (
input.traverseInputs {
case i if i == noTrailingSlash => Vector(())
}.isEmpty
basicInputs.exists {
case i: EndpointInput.ExtractFromRequest[_] if i.attribute(NoTrailingSlash.Attribute).getOrElse(false) => true
case _ => false
}
) p + "/?"
else p
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,30 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE](
basicStringRequest.get(uri"$baseUri/p1").send(backend).map(_.body shouldBe "e1") >>
basicStringRequest.get(uri"$baseUri/p1/p2").send(backend).map(_.body shouldBe "e2")
},
testServer(
"two endpoints with increasingly specific path inputs (w/ path capture): should match path exactly",
NonEmptyList.of(
route(
List[ServerEndpoint[Any, F]](
endpoint.get.in("p1").out(stringBody).serverLogic((_: Unit) => pureResult("e1".asRight[Unit])),
endpoint.get.in("p1" / path[String]).out(stringBody).serverLogic((_: String) => pureResult("e2".asRight[Unit]))
)
)
)
) { (backend, baseUri) =>
basicStringRequest.get(uri"$baseUri/p1").send(backend).map(_.body shouldBe "e1") >>
basicStringRequest.get(uri"$baseUri/p1/p2").send(backend).map(_.body shouldBe "e2")
},
testServer(
"two endpoints with decreasingly specific path inputs: should match path exactly",
NonEmptyList.of(
route(endpoint.get.in("p1" / "p2").out(stringBody).serverLogic((_: Unit) => pureResult("e2".asRight[Unit]))),
route(endpoint.get.in("p1").out(stringBody).serverLogic((_: Unit) => pureResult("e1".asRight[Unit])))
)
) { (backend, baseUri) =>
basicStringRequest.get(uri"$baseUri/p1").send(backend).map(_.body shouldBe "e1") >>
basicStringRequest.get(uri"$baseUri/p1/p2").send(backend).map(_.body shouldBe "e2")
},
testServer(
"two endpoints with a body defined as the first input: should only consume body when the path matches",
NonEmptyList.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,54 +58,85 @@ trait ZioHttpInterpreter[R] {
)
}

// Grouping the endpoints by path prefix template (fixed path components & single path captures). This way, if
// there are multiple endpoints - with/without trailing slash, with from-request extraction, or with path wildcards,
// they will be interpreted and disambiguated by the tapir logic, instead of ZIO HTTP's routing. Also, this covers
// multiple endpoints with different methods, and allows us to handle invalid methods.
val widenedSesGroupedByPathPrefixTemplate = widenedSes.groupBy { se =>
// here we'll keep the endpoint together with the meta-data needed to create the zio-http routing information
case class ServerEndpointWithPattern(
index: Int,
pathTemplate: String,
routePattern: RoutePattern[_],
endpoint: ZServerEndpoint[R & R2, ZioStreams with WebSockets]
)

def toPattern(se: ZServerEndpoint[R & R2, ZioStreams with WebSockets], index: Int): ServerEndpointWithPattern = {
val e = se.endpoint
val inputs = e.securityInput.and(e.input).asVectorOfBasicInputs()
val x = inputs.foldLeft("") { case (p, component) =>

// Creating the path template - no-trailing-slash inputs are treated as wildcard inputs, as they are usually
// accompanied by endpoints which handle wildcard path inputs, when the `/` is present (to serve files). They
// need to end up in the same group (see below), so that they are disambiguated by tapir's logic.
val pathTemplate = inputs.foldLeft("") { case (p, component) =>
component match {
case _: EndpointInput.PathCapture[_] => p + "/?"
case i: EndpointInput.FixedPath[_] => p + "/" + i.s
case _ => p
case _: EndpointInput.PathCapture[_] => p + "/?"
case _: EndpointInput.PathsCapture[_] => p + "/..."
case i: EndpointInput.ExtractFromRequest[_] if i.attribute(NoTrailingSlash.Attribute).getOrElse(false) => p + "/..."
case i: EndpointInput.FixedPath[_] => p + "/" + i.s
case _ => p
}
}
x
}

val handlers: List[Route[R & R2, Response]] = widenedSesGroupedByPathPrefixTemplate.toList.map { case (_, sesForPathTemplate) =>
// The pattern that we generate should be the same for all endpoints in a group
val e = sesForPathTemplate.head.endpoint
val inputs = e.securityInput.and(e.input).asVectorOfBasicInputs()

val hasPath = inputs.exists {
case _: EndpointInput.PathCapture[_] => true
case _: EndpointInput.PathsCapture[_] => true
case _: EndpointInput.FixedPath[_] => true
case _ => false
}
val hasNoTrailingSlash = inputs.exists {
case i: EndpointInput.ExtractFromRequest[_] if i.attribute(NoTrailingSlash.Attribute).getOrElse(false) => true
case _ => false
}

val pathPrefixPattern = if (hasPath) {
val routePattern = if (hasPath) {
val initialPattern = RoutePattern(Method.ANY, PathCodec.empty).asInstanceOf[RoutePattern[Any]]
val base = inputs
.foldLeft(initialPattern) { case (p, component) =>
// The second tuple parameter specifies if PathCodec.trailing should be added to the route's pattern. It can
// be added either because of a PathsCapture, or because of an noTrailingSlash input.
val (p, addTrailing) = inputs
.foldLeft((initialPattern, hasNoTrailingSlash)) { case ((p, addTrailing), component) =>
component match {
case i: EndpointInput.PathCapture[_] => (p / PathCodec.string(i.name.getOrElse("?"))).asInstanceOf[RoutePattern[Any]]
case i: EndpointInput.FixedPath[_] => p / PathCodec.literal(i.s)
case _ => p
case i: EndpointInput.PathCapture[_] =>
((p / PathCodec.string(i.name.getOrElse("?"))).asInstanceOf[RoutePattern[Any]], addTrailing)
case _: EndpointInput.PathsCapture[_] => (p, true)
case i: EndpointInput.FixedPath[_] => (p / PathCodec.literal(i.s), addTrailing)
case _ => (p, addTrailing)
}
}
// because we capture the path prefix, we add a matcher for arbitrary other path components (which might be
// handled by tapir's `paths` or `extractFromRequest`)
base / PathCodec.trailing

if (addTrailing) p / PathCodec.trailing else p
} else {
// if there are no path inputs, we return a catch-all
RoutePattern(Method.ANY, PathCodec.trailing).asInstanceOf[RoutePattern[Any]]
}

Route.handled(pathPrefixPattern)(Handler.fromFunctionHandler { (request: Request) => handleRequest(request, sesForPathTemplate) })
ServerEndpointWithPattern(index, pathTemplate, routePattern, se)
}

// Grouping the endpoints by path template. This way, if there are multiple endpoints with/without trailing slash or
// with path wildcards, they will end up in the same group, and they will be disambiguated by the tapir logic.
// That's because there's not way currently to create a zio-http route pattern which would match on
// no-trailing-slashes. A group also includes multiple endpoints with different methods, but same path.
val widenedSesGroupedByPathPrefixTemplate = widenedSes.zipWithIndex
.map { case (se, index) => toPattern(se, index) }
.groupBy(_.pathTemplate)
.toList
.map(_._2)
// we try to maintain the order of endpoints as passed by the user; this order might be changed if there are
// endpoints with/without trailing slashes, or with different methods, which are not passed as subsequent
// values in the original `ses` list
.sortBy(_.map(_.index).min)

val handlers: List[Route[R & R2, Response]] = widenedSesGroupedByPathPrefixTemplate.map { sesWithPattern =>
val pattern = sesWithPattern.head.routePattern
val endpoints = sesWithPattern.sortBy(_.index).map(_.endpoint)
// The pattern that we generate should be the same for all endpoints in a group
Route.handled(pattern)(Handler.fromFunctionHandler { (request: Request) => handleRequest(request, endpoints) })
}

Routes(Chunk.fromIterable(handlers))
Expand Down

0 comments on commit 251ae15

Please sign in to comment.