diff --git a/core/src/main/scala/cats/MonadError.scala b/core/src/main/scala/cats/MonadError.scala index e0368bce30..9a56f73f9c 100644 --- a/core/src/main/scala/cats/MonadError.scala +++ b/core/src/main/scala/cats/MonadError.scala @@ -19,6 +19,23 @@ trait MonadError[F[_], E] extends ApplicativeError[F, E] with Monad[F] { def ensureOr[A](fa: F[A])(error: A => E)(predicate: A => Boolean): F[A] = flatMap(fa)(a => if (predicate(a)) pure(a) else raiseError(error(a))) + /** + * Sequences the specified finalizer ensuring evaluation regardless of + * whether or not the target `F[A]` raises an error. + * + * If `raiseError` is analogous to `throw` and `handleErrorWith` is analogous to + * `catch`, then `guarantee` is analogous to `finally`. JVM exception semantics + * are mirrored with respect to error raised within the finalizer (i.e. errors + * raised within the finalizer will shade any errors raised by the primary action). + * + * @see [[raiseError]] + * @see [[handleErrorWith]] + */ + def guarantee[A](fa: F[A], finalizer: F[Unit]): F[A] = { + flatMap(attempt(fa)) { e => + flatMap(finalizer)(_ => e.fold(raiseError, pure)) + } + } } object MonadError { diff --git a/core/src/main/scala/cats/syntax/monadError.scala b/core/src/main/scala/cats/syntax/monadError.scala index 00a9381efa..651914af76 100644 --- a/core/src/main/scala/cats/syntax/monadError.scala +++ b/core/src/main/scala/cats/syntax/monadError.scala @@ -7,9 +7,13 @@ trait MonadErrorSyntax { } final class MonadErrorOps[F[_], E, A](val fa: F[A]) extends AnyVal { + def ensure(error: => E)(predicate: A => Boolean)(implicit F: MonadError[F, E]): F[A] = F.ensure(fa)(error)(predicate) def ensureOr(error: A => E)(predicate: A => Boolean)(implicit F: MonadError[F, E]): F[A] = F.ensureOr(fa)(error)(predicate) + + def guarantee(finalizer: F[Unit])(implicit F: MonadError[F, E]): F[A] = + F.guarantee(fa, finalizer) } diff --git a/tests/src/test/scala/cats/tests/MonadErrorSuite.scala b/tests/src/test/scala/cats/tests/MonadErrorTest.scala similarity index 50% rename from tests/src/test/scala/cats/tests/MonadErrorSuite.scala rename to tests/src/test/scala/cats/tests/MonadErrorTest.scala index 9c9c66d615..8a87d32069 100644 --- a/tests/src/test/scala/cats/tests/MonadErrorSuite.scala +++ b/tests/src/test/scala/cats/tests/MonadErrorTest.scala @@ -1,7 +1,11 @@ package cats package tests -class MonadErrorSuite extends CatsSuite { +import data.StateT + +import scala.language.postfixOps + +class MonadErrorTest extends CatsSuite { val successful: Option[Int] = 42.some val failed: Option[Int] = None @@ -32,5 +36,31 @@ class MonadErrorSuite extends CatsSuite { failed.ensureOr(_ => ())(_ => true) should === (failed) } + { + import StateTTests._ + + type Test[A] = StateT[Either[Boolean, ?], Int, A] + + val successful: Test[String] = StateT.modify[Either[Boolean, ?], Int](1 +).map(_ => "foo") + val failed1: Test[String] = StateT.lift(Left(true)) + val failed2: Test[String] = StateT.lift(Left(false)) + val finalizer: Test[Unit] = StateT.modify[Either[Boolean, ?], Int](10 *) + test("guarantee returns successful") { + successful.guarantee(finalizer) should === (successful.flatMap(a => finalizer.map(_ => a))) + } + + test("guarantee runs finalizer on fail") { + val expected: Test[String] = for { + _ <- StateT.modify[Either[Boolean, ?], Int](10 *) + _ <- failed1 + } yield "foo" + + failed1.guarantee(finalizer) should === (expected) + } + + test("guarantee returns inner errors on double failure") { + failed1.guarantee(failed2 >> ().pure[Test]) should === (failed2) + } + } }