diff --git a/core/src/main/scala/cats/ApplicativeError.scala b/core/src/main/scala/cats/ApplicativeError.scala index 61ab4b6604..ff2fd40a46 100644 --- a/core/src/main/scala/cats/ApplicativeError.scala +++ b/core/src/main/scala/cats/ApplicativeError.scala @@ -76,6 +76,41 @@ trait ApplicativeError[F[_], E] extends Applicative[F] { def recoverWith[A](fa: F[A])(pf: PartialFunction[E, F[A]]): F[A] = handleErrorWith(fa)(e => pf applyOrElse(e, raiseError)) + + /** + * Execute a callback on certain errors, then rethrow them. + * Any non matching error is rethrown as well. + * + * In the following example, only one of the errors is logged, + * but they are both rethrown, to be possibly handled by another + * layer of the program: + * + * {{{ + * scala> import cats._, data._, implicits._ + * + * scala> case class Err(msg: String) + * + * scala> type F[A] = EitherT[State[String, ?], Err, A] + * + * scala> val action: PartialFunction[Err, F[Unit]] = { + * | case Err("one") => EitherT.liftT(State.set("one")) + * | } + * + * scala> val prog1: F[Int] = (Err("one")).raiseError[F, Int] + * scala> val prog2: F[Int] = (Err("two")).raiseError[F, Int] + * + * scala> prog1.onError(action).value.run("").value + + * res0: (String, Either[Err,Int]) = (one,Left(Err(one))) + * + * scala> prog2.onError(action).value.run("").value + * res1: (String, Either[Err,Int]) = ("",Left(Err(two))) + * }}} + */ + def onError[A](fa: F[A])(pf: PartialFunction[E, F[Unit]]): F[A] = + handleErrorWith(fa)(e => + (pf andThen (map2(_, raiseError[A](e))((_, b) => b))) applyOrElse(e, raiseError)) + /** * Often E is Throwable. Here we try to call pure or catch * and raise. diff --git a/core/src/main/scala/cats/MonadError.scala b/core/src/main/scala/cats/MonadError.scala index e0368bce30..12348002c9 100644 --- a/core/src/main/scala/cats/MonadError.scala +++ b/core/src/main/scala/cats/MonadError.scala @@ -19,6 +19,28 @@ trait MonadError[F[_], E] extends ApplicativeError[F, E] with Monad[F] { def ensureOr[A](fa: F[A])(error: A => E)(predicate: A => Boolean): F[A] = flatMap(fa)(a => if (predicate(a)) pure(a) else raiseError(error(a))) + /** + * Transform certain errors using `pf` and rethrow them. + * Non matching errors and successful values are not affected by this function. + * + * Example: + * {{{ + * scala> import cats._, implicits._ + * + * scala> def pf: PartialFunction[String, String] = { case "error" => "ERROR" } + * + * scala> "error".asLeft[Int].adaptError(pf) + * res0: Either[String,Int] = Left(ERROR) + * + * scala> "err".asLeft[Int].adaptError(pf) + * res1: Either[String,Int] = Left(err) + * + * scala> 1.asRight[String].adaptError(pf) + * res2: Either[String,Int] = Right(1) + * }}} + */ + def adaptError[A](fa: F[A])(pf: PartialFunction[E, E]): F[A] = + flatMap(attempt(fa))(_.fold(e => raiseError(pf.applyOrElse[E, E](e, _ => e)), pure)) } object MonadError { diff --git a/core/src/main/scala/cats/syntax/applicativeError.scala b/core/src/main/scala/cats/syntax/applicativeError.scala index 166adddab3..b2b63798eb 100644 --- a/core/src/main/scala/cats/syntax/applicativeError.scala +++ b/core/src/main/scala/cats/syntax/applicativeError.scala @@ -34,4 +34,7 @@ final class ApplicativeErrorOps[F[_], E, A](val fa: F[A]) extends AnyVal { def recoverWith(pf: PartialFunction[E, F[A]])(implicit F: ApplicativeError[F, E]): F[A] = F.recoverWith(fa)(pf) + + def onError(pf: PartialFunction[E, F[Unit]])(implicit F: ApplicativeError[F, E]): F[A] = + F.onError(fa)(pf) } diff --git a/core/src/main/scala/cats/syntax/monadError.scala b/core/src/main/scala/cats/syntax/monadError.scala index 00a9381efa..95a4f07d53 100644 --- a/core/src/main/scala/cats/syntax/monadError.scala +++ b/core/src/main/scala/cats/syntax/monadError.scala @@ -12,4 +12,7 @@ final class MonadErrorOps[F[_], E, A](val fa: F[A]) extends AnyVal { def ensureOr(error: A => E)(predicate: A => Boolean)(implicit F: MonadError[F, E]): F[A] = F.ensureOr(fa)(error)(predicate) + + def adaptError(pf: PartialFunction[E, E])(implicit F: MonadError[F, E]): F[A] = + F.adaptError(fa)(pf) } diff --git a/laws/src/main/scala/cats/laws/ApplicativeErrorLaws.scala b/laws/src/main/scala/cats/laws/ApplicativeErrorLaws.scala index 460aa7eadd..e2a42706f9 100644 --- a/laws/src/main/scala/cats/laws/ApplicativeErrorLaws.scala +++ b/laws/src/main/scala/cats/laws/ApplicativeErrorLaws.scala @@ -39,6 +39,12 @@ trait ApplicativeErrorLaws[F[_], E] extends ApplicativeLaws[F] { def attemptFromEitherConsistentWithPure[A](eab: Either[E, A]): IsEq[F[Either[E, A]]] = F.attempt(F.fromEither(eab)) <-> F.pure(eab) + + def onErrorPure[A](a: A, f: E => F[Unit]): IsEq[F[A]] = + F.onError(F.pure(a))(PartialFunction(f)) <-> F.pure(a) + + def onErrorRaise[A](fa: F[A], e: E, fb: F[Unit]): IsEq[F[A]] = + F.onError(F.raiseError[A](e)){case err => fb} <-> F.map2(fb, F.raiseError[A](e))((_, b) => b) } object ApplicativeErrorLaws { diff --git a/laws/src/main/scala/cats/laws/MonadErrorLaws.scala b/laws/src/main/scala/cats/laws/MonadErrorLaws.scala index 17a30554c0..763003a2bb 100644 --- a/laws/src/main/scala/cats/laws/MonadErrorLaws.scala +++ b/laws/src/main/scala/cats/laws/MonadErrorLaws.scala @@ -13,6 +13,12 @@ trait MonadErrorLaws[F[_], E] extends ApplicativeErrorLaws[F, E] with MonadLaws[ def monadErrorEnsureOrConsistency[A](fa: F[A], e: A => E, p: A => Boolean): IsEq[F[A]] = F.ensureOr(fa)(e)(p) <-> F.flatMap(fa)(a => if (p(a)) F.pure(a) else F.raiseError(e(a))) + + def adaptErrorPure[A](a: A, f: E => E): IsEq[F[A]] = + F.adaptError(F.pure(a))(PartialFunction(f)) <-> F.pure(a) + + def adaptErrorRaise[A](e: E, f: E => E): IsEq[F[A]] = + F.adaptError(F.raiseError[A](e))(PartialFunction(f)) <-> F.raiseError(f(e)) } object MonadErrorLaws { diff --git a/laws/src/main/scala/cats/laws/discipline/ApplicativeErrorTests.scala b/laws/src/main/scala/cats/laws/discipline/ApplicativeErrorTests.scala index 0c2ecbf1d6..8e52d7acaf 100644 --- a/laws/src/main/scala/cats/laws/discipline/ApplicativeErrorTests.scala +++ b/laws/src/main/scala/cats/laws/discipline/ApplicativeErrorTests.scala @@ -15,6 +15,7 @@ trait ApplicativeErrorTests[F[_], E] extends ApplicativeTests[F] { ArbFA: Arbitrary[F[A]], ArbFB: Arbitrary[F[B]], ArbFC: Arbitrary[F[C]], + ArbFU: Arbitrary[F[Unit]], ArbFAtoB: Arbitrary[F[A => B]], ArbFBtoC: Arbitrary[F[B => C]], ArbE: Arbitrary[E], @@ -47,7 +48,9 @@ trait ApplicativeErrorTests[F[_], E] extends ApplicativeTests[F] { "applicativeError handleError consistent with recover" -> forAll(laws.handleErrorConsistentWithRecover[A] _), "applicativeError recover consistent with recoverWith" -> forAll(laws.recoverConsistentWithRecoverWith[A] _), "applicativeError attempt consistent with attemptT" -> forAll(laws.attemptConsistentWithAttemptT[A] _), - "applicativeError attempt fromEither consistent with pure" -> forAll(laws.attemptFromEitherConsistentWithPure[A] _) + "applicativeError attempt fromEither consistent with pure" -> forAll(laws.attemptFromEitherConsistentWithPure[A] _), + "applicativeError onError pure" -> forAll(laws.onErrorPure[A] _), + "applicativeError onError raise" -> forAll(laws.onErrorRaise[A] _) ) } } diff --git a/laws/src/main/scala/cats/laws/discipline/MonadErrorTests.scala b/laws/src/main/scala/cats/laws/discipline/MonadErrorTests.scala index 2784a4ed58..c6a6a4d5c8 100644 --- a/laws/src/main/scala/cats/laws/discipline/MonadErrorTests.scala +++ b/laws/src/main/scala/cats/laws/discipline/MonadErrorTests.scala @@ -14,6 +14,7 @@ trait MonadErrorTests[F[_], E] extends ApplicativeErrorTests[F, E] with MonadTes ArbFA: Arbitrary[F[A]], ArbFB: Arbitrary[F[B]], ArbFC: Arbitrary[F[C]], + ArbFU: Arbitrary[F[Unit]], ArbFAtoB: Arbitrary[F[A => B]], ArbFBtoC: Arbitrary[F[B => C]], ArbE: Arbitrary[E], @@ -39,7 +40,9 @@ trait MonadErrorTests[F[_], E] extends ApplicativeErrorTests[F, E] with MonadTes def props: Seq[(String, Prop)] = Seq( "monadError left zero" -> forAll(laws.monadErrorLeftZero[A, B] _), "monadError ensure consistency" -> forAll(laws.monadErrorEnsureConsistency[A] _), - "monadError ensureOr consistency" -> forAll(laws.monadErrorEnsureOrConsistency[A] _) + "monadError ensureOr consistency" -> forAll(laws.monadErrorEnsureOrConsistency[A] _), + "monadError adaptError pure" -> forAll(laws.adaptErrorPure[A] _), + "monadError adaptError raise" -> forAll(laws.adaptErrorRaise[A] _) ) } }