Skip to content

Commit

Permalink
Simplified and refactored Netty-ReactiveStreams integration (#3636)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghik authored Mar 26, 2024
1 parent 3001e6f commit 758d521
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ private[cats] class NettyCatsRequestBody[F[_]: Async](

override implicit val monad: MonadError[F] = new CatsMonadError()

override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): F[Array[Byte]] =
override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]] =
streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte])

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ private[netty] class NettyIdRequestBody(val createFile: ServerRequest => TapirFi
override implicit val monad: MonadError[Id] = idMonad
override val streams: capabilities.Streams[NoStreams] = NoStreams

override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Array[Byte] =
override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Array[Byte] =
SimpleSubscriber.processAllBlocking(publisher, contentLength, maxBytes)

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,9 @@ trait NettyIdServerInterpreter {
ses,
nettyServerOptions.interceptors,
new NettyIdRequestBody(nettyServerOptions.createFile),
new NettyToResponseBody[Id],
new NettyToResponseBody[Id](RunAsync.Id),
nettyServerOptions.deleteFile,
new RunAsync[Id] {
override def apply(f: => Id[Unit]): Unit = {
val _ = f
()
}
}
RunAsync.Id
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package sttp.tapir.server.netty

import sttp.monad.FutureMonad
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.netty.NettyFutureServerInterpreter.FutureRunAsync
import sttp.tapir.server.netty.internal.{NettyFutureRequestBody, NettyServerInterpreter, NettyToResponseBody, RunAsync}

import scala.concurrent.{ExecutionContext, Future}
Expand All @@ -22,9 +21,9 @@ trait NettyFutureServerInterpreter {
ses,
nettyServerOptions.interceptors,
new NettyFutureRequestBody(nettyServerOptions.createFile),
new NettyToResponseBody[Future](),
new NettyToResponseBody[Future](RunAsync.Future),
nettyServerOptions.deleteFile,
FutureRunAsync
RunAsync.Future
)
}
}
Expand All @@ -35,8 +34,4 @@ object NettyFutureServerInterpreter {
override def nettyServerOptions: NettyFutureServerOptions = serverOptions
}
}

private object FutureRunAsync extends RunAsync[Future] {
override def apply(f: => Future[Unit]): Unit = f
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ private[netty] class NettyFutureRequestBody(val createFile: ServerRequest => Fut
override val streams: capabilities.Streams[NoStreams] = NoStreams
override implicit val monad: MonadError[Future] = new FutureMonad()

override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Future[Array[Byte]] =
override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Future[Array[Byte]] =
SimpleSubscriber.processAll(publisher, contentLength, maxBytes)

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Future[Unit] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
* @return
* An effect which finishes with a single array of all collected bytes.
*/
def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): F[Array[Byte]]
def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]]

/** Backend-specific way to process all elements emitted by a Publisher[HttpContent] and write their bytes into a file.
*
Expand All @@ -52,35 +52,39 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
*/
def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit]

override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): F[RawValue[RAW]] = {
bodyType match {
case RawBodyType.StringBody(charset) => readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new String(bs, charset)))
case RawBodyType.ByteArrayBody =>
readAllBytes(serverRequest, maxBytes).map(RawValue(_))
case RawBodyType.ByteBufferBody =>
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs)))
case RawBodyType.InputStreamBody =>
// Possibly can be optimized to avoid loading all data eagerly into memory
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs)))
case RawBodyType.InputStreamRangeBody =>
// Possibly can be optimized to avoid loading all data eagerly into memory
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs))))
case RawBodyType.FileBody =>
for {
file <- createFile(serverRequest)
_ <- writeToFile(serverRequest, file, maxBytes)
} yield RawValue(FileRange(file), Seq(FileRange(file)))
case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException())
}
override def toRaw[RAW](
serverRequest: ServerRequest,
bodyType: RawBodyType[RAW],
maxBytes: Option[Long]
): F[RawValue[RAW]] = bodyType match {
case RawBodyType.StringBody(charset) =>
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new String(bs, charset)))
case RawBodyType.ByteArrayBody =>
readAllBytes(serverRequest, maxBytes).map(RawValue(_))
case RawBodyType.ByteBufferBody =>
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs)))
case RawBodyType.InputStreamBody =>
// Possibly can be optimized to avoid loading all data eagerly into memory
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs)))
case RawBodyType.InputStreamRangeBody =>
// Possibly can be optimized to avoid loading all data eagerly into memory
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs))))
case RawBodyType.FileBody =>
for {
file <- createFile(serverRequest)
_ <- writeToFile(serverRequest, file, maxBytes)
} yield RawValue(FileRange(file), Seq(FileRange(file)))
case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException)
}

