From e7afc47ce83e381c3f4aed2b2974e3b3d86a2340 Mon Sep 17 00:00:00 2001 From: Tim Date: Tue, 22 Oct 2024 16:38:51 +1300 Subject: [PATCH] ensure HttpMiddleware is only initialized once (#3808) --- .changeset/large-singers-wait.md | 5 + packages/platform/src/HttpApp.ts | 110 +++++++++++------- .../platform/src/internal/httpMiddleware.ts | 9 +- 3 files changed, 75 insertions(+), 49 deletions(-) create mode 100644 .changeset/large-singers-wait.md diff --git a/.changeset/large-singers-wait.md b/.changeset/large-singers-wait.md new file mode 100644 index 0000000000..c1451c4d22 --- /dev/null +++ b/.changeset/large-singers-wait.md @@ -0,0 +1,5 @@ +--- +"@effect/platform": patch +--- + +ensure HttpMiddleware is only initialized once diff --git a/packages/platform/src/HttpApp.ts b/packages/platform/src/HttpApp.ts index 51806b9416..d5f38522cf 100644 --- a/packages/platform/src/HttpApp.ts +++ b/packages/platform/src/HttpApp.ts @@ -4,11 +4,12 @@ import * as Context from "effect/Context" import * as Effect from "effect/Effect" import * as Exit from "effect/Exit" -import * as FiberRef from "effect/FiberRef" +import type * as FiberRef from "effect/FiberRef" import * as Layer from "effect/Layer" import type * as Option from "effect/Option" import * as Runtime from "effect/Runtime" import * as Scope from "effect/Scope" +import { unify } from "effect/Unify" import type { HttpMiddleware } from "./HttpMiddleware.js" import * as ServerError from "./HttpServerError.js" import * as ServerRequest from "./HttpServerRequest.js" @@ -32,6 +33,8 @@ export type HttpApp */ export type Default = HttpApp +const handledSymbol = Symbol.for("@effect/platform/HttpApp/handled") + /** * @since 1.0.0 * @category combinators @@ -45,53 +48,74 @@ export const toHandled = ( middleware?: HttpMiddleware | undefined ): Effect.Effect> => { const responded = Effect.withFiberRuntime< - void, - never, + ServerResponse.HttpServerResponse, + E | EH | ServerError.ResponseError, R | RH | ServerRequest.HttpServerRequest - >((fiber) => { - let handled = false - const request = Context.unsafeGet(fiber.getFiberRef(FiberRef.currentContext), ServerRequest.HttpServerRequest) - const preprocessResponse = (response: ServerResponse.HttpServerResponse) => { + >((fiber) => + Effect.flatMap(self, (response) => { + const request = Context.unsafeGet(fiber.currentContext, ServerRequest.HttpServerRequest) const handler = fiber.getFiberRef(currentPreResponseHandlers) - return handler._tag === "Some" ? handler.value(request, response) : Effect.succeed(response) - } - const responded = Effect.matchCauseEffect(self, { - onFailure: (cause) => - Effect.flatMap(ServerError.causeResponse(cause), ([response, cause]) => - preprocessResponse(response).pipe( - Effect.flatMap((response) => { - handled = true + if (handler._tag === "None") { + ;(request as any)[handledSymbol] = true + return Effect.as(handleResponse(request, response), response) + } + return Effect.tap(handler.value(request, response), (response) => { + ;(request as any)[handledSymbol] = true + return handleResponse(request, response) + }) + }) + ) + + const withErrorHandling = Effect.catchAllCause( + responded, + (cause) => + Effect.withFiberRuntime< + ServerResponse.HttpServerResponse, + E | EH | ServerError.ResponseError, + ServerRequest.HttpServerRequest | RH + >((fiber) => + Effect.flatMap(ServerError.causeResponse(cause), ([response, cause]) => { + const request = Context.unsafeGet(fiber.currentContext, ServerRequest.HttpServerRequest) + const handler = fiber.getFiberRef(currentPreResponseHandlers) + if (handler._tag === "None") { + ;(request as any)[handledSymbol] = true + return Effect.zipRight(handleResponse(request, response), Effect.failCause(cause)) + } + return Effect.zipRight( + Effect.tap(handler.value(request, response), (response) => { + ;(request as any)[handledSymbol] = true return handleResponse(request, response) }), - Effect.zipRight(Effect.failCause(cause)) - )), - onSuccess: (response) => - Effect.tap( - preprocessResponse(response), - (response) => { - handled = true - return handleResponse(request, response) - } - ) - }) - const withTracer = internalMiddleware.tracer(responded) - if (middleware === undefined) { - return withTracer as any - } - return Effect.matchCauseEffect(middleware(withTracer), { - onFailure: (cause): Effect.Effect => { - if (handled) { - return Effect.void - } - return Effect.matchCauseEffect(ServerError.causeResponse(cause), { - onFailure: (_cause) => handleResponse(request, ServerResponse.empty({ status: 500 })), - onSuccess: ([response]) => handleResponse(request, response) + Effect.failCause(cause) + ) }) - }, - onSuccess: (response): Effect.Effect => handled ? Effect.void : handleResponse(request, response) - }) - }) - return Effect.uninterruptible(Effect.scoped(responded)) + ) + ) + + const withMiddleware = unify( + middleware === undefined ? + internalMiddleware.tracer(withErrorHandling) : + Effect.matchCauseEffect(middleware(internalMiddleware.tracer(withErrorHandling)), { + onFailure: (cause): Effect.Effect => + Effect.withFiberRuntime((fiber) => { + const request = Context.unsafeGet(fiber.currentContext, ServerRequest.HttpServerRequest) + if (handledSymbol in request) { + return Effect.void + } + return Effect.matchCauseEffect(ServerError.causeResponse(cause), { + onFailure: (_cause) => handleResponse(request, ServerResponse.empty({ status: 500 })), + onSuccess: ([response]) => handleResponse(request, response) + }) + }), + onSuccess: (response): Effect.Effect => + Effect.withFiberRuntime((fiber) => { + const request = Context.unsafeGet(fiber.currentContext, ServerRequest.HttpServerRequest) + return handledSymbol in request ? Effect.void : handleResponse(request, response) + }) + }) + ) + + return Effect.uninterruptible(Effect.scoped(withMiddleware)) as any } /** diff --git a/packages/platform/src/internal/httpMiddleware.ts b/packages/platform/src/internal/httpMiddleware.ts index 90291dd11a..aa1ff93f02 100644 --- a/packages/platform/src/internal/httpMiddleware.ts +++ b/packages/platform/src/internal/httpMiddleware.ts @@ -76,8 +76,7 @@ export const withTracerDisabledForUrls = dual< export const logger = make((httpApp) => { let counter = 0 return Effect.withFiberRuntime((fiber) => { - const context = fiber.getFiberRef(FiberRef.currentContext) - const request = Context.unsafeGet(context, ServerRequest.HttpServerRequest) + const request = Context.unsafeGet(fiber.currentContext, ServerRequest.HttpServerRequest) return Effect.withLogSpan( Effect.flatMap(Effect.exit(httpApp), (exit) => { if (fiber.getFiberRef(loggerDisabled)) { @@ -110,8 +109,7 @@ export const logger = make((httpApp) => { /** @internal */ export const tracer = make((httpApp) => Effect.withFiberRuntime((fiber) => { - const context = fiber.getFiberRef(FiberRef.currentContext) - const request = Context.unsafeGet(context, ServerRequest.HttpServerRequest) + const request = Context.unsafeGet(fiber.currentContext, ServerRequest.HttpServerRequest) const disabled = fiber.getFiberRef(currentTracerDisabledWhen)(request) if (disabled) { return httpApp @@ -303,8 +301,7 @@ export const cors = (options?: { return (httpApp: App.Default): App.Default => Effect.withFiberRuntime((fiber) => { - const context = fiber.getFiberRef(FiberRef.currentContext) - const request = Context.unsafeGet(context, ServerRequest.HttpServerRequest) + const request = Context.unsafeGet(fiber.currentContext, ServerRequest.HttpServerRequest) if (request.method === "OPTIONS") { return Effect.succeed(ServerResponse.empty({ status: 204,