Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed handling of CORS within Vert.x #4232

Merged
merged 5 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2112,6 +2112,7 @@ lazy val examples: ProjectMatrix = (projectMatrix in file("examples"))
sttpStubServer,
swaggerUiBundle,
redocBundle,
vertxServer,
zioHttpServer,
zioJson,
zioMetrics
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// {cat=Security; effects=Future; server=Vert.x}: CORS interceptor

//> using dep com.softwaremill.sttp.tapir::tapir-vertx-server:1.11.11
//> using dep com.softwaremill.sttp.client3::core:3.10.2

package sttp.tapir.examples.security

import io.vertx.core.Vertx
import io.vertx.ext.web.*
import sttp.client3.*
import sttp.model.headers.Origin
import sttp.model.{Header, HeaderNames, Method, StatusCode}
import sttp.tapir.*
import sttp.tapir.server.interceptor.cors.{CORSConfig, CORSInterceptor}
import sttp.tapir.server.vertx.VertxFutureServerInterpreter.*
import sttp.tapir.server.vertx.{VertxFutureServerInterpreter, VertxFutureServerOptions}

import scala.concurrent.duration.*
import scala.concurrent.{Await, ExecutionContext, Future}

@main def corsInterceptorVertxServer() =
given ExecutionContext = scala.concurrent.ExecutionContext.Implicits.global
val vertx = Vertx.vertx()

val server = vertx.createHttpServer()
val router = Router.router(vertx)

val myEndpoint = endpoint.get
.in("path")
.out(plainBody[String])
.serverLogic(_ => Future(Right("OK")))

val corsInterceptor = VertxFutureServerOptions.customiseInterceptors
.corsInterceptor(
CORSInterceptor.customOrThrow(
CORSConfig.default
.allowOrigin(Origin.Host("http", "my.origin"))
.allowMethods(Method.GET)
)
)
.options

val attach = VertxFutureServerInterpreter(corsInterceptor).route(myEndpoint)
attach(router)

// starting the server
val bindAndCheck = server.requestHandler(router).listen(9000).asScala.map { binding =>
val backend = HttpClientSyncBackend()

// Sending preflight request with allowed origin
val preflightResponse = basicRequest
.options(uri"http://localhost:9000/path")
.headers(
Header.origin(Origin.Host("http", "my.origin")),
Header.accessControlRequestMethod(Method.GET)
)
.send(backend)

assert(preflightResponse.code == StatusCode.NoContent)
assert(preflightResponse.headers.contains(Header.accessControlAllowOrigin("http://my.origin")))
assert(preflightResponse.headers.contains(Header.accessControlAllowMethods(Method.GET)))

println("Got expected response for preflight request")

// Sending preflight request with disallowed origin
val preflightResponseForDisallowedOrigin = basicRequest
.options(uri"http://localhost:9000/path")
.headers(
Header.origin(Origin.Host("http", "disallowed.com")),
Header.accessControlRequestMethod(Method.GET)
)
.send(backend)

// Check response does not contain allowed origin header
assert(preflightResponseForDisallowedOrigin.code == StatusCode.NoContent)
assert(!preflightResponseForDisallowedOrigin.headers.contains(Header.accessControlAllowOrigin("http://example.com")))

println("Got expected response for preflight request for wrong origin. No allowed origin header in response")

// Sending regular request from allowed origin
val requestResponse = basicRequest
.response(asStringAlways)
.get(uri"http://localhost:9000/path")
.headers(Header.origin(Origin.Host("http", "my.origin")))
.send(backend)

assert(requestResponse.code == StatusCode.Ok)
assert(requestResponse.body == "OK")
assert(requestResponse.headers.contains(Header.vary(HeaderNames.Origin)))
assert(requestResponse.headers.contains(Header.accessControlAllowOrigin("http://my.origin")))

println("Got expected response for regular request")

binding
}

