Skip to content

Commit

Permalink
Merge pull request #537 from SystemFw/memoize-cancelation
Browse files Browse the repository at this point in the history
Concurrent.memoize is well behaved under cancelation
  • Loading branch information
djspiewak committed May 26, 2019
2 parents 3abef46 + 2b88173 commit 9cb3e6c
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 17 deletions.
74 changes: 58 additions & 16 deletions core/shared/src/main/scala/cats/effect/Concurrent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import cats.effect.IO.{Delay, Pure, RaiseError}
import cats.effect.internals.Callback.{rightUnit, successUnit}
import cats.effect.internals.{CancelableF, IORunLoop}
import cats.effect.internals.TrampolineEC.immediate
import cats.effect.implicits._
import cats.syntax.all._

import scala.annotation.implicitNotFound
Expand Down Expand Up @@ -353,26 +354,67 @@ object Concurrent {
}

/**
* Lazily memoizes `f`. For every time the returned `F[F[A]]` is
* bound, the effect `f` will be performed at most once (when the
* inner `F[A]` is bound the first time).
* Lazily memoizes `f`. Assuming no cancelation happens, the effect
* `f` will be performed at most once for every time the returned
* `F[F[A]]` is bound (when the inner `F[A]` is bound the first
* time).
*
* Note: `start` can be used for eager memoization.
* If you try to cancel an inner `F[A]`, `f` is only interrupted if
* there are no other active subscribers, whereas if there are, `f`
* keeps running in the background.
*
* If `f` is successfully canceled, the next time an inner `F[A]`
* is bound `f` will be restarted again. Note that this can mean
* the effects of `f` happen more than once.
*
* You can look at `Async.memoize` for a version of this function
* which does not allow cancelation.
*/
def memoize[F[_], A](f: F[A])(implicit F: Concurrent[F]): F[F[A]] =
Ref.of[F, Option[Deferred[F, Either[Throwable, A]]]](None).map { ref =>
Deferred[F, Either[Throwable, A]].flatMap { d =>
ref
.modify {
case None =>
Some(d) -> f.attempt.flatTap(d.complete)
case s @ Some(other) =>
s -> other.get
}
.flatten
.rethrow
def memoize[F[_], A](f: F[A])(implicit F: Concurrent[F]): F[F[A]] = {
sealed trait State
case class Subs(n: Int) extends State
case object Done extends State

case class Fetch(state: State, v: Deferred[F, Either[Throwable, A]], stop: Deferred[F, F[Unit]])

Ref[F].of(Option.empty[Fetch]).map { state =>
(Deferred[F, Either[Throwable, A]] product Deferred[F, F[Unit]]).flatMap {
case (v, stop) =>
def endState(ec: ExitCase[Throwable]) =
state.modify {
case None =>
throw new AssertionError("unreachable")
case s @ Some(Fetch(Done, _, _)) =>
s -> F.unit
case Some(Fetch(Subs(n), v, stop)) =>
if (ec == ExitCase.Canceled && n == 1)
None -> stop.get.flatten
else if (ec == ExitCase.Canceled)
Fetch(Subs(n - 1), v, stop).some -> F.unit
else
Fetch(Done, v, stop).some -> F.unit
}.flatten

def fetch =
f.attempt
.flatMap(v.complete)
.start
.flatMap(fiber => stop.complete(fiber.cancel))

state
.modify {
case s @ Some(Fetch(Done, v, _)) =>
s -> v.get
case Some(Fetch(Subs(n), v, stop)) =>
Fetch(Subs(n + 1), v, stop).some -> v.get.guaranteeCase(endState)
case None =>
Fetch(Subs(1), v, stop).some -> fetch.bracketCase(_ => v.get) { case (_, ec) => endState(ec) }
}
.flatten
.rethrow
}
}
}

/**
* Returns an effect that either completes with the result of the source within
Expand Down
113 changes: 112 additions & 1 deletion laws/shared/src/test/scala/cats/effect/MemoizeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package cats
package effect

import cats.implicits._
import cats.effect.concurrent.Ref
import cats.effect.concurrent.{Ref, Deferred}
import scala.concurrent.duration._
import scala.util.{Success}

Expand Down Expand Up @@ -92,4 +92,115 @@ class MemoizeTests extends BaseTestsSuite {
Concurrent.memoize(fa).flatten <-> fa
}
}

testAsync("Memoized effects can be canceled when there are no other active subscribers (1)") { implicit ec =>
implicit val cs = ec.contextShift[IO]
implicit val timer = ec.timer[IO]

val prog = for {
completed <- Ref[IO].of(false)
action = IO.sleep(200.millis) >> completed.set(true)
memoized <- Concurrent.memoize(action)
fiber <- memoized.start
_ <- IO.sleep(100.millis)
_ <- fiber.cancel
_ <- IO.sleep(300.millis)
res <- completed.get
} yield res

val result = prog.unsafeToFuture()
ec.tick(500.millis)
result.value shouldBe Some(Success(false))
}

testAsync("Memoized effects can be canceled when there are no other active subscribers (2)") { implicit ec =>
implicit val cs = ec.contextShift[IO]
implicit val timer = ec.timer[IO]

val prog = for {
completed <- Ref[IO].of(false)
action = IO.sleep(300.millis) >> completed.set(true)
memoized <- Concurrent.memoize(action)
fiber1 <- memoized.start
_ <- IO.sleep(100.millis)
fiber2 <- memoized.start
_ <- IO.sleep(100.millis)
_ <- fiber2.cancel
_ <- fiber1.cancel
_ <- IO.sleep(400.millis)
res <- completed.get
} yield res

val result = prog.unsafeToFuture()
ec.tick(600.millis)
result.value shouldBe Some(Success(false))
}

testAsync("Memoized effects can be canceled when there are no other active subscribers (3)") { implicit ec =>
implicit val cs = ec.contextShift[IO]
implicit val timer = ec.timer[IO]

val prog = for {
completed <- Ref[IO].of(false)
action = IO.sleep(300.millis) >> completed.set(true)
memoized <- Concurrent.memoize(action)
fiber1 <- memoized.start
_ <- IO.sleep(100.millis)
fiber2 <- memoized.start
_ <- IO.sleep(100.millis)
_ <- fiber1.cancel
_ <- fiber2.cancel
_ <- IO.sleep(400.millis)
res <- completed.get
} yield res

val result = prog.unsafeToFuture()
ec.tick(600.millis)
result.value shouldBe Some(Success(false))
}

testAsync("Running a memoized effect after it was previously canceled reruns it") { implicit ec =>
implicit val cs = ec.contextShift[IO]
implicit val timer = ec.timer[IO]

val prog = for {
started <- Ref[IO].of(0)
completed <- Ref[IO].of(0)
action = started.update(_ + 1) >> timer.sleep(200.millis) >> completed.update(_ + 1)
memoized <- Concurrent.memoize(action)
fiber <- memoized.start
_ <- IO.sleep(100.millis)
_ <- fiber.cancel
_ <- memoized.timeout(1.second)
v1 <- started.get
v2 <- completed.get
} yield v1 -> v2

val result = prog.unsafeToFuture()
ec.tick(500.millis)
result.value shouldBe Some(Success(2 -> 1))
}

testAsync("Attempting to cancel a memoized effect with active subscribers is a no-op") { implicit ec =>
implicit val cs = ec.contextShift[IO]
implicit val timer = ec.timer[IO]

val prog = for {
condition <- Deferred[IO, Unit]
action = IO.sleep(200.millis) >> condition.complete(())
memoized <- Concurrent.memoize(action)
fiber1 <- memoized.start
_ <- IO.sleep(50.millis)
fiber2 <- memoized.start
_ <- IO.sleep(50.millis)
_ <- fiber1.cancel
_ <- fiber2.join // Make sure no exceptions are swallowed by start
v <- condition.get.timeout(1.second).as(true)
} yield v

val result = prog.unsafeToFuture()
ec.tick(500.millis)
result.value shouldBe Some(Success(true))
}
}

0 comments on commit 9cb3e6c

Please sign in to comment.