private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] =
serverRequest.underlying match {
case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // Empty request
monad.unit(Array.empty[Byte])
case req: StreamedHttpRequest =>
val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toInt)
val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toLong)
publisherToBytes(req, contentLength, maxBytes)
case other => monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}"))
case other =>
monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import java.nio.charset.Charset
* Publishers to integrate responses like InputStreamBody, InputStreamRangeBody or FileBody with Netty reactive extensions. Other kinds of
* raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher.
*/
private[netty] class NettyToResponseBody[F[_]](implicit me: MonadError[F]) extends ToResponseBody[NettyResponse, NoStreams] {
private[netty] class NettyToResponseBody[F[_]](runAsync: RunAsync[F])(implicit me: MonadError[F])
extends ToResponseBody[NettyResponse, NoStreams] {

override val streams: capabilities.Streams[NoStreams] = NoStreams

override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = {
Expand Down Expand Up @@ -54,7 +56,7 @@ private[netty] class NettyToResponseBody[F[_]](implicit me: MonadError[F]) exten
}

private def wrap(streamRange: InputStreamRange): Publisher[HttpContent] = {
new InputStreamPublisher[F](streamRange, DefaultChunkSize)
new InputStreamPublisher[F](streamRange, DefaultChunkSize, runAsync)
}

private def wrap(fileRange: FileRange): Publisher[HttpContent] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
package sttp.tapir.server.netty.internal

import scala.concurrent.Future

trait RunAsync[F[_]] {
def apply(f: => F[Unit]): Unit
}
object RunAsync {
type Id[A] = A

final val Id: RunAsync[Id] = new RunAsync[Id] {
override def apply(f: => Id[Unit]): Unit = f
}

final val Future: RunAsync[Future] = new RunAsync[Future] {
override def apply(f: => Future[Unit]): Unit =
f: Unit
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import org.reactivestreams.{Publisher, Subscription}

import java.nio.channels.AsynchronousFileChannel
import java.nio.file.{Path, StandardOpenOption}
import scala.concurrent.{Future, Promise}
import java.util.concurrent.LinkedBlockingQueue
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future, Promise}

/** A Reactive Streams subscriber which receives chunks of bytes and writes them to a file.
*/
Expand All @@ -22,11 +23,7 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon
/** Used to signal completion, so that external code can represent writing to a file as Future[Unit] */
private val resultPromise = Promise[Unit]()

/** An alternative way to signal completion, so that non-effectful servers can await on the response (like netty-loom) */
private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Unit]]()

override def future: Future[Unit] = resultPromise.future
private def waitForResultBlocking(): Either[Throwable, Unit] = resultBlockingQueue.take()

override def onSubscribe(s: Subscription): Unit = {
this.subscription = s
Expand Down Expand Up @@ -58,13 +55,11 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon

override def onError(t: Throwable): Unit = {
fileChannel.close()
resultBlockingQueue.add(Left(t))
resultPromise.failure(t)
}

override def onComplete(): Unit = {
fileChannel.close()
val _ = resultBlockingQueue.add(Right(()))
resultPromise.success(())
}
}
Expand All @@ -76,9 +71,6 @@ object FileWriterSubscriber {
subscriber.future
}

def processAllBlocking(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Unit = {
val subscriber = new FileWriterSubscriber(path)
publisher.subscribe(maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber))
subscriber.waitForResultBlocking().left.foreach(e => throw e)
}
def processAllBlocking(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Unit =
Await.result(processAll(publisher, path, maxBytes), Duration.Inf)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@ package sttp.tapir.server.netty.internal.reactivestreams
import io.netty.buffer.Unpooled
import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent}
import org.reactivestreams.{Publisher, Subscriber, Subscription}
import sttp.monad.MonadError
import sttp.monad.syntax._
import sttp.tapir.InputStreamRange
import sttp.tapir.server.netty.internal.RunAsync

import java.io.InputStream
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}
import scala.util.Try
import sttp.monad.MonadError
import sttp.monad.syntax._

class InputStreamPublisher[F[_]](range: InputStreamRange, chunkSize: Int)(implicit monad: MonadError[F]) extends Publisher[HttpContent] {
class InputStreamPublisher[F[_]](
range: InputStreamRange,
chunkSize: Int,
runAsync: RunAsync[F]
)(implicit
monad: MonadError[F]
) extends Publisher[HttpContent] {
override def subscribe(subscriber: Subscriber[_ >: HttpContent]): Unit = {
if (subscriber == null) throw new NullPointerException("Subscriber cannot be null")
val subscription = new InputStreamSubscription(subscriber, range, chunkSize)
Expand Down Expand Up @@ -46,7 +53,9 @@ class InputStreamPublisher[F[_]](range: InputStreamRange, chunkSize: Int)(implic
case _ => chunkSize
}

val _ = monad
// Note: the effect F may be Id, in which case everything here will be synchronous and blocking
// (which technically is against the reactive streams spec).
runAsync(monad
.blocking(
stream.readNBytes(expectedBytes)
)
Expand All @@ -69,11 +78,10 @@ class InputStreamPublisher[F[_]](range: InputStreamRange, chunkSize: Int)(implic
}
}
.handleError {
case e => {
case e =>
val _ = Try(stream.close())
monad.unit(subscriber.onError(e))
}
}
})
}
}

Expand Down
Loading

0 comments on commit 758d521

Please sign in to comment.