Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplified and refactored Netty-ReactiveStreams integration #3636

Merged
merged 7 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ private[cats] class NettyCatsRequestBody[F[_]: Async](

override implicit val monad: MonadError[F] = new CatsMonadError()

override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): F[Array[Byte]] =
override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]] =
streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte])

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ private[netty] class NettyIdRequestBody(val createFile: ServerRequest => TapirFi
override implicit val monad: MonadError[Id] = idMonad
override val streams: capabilities.Streams[NoStreams] = NoStreams

override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Array[Byte] =
override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Array[Byte] =
SimpleSubscriber.processAllBlocking(publisher, contentLength, maxBytes)

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,9 @@ trait NettyIdServerInterpreter {
ses,
nettyServerOptions.interceptors,
new NettyIdRequestBody(nettyServerOptions.createFile),
new NettyToResponseBody[Id],
new NettyToResponseBody[Id](RunAsync.Id),
nettyServerOptions.deleteFile,
new RunAsync[Id] {
override def apply(f: => Id[Unit]): Unit = {
val _ = f
()
}
}
RunAsync.Id
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package sttp.tapir.server.netty

import sttp.monad.FutureMonad
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.netty.NettyFutureServerInterpreter.FutureRunAsync
import sttp.tapir.server.netty.internal.{NettyFutureRequestBody, NettyServerInterpreter, NettyToResponseBody, RunAsync}

import scala.concurrent.{ExecutionContext, Future}
Expand All @@ -22,9 +21,9 @@ trait NettyFutureServerInterpreter {
ses,
nettyServerOptions.interceptors,
new NettyFutureRequestBody(nettyServerOptions.createFile),
new NettyToResponseBody[Future](),
new NettyToResponseBody[Future](RunAsync.Future),
nettyServerOptions.deleteFile,
FutureRunAsync
RunAsync.Future
)
}
}
Expand All @@ -35,8 +34,4 @@ object NettyFutureServerInterpreter {
override def nettyServerOptions: NettyFutureServerOptions = serverOptions
}
}

private object FutureRunAsync extends RunAsync[Future] {
override def apply(f: => Future[Unit]): Unit = f
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ private[netty] class NettyFutureRequestBody(val createFile: ServerRequest => Fut
override val streams: capabilities.Streams[NoStreams] = NoStreams
override implicit val monad: MonadError[Future] = new FutureMonad()

override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): Future[Array[Byte]] =
override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Future[Array[Byte]] =
SimpleSubscriber.processAll(publisher, contentLength, maxBytes)

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Future[Unit] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
* @return
* An effect which finishes with a single array of all collected bytes.
*/
def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Int], maxBytes: Option[Long]): F[Array[Byte]]
def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]]

/** Backend-specific way to process all elements emitted by a Publisher[HttpContent] and write their bytes into a file.
*
Expand All @@ -52,35 +52,39 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
*/
def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit]

override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): F[RawValue[RAW]] = {
bodyType match {
case RawBodyType.StringBody(charset) => readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new String(bs, charset)))
case RawBodyType.ByteArrayBody =>
readAllBytes(serverRequest, maxBytes).map(RawValue(_))
case RawBodyType.ByteBufferBody =>
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs)))
case RawBodyType.InputStreamBody =>
// Possibly can be optimized to avoid loading all data eagerly into memory
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs)))
case RawBodyType.InputStreamRangeBody =>
// Possibly can be optimized to avoid loading all data eagerly into memory
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs))))
case RawBodyType.FileBody =>
for {
file <- createFile(serverRequest)
_ <- writeToFile(serverRequest, file, maxBytes)
} yield RawValue(FileRange(file), Seq(FileRange(file)))
case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException())
}
override def toRaw[RAW](
serverRequest: ServerRequest,
bodyType: RawBodyType[RAW],
maxBytes: Option[Long]
): F[RawValue[RAW]] = bodyType match {
case RawBodyType.StringBody(charset) =>
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new String(bs, charset)))
case RawBodyType.ByteArrayBody =>
readAllBytes(serverRequest, maxBytes).map(RawValue(_))
case RawBodyType.ByteBufferBody =>
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(ByteBuffer.wrap(bs)))
case RawBodyType.InputStreamBody =>
// Possibly can be optimized to avoid loading all data eagerly into memory
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(new ByteArrayInputStream(bs)))
case RawBodyType.InputStreamRangeBody =>
// Possibly can be optimized to avoid loading all data eagerly into memory
readAllBytes(serverRequest, maxBytes).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs))))
case RawBodyType.FileBody =>
for {
file <- createFile(serverRequest)
_ <- writeToFile(serverRequest, file, maxBytes)
} yield RawValue(FileRange(file), Seq(FileRange(file)))
case _: RawBodyType.MultipartBody => monad.error(new UnsupportedOperationException)
}

