Skip to content

Commit

Permalink
Merge pull request #2653 from vasilmkd/concurrently-transformers-25x
Browse files Browse the repository at this point in the history
Fix `Stream#concurrently` for short-circuiting monad transformers - `2.5.x`
  • Loading branch information
mpilquist authored Sep 30, 2021
2 parents 7c6a7b9 + 47bc084 commit 24282e2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 16 deletions.
29 changes: 14 additions & 15 deletions core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -545,25 +545,24 @@ final class Stream[+F[_], +O] private[fs2] (private val free: FreeC[F, O, Unit])
)(implicit F: Concurrent[F2]): Stream[F2, O] = {
val fstream: F2[Stream[F2, O]] = for {
interrupt <- Deferred[F2, Unit]
doneR <- Deferred[F2, Either[Throwable, Unit]]
backResult <- Deferred[F2, Either[Throwable, Unit]]
} yield {
def runR: F2[Unit] =
that.interruptWhen(interrupt.get.attempt).compile.drain.attempt.flatMap { r =>
doneR.complete(r) >> {
if (r.isLeft)
interrupt
.complete(())
.attempt
.void // interrupt only if this failed otherwise give change to `this` to finalize
else F.unit
}
}
def watch[A](str: Stream[F2, A]) = str.interruptWhen(interrupt.get.attempt)

val compileBack: F2[Unit] = watch(that).compile.drain.guaranteeCase {
// Pass the result of backstream completion in the backResult deferred.
// If result of back-stream was failed, interrupt fore. Otherwise, let it be
case ExitCase.Error(t) =>
backResult.complete(Left(t)).attempt >> interrupt.complete(()).handleError(_ => ())
case _ => backResult.complete(Right(())).handleError(_ => ())
}

// stop background process but await for it to finalise with a result
val stopBack: F2[Unit] = interrupt.complete(()).attempt >> doneR.get.flatMap(F.fromEither)
// We use F.fromEither to bring errors from the back into the fore
val stopBack: F2[Unit] =
interrupt.complete(()).attempt >> backResult.get.flatMap(F.fromEither)

Stream.bracket(F.start(runR))(_ => stopBack) >>
this.interruptWhen(interrupt.get.attempt)
Stream.bracket(F.start(compileBack))(_ => stopBack) >> watch(this)
}

Stream.eval(fstream).flatten
Expand Down
19 changes: 19 additions & 0 deletions core/shared/src/test/scala/fs2/StreamConcurrentlySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package fs2

import scala.concurrent.duration._

import cats.data.EitherT
import cats.effect.IO
import cats.effect.concurrent.{Deferred, Ref, Semaphore}
import cats.syntax.all._
Expand Down Expand Up @@ -146,4 +147,22 @@ class StreamConcurrentlySuite extends Fs2Suite {
}
}

test("background stream completes with short-circuiting transformers") {
Stream(1, 2, 3)
.concurrently(Stream.eval(EitherT.leftT[IO, Int]("left")))
.compile
.lastOrError
.value
.assertEquals(Right(3))
}

test("foreground stream short-circuits") {
Stream(1, 2, 3)
.evalMap(n => EitherT.cond[IO](n % 2 == 0, n, "left"))
.concurrently(Stream.eval(EitherT.rightT[IO, String](42)))
.compile
.lastOrError
.value
.assertEquals(Left("left"))
}
}
14 changes: 13 additions & 1 deletion io/src/test/scala/fs2/io/IoSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
package fs2
package io

import cats.data.EitherT
import cats.effect.{Blocker, IO, Resource}
import fs2.Fs2Suite
import org.scalacheck.{Arbitrary, Gen, Shrink}
Expand Down Expand Up @@ -215,7 +216,8 @@ class IoSuite extends Fs2Suite {
}

test("can copy more than Int.MaxValue bytes") {
// Unit test adapted from the original issue reproduction at https://github.com/mrdziuban/fs2-writeOutputStream.
// Unit test adapted from the original issue reproduction at
// https://github.com/mrdziuban/fs2-writeOutputStream.

val byteStream =
Stream
Expand All @@ -238,6 +240,16 @@ class IoSuite extends Fs2Suite {
.drain
}
}

test("works with short-circuiting monad transformers") {
// Unit test adapted from the original issue reproduction at
// https://github.com/mrdziuban/fs2-readOutputStream-EitherT.

Blocker[IO].use { blocker =>
readOutputStream(blocker, 1)(_ => EitherT.left[Unit](IO.unit)).compile.drain.value
.timeout(5.seconds)
}
}
}

group("unsafeReadInputStream") {
Expand Down

0 comments on commit 24282e2

Please sign in to comment.