From b4fd74ca0af7a8065a73476ba62b78dfbd127af2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Abecasis?= Date: Fri, 19 Apr 2024 01:45:23 +0100 Subject: [PATCH] Make `timeout`/`timeoutTo` always return the outcome of the effect `timeout*` methods are implemented in terms of a race between a desired effect and the timeout. In the case that both effects complete simultaneously, it could happen that the timeout would win the race, a `TimeoutException` be raised, and the outcome of the desired effect lost. As is noted in #3456, this is a general problem with the `race*` methods, and can't be addressed in the general case without breaking the current interfaces. This change is a more narrow take on the problem specifically focusing on the `timeout` and `timeoutTo` methods. As these methods inherently wait for both racing effects to complete, the implementation is changed to always take into account the outcome of the desired effect, only raising a `TimeoutException` if the timeout won the race *and* the desired effect was effectively canceled. Similarly, errors from the desired effect are preferentially propagated over the generic `TimeoutException`. The `timeoutAndForget` methods are left unchanged, as they explicitly avoid waiting for the losing effect to finish. This change allows for `timeout` and `timeoutTo` methods to be safely used on effects that acquire resources, such as `Semaphore.acquire`, ensuring that successful outcomes are always propagated back to the user. --- .../src/main/scala/cats/effect/IO.scala | 9 +++++--- .../cats/effect/kernel/GenTemporal.scala | 23 ++++++++++++++----- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/core/shared/src/main/scala/cats/effect/IO.scala b/core/shared/src/main/scala/cats/effect/IO.scala index 7138a0b915f..0c51d9da36f 100644 --- a/core/shared/src/main/scala/cats/effect/IO.scala +++ b/core/shared/src/main/scala/cats/effect/IO.scala @@ -819,9 +819,12 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] { */ def timeoutTo[A2 >: A](duration: Duration, fallback: IO[A2]): IO[A2] = { handleDuration[IO[A2]](duration, this) { finiteDuration => - race(IO.sleep(finiteDuration)).flatMap { - case Right(_) => fallback - case Left(value) => IO.pure(value) + IO.uncancelable { poll => + poll(racePair(IO.sleep(finiteDuration))) flatMap { + case Left((oc, f)) => f.cancel *> oc.embed(poll(IO.canceled) *> IO.never) + case Right((f, _)) => + f.cancel *> f.join.flatMap { oc => oc.fold(fallback, IO.raiseError, identity) } + } } } } diff --git a/kernel/shared/src/main/scala/cats/effect/kernel/GenTemporal.scala b/kernel/shared/src/main/scala/cats/effect/kernel/GenTemporal.scala index 56d6d0d2702..5ed9f8d4518 100644 --- a/kernel/shared/src/main/scala/cats/effect/kernel/GenTemporal.scala +++ b/kernel/shared/src/main/scala/cats/effect/kernel/GenTemporal.scala @@ -91,9 +91,13 @@ trait GenTemporal[F[_], E] extends GenConcurrent[F, E] with Clock[F] { handleDuration(duration, fa)(timeoutTo(fa, _, fallback)) protected def timeoutTo[A](fa: F[A], duration: FiniteDuration, fallback: F[A]): F[A] = - flatMap(race(fa, sleep(duration))) { - case Left(a) => pure(a) - case Right(_) => fallback + uncancelable { poll => + implicit val F: GenTemporal[F, E] = this + + poll(racePair(fa, sleep(duration))) flatMap { + case Left((oc, f)) => f.cancel *> oc.embed(poll(F.canceled) *> F.never) + case Right((f, _)) => f.cancel *> f.join.flatMap { oc => oc.embed(fallback) } + } } /** @@ -115,9 +119,16 @@ trait GenTemporal[F[_], E] extends GenConcurrent[F, E] with Clock[F] { protected def timeout[A](fa: F[A], duration: FiniteDuration)( implicit ev: TimeoutException <:< E): F[A] = { - flatMap(race(fa, sleep(duration))) { - case Left(a) => pure(a) - case Right(_) => raiseError[A](ev(new TimeoutException(duration.toString()))) + uncancelable { poll => + implicit val F: GenTemporal[F, E] = this + + poll(racePair(fa, sleep(duration))) flatMap { + case Left((oc, f)) => f.cancel *> oc.embed(poll(F.canceled) *> F.never) + case Right((f, _)) => + f.cancel *> f.join.flatMap { oc => + oc.embed(raiseError[A](ev(new TimeoutException(duration.toString())))) + } + } } }