private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] =
serverRequest.underlying match {
case r: FullHttpRequest if r.content() == Unpooled.EMPTY_BUFFER => // Empty request
monad.unit(Array.empty[Byte])
case req: StreamedHttpRequest =>
val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toInt)
val contentLength = Option(req.headers().get(HeaderNames.ContentLength)).map(_.toLong)
publisherToBytes(req, contentLength, maxBytes)
case other => monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}"))
case other =>
monad.error(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import java.nio.charset.Charset
* Publishers to integrate responses like InputStreamBody, InputStreamRangeBody or FileBody with Netty reactive extensions. Other kinds of
* raw responses like directly available String, ByteArray or ByteBuffer can be returned without wrapping into a Publisher.
*/
private[netty] class NettyToResponseBody[F[_]](implicit me: MonadError[F]) extends ToResponseBody[NettyResponse, NoStreams] {
private[netty] class NettyToResponseBody[F[_]](runAsync: RunAsync[F])(implicit me: MonadError[F])
extends ToResponseBody[NettyResponse, NoStreams] {

override val streams: capabilities.Streams[NoStreams] = NoStreams

override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = {
Expand Down Expand Up @@ -54,7 +56,7 @@ private[netty] class NettyToResponseBody[F[_]](implicit me: MonadError[F]) exten
}

private def wrap(streamRange: InputStreamRange): Publisher[HttpContent] = {
new InputStreamPublisher[F](streamRange, DefaultChunkSize)
new InputStreamPublisher[F](streamRange, DefaultChunkSize, runAsync)
}

private def wrap(fileRange: FileRange): Publisher[HttpContent] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
package sttp.tapir.server.netty.internal

import scala.concurrent.Future

trait RunAsync[F[_]] {
def apply(f: => F[Unit]): Unit
}
object RunAsync {
type Id[A] = A

final val Id: RunAsync[Id] = new RunAsync[Id] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity: anything better in final val Id over object Id extends RunAsync[Id]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, no strong opinions.

It arguably generates somewhat simpler bytecode, i.e.

  • no lazy initialization associated with object
  • no separate class associated with object (there's an anonymous class but it will probably be compiled to a lambda so it won't have a classfile)

It also works more naturally with type inference:

  • when it's a val, type of Id will be inferred as RunAsync[Id] (unless explicitly requested to be typed as Id.type)
  • when it's an object, it will be inferred as Id.type, with API potentially extended over RunAsync[Id]

Neither the laziness nor introducing a subtype was my intention, so a val is much closer to expressing exactly what I want, which is just having a plain implementation of RunAsync.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, thanks :)

override def apply(f: => Id[Unit]): Unit = f
}

final val Future: RunAsync[Future] = new RunAsync[Future] {
override def apply(f: => Future[Unit]): Unit =
f: Unit
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import org.reactivestreams.{Publisher, Subscription}

import java.nio.channels.AsynchronousFileChannel
import java.nio.file.{Path, StandardOpenOption}
import scala.concurrent.{Future, Promise}
import java.util.concurrent.LinkedBlockingQueue
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future, Promise}

/** A Reactive Streams subscriber which receives chunks of bytes and writes them to a file.
*/
Expand All @@ -22,11 +23,7 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon
/** Used to signal completion, so that external code can represent writing to a file as Future[Unit] */
private val resultPromise = Promise[Unit]()

/** An alternative way to signal completion, so that non-effectful servers can await on the response (like netty-loom) */
private val resultBlockingQueue = new LinkedBlockingQueue[Either[Throwable, Unit]]()

override def future: Future[Unit] = resultPromise.future
private def waitForResultBlocking(): Either[Throwable, Unit] = resultBlockingQueue.take()

override def onSubscribe(s: Subscription): Unit = {
this.subscription = s
Expand Down Expand Up @@ -58,13 +55,11 @@ class FileWriterSubscriber(path: Path) extends PromisingSubscriber[Unit, HttpCon

override def onError(t: Throwable): Unit = {
fileChannel.close()
resultBlockingQueue.add(Left(t))
resultPromise.failure(t)
}

override def onComplete(): Unit = {
fileChannel.close()
val _ = resultBlockingQueue.add(Right(()))
resultPromise.success(())
}
}
Expand All @@ -76,9 +71,6 @@ object FileWriterSubscriber {
subscriber.future
}

def processAllBlocking(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Unit = {
val subscriber = new FileWriterSubscriber(path)
publisher.subscribe(maxBytes.map(new LimitedLengthSubscriber(_, subscriber)).getOrElse(subscriber))
subscriber.waitForResultBlocking().left.foreach(e => throw e)
}
def processAllBlocking(publisher: Publisher[HttpContent], path: Path, maxBytes: Option[Long]): Unit =
Await.result(processAll(publisher, path, maxBytes), Duration.Inf)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@ package sttp.tapir.server.netty.internal.reactivestreams
import io.netty.buffer.Unpooled
import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent}
import org.reactivestreams.{Publisher, Subscriber, Subscription}
import sttp.monad.MonadError
import sttp.monad.syntax._
import sttp.tapir.InputStreamRange
import sttp.tapir.server.netty.internal.RunAsync

import java.io.InputStream
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}
import scala.util.Try
import sttp.monad.MonadError
import sttp.monad.syntax._

class InputStreamPublisher[F[_]](range: InputStreamRange, chunkSize: Int)(implicit monad: MonadError[F]) extends Publisher[HttpContent] {
class InputStreamPublisher[F[_]](
range: InputStreamRange,
chunkSize: Int,
runAsync: RunAsync[F]
)(implicit
monad: MonadError[F]
) extends Publisher[HttpContent] {
override def subscribe(subscriber: Subscriber[_ >: HttpContent]): Unit = {
if (subscriber == null) throw new NullPointerException("Subscriber cannot be null")
val subscription = new InputStreamSubscription(subscriber, range, chunkSize)
Expand Down Expand Up @@ -46,7 +53,7 @@ class InputStreamPublisher[F[_]](range: InputStreamRange, chunkSize: Int)(implic
case _ => chunkSize
}

val _ = monad
runAsync(monad
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: I think we still should add a comment that in case of Id this isn't really async and we are aware that this case violates reactive streams.

.blocking(
stream.readNBytes(expectedBytes)
)
Expand All @@ -69,11 +76,10 @@ class InputStreamPublisher[F[_]](range: InputStreamRange, chunkSize: Int)(implic
}
}
.handleError {
case e => {
case e =>
val _ = Try(stream.close())
monad.unit(subscriber.onError(e))
}
}
})
}
}

Expand Down
Loading
Loading