Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cancelation leak in fromFuture, fromFutureCancelable #3892

Merged
merged 12 commits into from
Nov 23, 2023
18 changes: 11 additions & 7 deletions kernel/shared/src/main/scala/cats/effect/kernel/Async.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,21 +212,25 @@ trait Async[F[_]] extends AsyncPlatform[F] with Sync[F] with Temporal[F] {
* [[fromFutureCancelable]] for a cancelable version
*/
def fromFuture[A](fut: F[Future[A]]): F[A] =
flatMap(fut) { f =>
flatMap(executionContext) { implicit ec =>
async_[A](cb => f.onComplete(t => cb(t.toEither)))
flatMap(executionContext) { implicit ec =>
uncancelable { poll =>
flatMap(poll(fut)) { f => async_[A](cb => f.onComplete(t => cb(t.toEither))) }
}
}

/**
* Like [[fromFuture]], but is cancelable via the provided finalizer.
*/
def fromFutureCancelable[A](futCancel: F[(Future[A], F[Unit])]): F[A] =
flatMap(futCancel) {
case (fut, fin) =>
flatMap(executionContext) { implicit ec =>
async[A](cb => as(delay(fut.onComplete(t => cb(t.toEither))), Some(fin)))
flatMap(executionContext) { implicit ec =>
uncancelable { poll =>
flatMap(poll(futCancel)) {
case (fut, fin) =>
onCancel(
poll(async[A](cb => as(delay(fut.onComplete(t => cb(t.toEither))), Some(unit)))),
fin)
}
}
}

/**
Expand Down
79 changes: 78 additions & 1 deletion tests/shared/src/test/scala/cats/effect/kernel/AsyncSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@ package kernel
import cats.{Eq, Order, StackSafeMonad}
import cats.arrow.FunctionK
import cats.effect.laws.AsyncTests
import cats.effect.testkit.TestControl
import cats.effect.unsafe.IORuntimeConfig
import cats.laws.discipline.arbitrary._

import org.scalacheck.{Arbitrary, Cogen, Prop}
import org.scalacheck.Arbitrary.arbitrary
import org.typelevel.discipline.specs2.mutable.Discipline

import scala.concurrent.ExecutionContext
import scala.concurrent.{ExecutionContext, Promise}
import scala.concurrent.duration._

import java.util.concurrent.atomic.AtomicBoolean

class AsyncSpec extends BaseSpec with Discipline {

// we just need this because of the laws testing, since the prop runs can interfere with each other
Expand All @@ -43,6 +47,79 @@ class AsyncSpec extends BaseSpec with Discipline {
) /*(Parameters(seed = Some(Seed.fromBase64("ZxDXpm7_3Pdkl-Fvt8M90Cxfam9wKuzcifQ1QsIJxND=").get)))*/
}

"fromFuture" should {
"backpressure on cancelation" in real {
// a non-cancelable, never-completing Future
def mkf() = Promise[Unit]().future

def go = for {
started <- IO(new AtomicBoolean)
fiber <- IO.fromFuture {
IO {
started.set(true)
mkf()
}
}.start
_ <- IO.cede.whileM_(IO(!started.get))
_ <- fiber.cancel
} yield ()

TestControl
.executeEmbed(go, IORuntimeConfig(1, 2))
.as(false)
.recover { case _: TestControl.NonTerminationException => true }
.replicateA(1000)
.map(_.forall(identity(_)))
}

}

"fromFutureCancelable" should {

"cancel on fiber cancelation" in real {
val smallDelay: IO[Unit] = IO.sleep(10.millis)
def mkf() = Promise[Unit]().future

val go = for {
canceled <- IO(new AtomicBoolean)
fiber <- IO.fromFutureCancelable {
IO(mkf()).map(f => f -> IO(canceled.set(true)))
}.start
_ <- smallDelay
_ <- fiber.cancel
res <- IO(canceled.get() mustEqual true)
} yield res

TestControl.executeEmbed(go, IORuntimeConfig(1, 2)).replicateA(1000)

}

"backpressure on cancelation" in real {
// a non-cancelable, never-completing Future
def mkf() = Promise[Unit]().future
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def mkf() = Promise[Unit]().future
def mkf() = Future.never


val go = for {
started <- IO(new AtomicBoolean)
fiber <- IO.fromFutureCancelable {
IO {
started.set(true)
mkf()
}.map(f => f -> IO.never)
}.start
_ <- IO.cede.whileM_(IO(!started.get))
_ <- fiber.cancel
} yield ()

TestControl
.executeEmbed(go, IORuntimeConfig(1, 2))
.as(false)
.recover { case _: TestControl.NonTerminationException => true }
.replicateA(1000)
.map(_.forall(identity(_)))
}

}

final class AsyncIO[A](val io: IO[A])

implicit def asyncForAsyncIO: Async[AsyncIO] = new Async[AsyncIO]
Expand Down
Loading