Skip to content

Commit

Permalink
Reduce allocations on StreamSubscription
Browse files Browse the repository at this point in the history
  • Loading branch information
BalmungSan committed Feb 26, 2023
1 parent 0c60231 commit ae7684a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 69 deletions.
147 changes: 79 additions & 68 deletions core/jvm/src/main/scala/fs2/interop/flow/StreamSubscription.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ package fs2
package interop
package flow

import cats.effect.kernel.{Async, Deferred, Resource, Outcome}
import cats.effect.std.{Dispatcher, Queue}
import cats.effect.kernel.{Async, Outcome}
import cats.effect.syntax.all._
import cats.syntax.all._

import java.util.concurrent.Flow.{Subscription, Subscriber}
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}

/** Implementation of a [[Subscription]].
*
Expand All @@ -39,31 +39,42 @@ import java.util.concurrent.Flow.{Subscription, Subscriber}
private[flow] final class StreamSubscription[F[_], A] private (
stream: Stream[F, A],
sub: Subscriber[A],
requestDispatcher: Dispatcher[F],
requests: Queue[F, StreamSubscription.Request],
canceled: Deferred[F, Unit]
requests: AtomicLong,
resume: AtomicReference[() => Unit],
cancelToken: AtomicReference[() => Unit],
canceled: F[Unit]
)(implicit F: Async[F])
extends Subscription {
// Ensure we are on a terminal state; i.e. set `canceled`, before signaling the subscriber.
private def onError(ex: Throwable): F[Unit] =
cancelMe >> F.delay(sub.onError(ex))

private def onComplete: F[Unit] =
cancelMe >> F.delay(sub.onComplete())
// Ensure we are on a terminal state; i.e. call `cancel`, before signaling the subscriber.
private def onError(ex: Throwable): Unit = {
cancel()
sub.onError(ex)
}

private[flow] def run: F[Unit] = {
def subscriptionPipe: Pipe[F, A, A] =
in => {
def go(s: Stream[F, A]): Pull[F, A, Unit] =
Pull.eval(requests.take).flatMap {
case StreamSubscription.Request.Infinite =>
Pull.eval(F.delay(requests.get())).flatMap { n =>
if (n == Long.MaxValue)
s.pull.echo

case StreamSubscription.Request.Finite(n) =>
s.pull.take(n).flatMap {
case None => Pull.done
case Some(rem) => go(rem)
}
else if (n == 0)
Pull.eval(F.async_[Unit] { cb =>
// If there aren't more pending request,
// we will wait until the next one.
resume.set(() => cb(Either.unit))
// However, before blocking,
// we must check if it has been a concurrent request.
if (requests.get() > 0)
cb(Either.unit)
// In case it was, we abort the wait.
}) >> go(s)
else
Pull.eval(F.delay(requests.updateAndGet(r => r - n))) >>
s.pull.take(n).flatMap {
case None => Pull.done
case Some(rem) => go(rem)
}
}

go(in).stream
Expand All @@ -77,15 +88,15 @@ private[flow] final class StreamSubscription[F[_], A] private (
.drain

events
.race(canceled.get)
.race(canceled)
.guaranteeCase {
case Outcome.Succeeded(result) =>
result.flatMap {
case Left(()) => onComplete // Events finished normally.
case Left(()) => F.delay(sub.onComplete()) // Events finished normally.
case Right(()) => F.unit // Events was canceled.
}
case Outcome.Errored(ex) => onError(ex)
case Outcome.Canceled() => cancelMe
case Outcome.Errored(ex) => F.delay(onError(ex))
case Outcome.Canceled() => F.delay(cancel())
}
.void
}
Expand All @@ -96,75 +107,75 @@ private[flow] final class StreamSubscription[F[_], A] private (
// ordering is guaranteed by a sequential dispatcher.
// See https://github.com/zainab-ali/fs2-reactive-streams/issues/29
// and https://github.com/zainab-ali/fs2-reactive-streams/issues/46
private def cancelMe: F[Unit] =
canceled.complete(()).void

override def cancel(): Unit =
try
requestDispatcher.unsafeRunAndForget(cancelMe)
catch {
case _: IllegalStateException =>
// Dispatcher already shutdown, we are on terminal state, NOOP.
override def cancel(): Unit = {
var cancelMe = cancelToken.get()

if (cancelMe ne null) {
// Loop until we get the actual cancel callback.
while (cancelMe eq StreamSubscription.defaultCallback)
cancelMe = cancelToken.get()

cancelToken.set(null)
cancelMe()
}
}

override def request(n: Long): Unit = {
val prog =
canceled.tryGet.flatMap {
case None =>
if (n == java.lang.Long.MAX_VALUE)
requests.offer(StreamSubscription.Request.Infinite)
else if (n > 0)
requests.offer(StreamSubscription.Request.Finite(n))
else
onError(
ex = new IllegalArgumentException(s"Invalid number of elements [${n}]")
)

case Some(()) =>
F.unit
override def request(n: Long): Unit =
if (cancelToken.get() ne null) {
if (n <= 0)
onError(
ex = new IllegalArgumentException(s"Invalid number of elements [${n}]")
)
else {
requests.updateAndGet { r =>
val result = r + n
if (result < 0) {
// Overflow
Long.MaxValue
} else {
result
}
}

resume.get().apply()
}
try
requestDispatcher.unsafeRunAndForget(prog)
catch {
case _: IllegalStateException =>
// Dispatcher already shutdown, we are on terminal state, NOOP.
}
}
}

private[flow] object StreamSubscription {

/** Represents a downstream subscriber's request to publish elements. */
private sealed trait Request
private object Request {
case object Infinite extends Request
final case class Finite(n: Long) extends Request
}
private val defaultCallback = () => ()

// Mostly for testing purposes.
def apply[F[_], A](stream: Stream[F, A], subscriber: Subscriber[A])(implicit
F: Async[F]
): Resource[F, StreamSubscription[F, A]] =
): F[StreamSubscription[F, A]] =
(
Dispatcher.sequential[F](await = true),
Resource.eval(Queue.unbounded[F, Request]),
Resource.eval(Deferred[F, Unit])
).mapN { case (requestDispatcher, requests, canceled) =>
F.delay(new AtomicLong(0L)),
F.delay(new AtomicReference(defaultCallback)),
F.delay(new AtomicReference(defaultCallback)).map { cancelToken =>
val canceled = F.async_[Unit] { cb =>
cancelToken.set(() => cb(Either.unit))
}

cancelToken -> canceled
}
).mapN { case (requests, resume, (cancelToken, canceled)) =>
new StreamSubscription(
stream,
subscriber,
requestDispatcher,
requests,
resume,
cancelToken,
canceled
)
}.evalTap { subscription =>
}.flatTap { subscription =>
F.delay(subscriber.onSubscribe(subscription))
}

def subscribe[F[_], A](stream: Stream[F, A], subscriber: Subscriber[A])(implicit
F: Async[F]
): F[Unit] =
apply(stream, subscriber).use { subscription =>
apply(stream, subscriber).flatMap { subscription =>
subscription.run
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class CancellationSpec extends Fs2Suite {
def testStreamSubscription(clue: String)(program: Subscription => Unit): IO[Unit] =
IO(new AtomicBoolean(false))
.flatMap { flag =>
StreamSubscription(s, new Sub(flag)).use { subscription =>
StreamSubscription(s, new Sub(flag)).flatMap { subscription =>
(
subscription.run,
IO(program(subscription))
Expand Down

0 comments on commit ae7684a

Please sign in to comment.