Skip to content

Commit

Permalink
Merge pull request #3892 from TimWSpence/3891-fix-future-leak
Browse files Browse the repository at this point in the history
3891 fix future leak
  • Loading branch information
djspiewak authored Nov 23, 2023
2 parents f6a6f18 + 77c404a commit 5794542
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 8 deletions.
18 changes: 11 additions & 7 deletions kernel/shared/src/main/scala/cats/effect/kernel/Async.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,21 +212,25 @@ trait Async[F[_]] extends AsyncPlatform[F] with Sync[F] with Temporal[F] {
* [[fromFutureCancelable]] for a cancelable version
*/
def fromFuture[A](fut: F[Future[A]]): F[A] =
flatMap(fut) { f =>
flatMap(executionContext) { implicit ec =>
async_[A](cb => f.onComplete(t => cb(t.toEither)))
flatMap(executionContext) { implicit ec =>
uncancelable { poll =>
flatMap(poll(fut)) { f => async_[A](cb => f.onComplete(t => cb(t.toEither))) }
}
}

/**
* Like [[fromFuture]], but is cancelable via the provided finalizer.
*/
def fromFutureCancelable[A](futCancel: F[(Future[A], F[Unit])]): F[A] =
flatMap(futCancel) {
case (fut, fin) =>
flatMap(executionContext) { implicit ec =>
async[A](cb => as(delay(fut.onComplete(t => cb(t.toEither))), Some(fin)))
flatMap(executionContext) { implicit ec =>
uncancelable { poll =>
flatMap(poll(futCancel)) {
case (fut, fin) =>
onCancel(
poll(async[A](cb => as(delay(fut.onComplete(t => cb(t.toEither))), Some(unit)))),
fin)
}
}
}

/**
Expand Down
79 changes: 78 additions & 1 deletion tests/shared/src/test/scala/cats/effect/kernel/AsyncSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@ package kernel
import cats.{Eq, Order, StackSafeMonad}
import cats.arrow.FunctionK
import cats.effect.laws.AsyncTests
import cats.effect.testkit.TestControl
import cats.effect.unsafe.IORuntimeConfig
import cats.laws.discipline.arbitrary._

import org.scalacheck.{Arbitrary, Cogen, Prop}
import org.scalacheck.Arbitrary.arbitrary
import org.typelevel.discipline.specs2.mutable.Discipline

import scala.concurrent.ExecutionContext
import scala.concurrent.{ExecutionContext, Promise}
import scala.concurrent.duration._

import java.util.concurrent.atomic.AtomicBoolean

class AsyncSpec extends BaseSpec with Discipline {

// we just need this because of the laws testing, since the prop runs can interfere with each other
Expand All @@ -43,6 +47,79 @@ class AsyncSpec extends BaseSpec with Discipline {
) /*(Parameters(seed = Some(Seed.fromBase64("ZxDXpm7_3Pdkl-Fvt8M90Cxfam9wKuzcifQ1QsIJxND=").get)))*/
}

"fromFuture" should {
"backpressure on cancelation" in real {
// a non-cancelable, never-completing Future
def mkf() = Promise[Unit]().future

def go = for {
started <- IO(new AtomicBoolean)
fiber <- IO.fromFuture {
IO {
started.set(true)
mkf()
}
}.start
_ <- IO.cede.whileM_(IO(!started.get))
_ <- fiber.cancel
} yield ()

TestControl
.executeEmbed(go, IORuntimeConfig(1, 2))
.as(false)
.recover { case _: TestControl.NonTerminationException => true }
.replicateA(1000)
.map(_.forall(identity(_)))
}

}

"fromFutureCancelable" should {

"cancel on fiber cancelation" in real {
val smallDelay: IO[Unit] = IO.sleep(10.millis)
def mkf() = Promise[Unit]().future

val go = for {
canceled <- IO(new AtomicBoolean)
fiber <- IO.fromFutureCancelable {
IO(mkf()).map(f => f -> IO(canceled.set(true)))
}.start
_ <- smallDelay
_ <- fiber.cancel
res <- IO(canceled.get() mustEqual true)
} yield res

TestControl.executeEmbed(go, IORuntimeConfig(1, 2)).replicateA(1000)

}

"backpressure on cancelation" in real {
// a non-cancelable, never-completing Future
def mkf() = Promise[Unit]().future

val go = for {
started <- IO(new AtomicBoolean)
fiber <- IO.fromFutureCancelable {
IO {
started.set(true)
mkf()
}.map(f => f -> IO.never)
}.start
_ <- IO.cede.whileM_(IO(!started.get))
_ <- fiber.cancel
} yield ()

TestControl
.executeEmbed(go, IORuntimeConfig(1, 2))
.as(false)
.recover { case _: TestControl.NonTerminationException => true }
.replicateA(1000)
.map(_.forall(identity(_)))
}

}

final class AsyncIO[A](val io: IO[A])

implicit def asyncForAsyncIO: Async[AsyncIO] = new Async[AsyncIO]
Expand Down

0 comments on commit 5794542

Please sign in to comment.