Await.result(bindAndCheck.flatMap(_.close().asScala), 1.minute)
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ trait VertxCatsServerInterpreter[F[_]] extends CommonServerInterpreter with Vert
def route(
e: ServerEndpoint[Fs2Streams[F] with WebSockets, F]
): Router => Route = { router =>
val routeDef = extractRouteDefinition(e.endpoint)
val readStreamCompatible = fs2ReadStreamCompatible(vertxCatsServerOptions)
mountWithDefaultHandlers(e)(router, extractRouteDefinition(e.endpoint), vertxCatsServerOptions)
optionsRouteIfCORSDefined(e)(router, routeDef, vertxCatsServerOptions)
.foreach(_.handler(endpointHandler(e, readStreamCompatible)))
mountWithDefaultHandlers(e)(router, routeDef, vertxCatsServerOptions)
.handler(endpointHandler(e, readStreamCompatible))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package sttp.tapir.server.vertx

import io.vertx.core.{Handler, Future => VFuture}
import io.vertx.ext.web.{Route, Router, RoutingContext}
import sttp.monad.FutureMonad
import sttp.capabilities.WebSockets
import sttp.monad.FutureMonad
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.interceptor.RequestResult
import sttp.tapir.server.interpreter.{BodyListener, ServerInterpreter}
Expand All @@ -26,7 +26,10 @@ trait VertxFutureServerInterpreter extends CommonServerInterpreter with VertxErr
* A function, that given a router, will attach this endpoint to it
*/
def route[A, U, I, E, O](e: ServerEndpoint[VertxStreams with WebSockets, Future]): Router => Route = { router =>
mountWithDefaultHandlers(e)(router, extractRouteDefinition(e.endpoint), vertxFutureServerOptions)
val routeDef = extractRouteDefinition(e.endpoint)
optionsRouteIfCORSDefined(e)(router, routeDef, vertxFutureServerOptions)
.foreach(_.handler(endpointHandler(e)))
mountWithDefaultHandlers(e)(router, routeDef, vertxFutureServerOptions)
.handler(endpointHandler(e))
}

Expand All @@ -37,7 +40,10 @@ trait VertxFutureServerInterpreter extends CommonServerInterpreter with VertxErr
* A function, that given a router, will attach this endpoint to it
*/
def blockingRoute(e: ServerEndpoint[VertxStreams with WebSockets, Future]): Router => Route = { router =>
mountWithDefaultHandlers(e)(router, extractRouteDefinition(e.endpoint), vertxFutureServerOptions)
val routeDef = extractRouteDefinition(e.endpoint)
optionsRouteIfCORSDefined(e)(router, routeDef, vertxFutureServerOptions)
.foreach(_.handler(endpointHandler(e)))
mountWithDefaultHandlers(e)(router, routeDef, vertxFutureServerOptions)
.blockingHandler(endpointHandler(e))
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,45 @@
package sttp.tapir.server.vertx.interpreters

import io.vertx.core.http.HttpMethod._
import io.vertx.ext.web.{Route, Router}
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.interceptor.Interceptor
import sttp.tapir.server.interceptor.cors.CORSInterceptor
import sttp.tapir.server.vertx.VertxServerOptions
import sttp.tapir.server.vertx.handlers.attachDefaultHandlers
import sttp.tapir.server.vertx.routing.PathMapping.{RouteDefinition, createRoute}

trait CommonServerInterpreter {

/** Checks if a CORS interceptor is defined in the server options and creates an OPTIONS route if it is.
*
* Vert.x will signal a 405 error if a route matches the path, but doesn’t match the HTTP Method. So if CORS is defined, we additionally
* register OPTIONS route which accepts the preflight requests.
*
* @return
* An optional Route. If a CORS interceptor is defined, an OPTIONS route is created and returned. Otherwise, None is returned.
*/
protected def optionsRouteIfCORSDefined[C, F[_]](
e: ServerEndpoint[C, F]
)(router: Router, routeDef: RouteDefinition, serverOptions: VertxServerOptions[F]): Option[Route] = {
def isCORSInterceptorDefined(interceptors: List[Interceptor[F]]): Boolean = {
interceptors.collectFirst { case ci: CORSInterceptor[F] => ci }.nonEmpty
}

def createOptionsRoute(router: Router, route: RouteDefinition): Option[Route] =
route match {
case (Some(method), path) if Set(GET, HEAD, POST, PUT, DELETE).contains(method) =>
Some(router.options(path))
case (None, path) => Some(router.options(path))
case _ => None
}

if (isCORSInterceptorDefined(serverOptions.interceptors)) {
createOptionsRoute(router, routeDef)
} else
None
}

protected def mountWithDefaultHandlers[C, F[_]](e: ServerEndpoint[C, F])(
router: Router,
routeDef: RouteDefinition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package sttp.tapir.server.vertx.routing

import io.vertx.core.http.HttpMethod
import io.vertx.ext.web.{Route, Router}
import sttp.tapir.{AnyEndpoint, EndpointInput}
import sttp.tapir.EndpointInput.PathCapture
import sttp.tapir.internal._
import sttp.tapir.{AnyEndpoint, EndpointInput}

object PathMapping {

Expand Down Expand Up @@ -49,5 +49,4 @@ object PathMapping {
.mkString
if (path.isEmpty) "/*" else path
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ trait VertxZioServerInterpreter[R] extends CommonServerInterpreter with VertxErr
def route[R2](e: ZServerEndpoint[R2, ZioStreams with WebSockets])(implicit
runtime: Runtime[R & R2]
): Router => Route = { router =>
mountWithDefaultHandlers(e.widen)(router, extractRouteDefinition(e.endpoint), vertxZioServerOptions)
val routeDef = extractRouteDefinition(e.endpoint)
optionsRouteIfCORSDefined(e.widen)(router, routeDef, vertxZioServerOptions)
.foreach(_.handler(endpointHandler(e)))
mountWithDefaultHandlers(e.widen)(router, routeDef, vertxZioServerOptions)
.handler(endpointHandler(e))
}

Expand Down
Loading