From 586ea08dfcd476322f8a2431aae3c88b7c070ca6 Mon Sep 17 00:00:00 2001 From: Itamar Ravid Date: Sun, 14 May 2017 10:06:49 +0300 Subject: [PATCH] Implement a ReaderWriterStateT data type (#1598) * Implement a ReaderWriterStateT data type Resolves #1597. --- .../scala/cats/data/ReaderWriterStateT.scala | 554 ++++++++++++++++++ core/src/main/scala/cats/data/package.scala | 9 + .../main/scala/cats/laws/discipline/Eq.scala | 16 + .../cats/tests/ReaderWriterStateTTests.scala | 367 ++++++++++++ 4 files changed, 946 insertions(+) create mode 100644 core/src/main/scala/cats/data/ReaderWriterStateT.scala create mode 100644 tests/src/test/scala/cats/tests/ReaderWriterStateTTests.scala diff --git a/core/src/main/scala/cats/data/ReaderWriterStateT.scala b/core/src/main/scala/cats/data/ReaderWriterStateT.scala new file mode 100644 index 0000000000..ac04862f8c --- /dev/null +++ b/core/src/main/scala/cats/data/ReaderWriterStateT.scala @@ -0,0 +1,554 @@ +package cats +package data + +import cats.functor.{ Contravariant, Bifunctor, Profunctor } +import cats.syntax.either._ + +/** + * Represents a stateful computation in a context `F[_]`, over state `S`, with an initial environment `E`, + * an accumulated log `L` and a result `A`. + * + * In other words, it is a pre-baked stack of `[[ReaderT]][F, E, A]`, `[[WriterT]][F, L, A]` + * and `[[StateT]][F, S, A]`. + */ +final class ReaderWriterStateT[F[_], E, S, L, A](val runF: F[(E, S) => F[(L, S, A)]]) extends Serializable { + + /** + * Modify the initial environment using `f`. + */ + def contramap[E0](f: E0 => E)(implicit F: Functor[F]): ReaderWriterStateT[F, E0, S, L, A] = + ReaderWriterStateT.applyF { + F.map(runF) { rwsa => + (e0: E0, s: S) => rwsa(f(e0), s) + } + } + + /** + * Alias for [[contramap]]. + */ + def local[EE](f: EE => E)(implicit F: Functor[F]): ReaderWriterStateT[F, EE, S, L, A] = + contramap(f) + + /** + * Modify the result of the computation using `f`. + */ + def map[B](f: A => B)(implicit F: Functor[F]): ReaderWriterStateT[F, E, S, L, B] = + transform { (l, s, a) => (l, s, f(a)) } + + /** + * Modify the written log value using `f`. + */ + def mapWritten[LL](f: L => LL)(implicit F: Functor[F]): ReaderWriterStateT[F, E, S, LL, A] = + transform { (l, s, a) => (f(l), s, a) } + + /** + * Combine this computation with `rwsb` using `fn`. The state will be be threaded + * through the computations and the log values will be combined. + */ + def map2[B, Z](rwsb: ReaderWriterStateT[F, E, S, L, B])(fn: (A, B) => Z)( + implicit F: FlatMap[F], L: Semigroup[L]): ReaderWriterStateT[F, E, S, L, Z] = + flatMap { a => + rwsb.map { b => + fn(a, b) + } + } + + /** + * Modify the result of the computation by feeding it into `f`, threading the state + * through the resulting computation and combining the log values. + */ + def flatMap[B](f: A => ReaderWriterStateT[F, E, S, L, B])( + implicit F: FlatMap[F], L: Semigroup[L]): ReaderWriterStateT[F, E, S, L, B] = + ReaderWriterStateT.applyF { + F.map(runF) { rwsfa => + (e: E, s0: S) => + F.flatMap(rwsfa(e, s0)) { case (la, sa, a) => + F.flatMap(f(a).runF) { rwsfb => + F.map(rwsfb(e, sa)) { case (lb, sb, b) => + (L.combine(la, lb), sb, b) + } + } + } + } + } + + /** + * Like [[map]], but allows the mapping function to return an effectful value. + */ + def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): ReaderWriterStateT[F, E, S, L, B] = + ReaderWriterStateT.applyF { + F.map(runF) { rwsfa => + (e: E, s: S) => + F.flatMap(rwsfa(e, s)) { case (l, s, a) => + F.map(faf(a))((l, s, _)) + } + } + } + + /** + * Transform the resulting log, state and value using `f`. + */ + def transform[LL, B](f: (L, S, A) => (LL, S, B))(implicit F: Functor[F]): ReaderWriterStateT[F, E, S, LL, B] = + ReaderWriterStateT.applyF { + F.map(runF) { rwsfa => + (e: E, s: S) => F.map(rwsfa(e, s)) { case (l, s, a) => + val (ll, sb, b) = f(l, s, a) + (ll, sb, b) + } + } + } + + /** + * Like [[transform]], but allows the context to change from `F` to `G`. + */ + def transformF[G[_], LL, B](f: F[(L, S, A)] => G[(LL, S, B)])( + implicit F: Monad[F], G: Applicative[G]): ReaderWriterStateT[G, E, S, LL, B] = + ReaderWriterStateT.apply((e, s) => f(run(e, s))) + + + /** + * Modify the resulting state. + */ + def modify(f: S => S)(implicit F: Functor[F]): ReaderWriterStateT[F, E, S, L, A] = + transform { (l, s, a) => (l, f(s), a) } + + /** + * Inspect a value from the input state, without modifying the state. + */ + def inspect[B](f: S => B)(implicit F: Functor[F]): ReaderWriterStateT[F, E, S, L, B] = + transform { (l, s, a) => (l, s, f(s)) } + + /** + * Get the input state, without modifying it. + */ + def get(implicit F: Functor[F]): ReaderWriterStateT[F, E, S, L, S] = + inspect(identity) + + /** + * Add a value to the log. + */ + def tell(l: L)(implicit F: Functor[F], L: Semigroup[L]): ReaderWriterStateT[F, E, S, L, A] = + mapWritten(L.combine(_, l)) + + /** + * Retrieve the value written to the log. + */ + def written(implicit F: Functor[F]): ReaderWriterStateT[F, E, S, L, L] = + transform { (l, s, a) => (l, s, l) } + + /** + * Clear the log. + */ + def reset(implicit F: Functor[F], L: Monoid[L]): ReaderWriterStateT[F, E, S, L, A] = + mapWritten(_ => L.empty) + + /** + * Run the computation using the provided initial environment and state. + */ + def run(env: E, initial: S)(implicit F: Monad[F]): F[(L, S, A)] = + F.flatMap(runF)(_.apply(env, initial)) + + /** + * Run the computation using the provided environment and an empty state. + */ + def runEmpty(env: E)(implicit F: Monad[F], S: Monoid[S]): F[(L, S, A)] = + run(env, S.empty) + + /** + * Like [[run]], but discards the final state and log. + */ + def runA(env: E, initial: S)(implicit F: Monad[F]): F[A] = + F.map(run(env, initial))(_._3) + + /** + * Like [[run]], but discards the final value and log. + */ + def runS(env: E, initial: S)(implicit F: Monad[F]): F[S] = + F.map(run(env, initial))(_._2) + + /** + * Like [[run]], but discards the final state and value. + */ + def runL(env: E, initial: S)(implicit F: Monad[F]): F[L] = + F.map(run(env, initial))(_._1) + + /** + * Like [[runEmpty]], but discards the final state and log. + */ + def runEmptyA(env: E)(implicit F: Monad[F], S: Monoid[S]): F[A] = + runA(env, S.empty) + + /** + * Like [[runEmpty]], but discards the final value and log. + */ + def runEmptyS(env: E)(implicit F: Monad[F], S: Monoid[S]): F[S] = + runS(env, S.empty) + + /** + * Like [[runEmpty]], but discards the final state and value. + */ + def runEmptyL(env: E)(implicit F: Monad[F], S: Monoid[S]): F[L] = + runL(env, S.empty) +} + +object ReaderWriterStateT extends RWSTInstances { + /** + * Construct a new computation using the provided function. + */ + def apply[F[_], E, S, L, A](runF: (E, S) => F[(L, S, A)])(implicit F: Applicative[F]): ReaderWriterStateT[F, E, S, L, A] = + new ReaderWriterStateT(F.pure(runF)) + + /** + * Like [[apply]], but using a function in a context `F`. + */ + def applyF[F[_], E, S, L, A](runF: F[(E, S) => F[(L, S, A)]]): ReaderWriterStateT[F, E, S, L, A] = + new ReaderWriterStateT(runF) + + /** + * Return `a` and an empty log without modifying the input state. + */ + def pure[F[_], E, S, L, A](a: A)(implicit F: Applicative[F], L: Monoid[L]): ReaderWriterStateT[F, E, S, L, A] = + ReaderWriterStateT((_, s) => F.pure((L.empty, s, a))) + + /** + * Return an effectful `a` and an empty log without modifying the input state. + */ + def lift[F[_], E, S, L, A](fa: F[A])(implicit F: Applicative[F], L: Monoid[L]): ReaderWriterStateT[F, E, S, L, A] = + ReaderWriterStateT((_, s) => F.map(fa)((L.empty, s, _))) + + /** + * Inspect a value from the input state, without modifying the state. + */ + def inspect[F[_], E, S, L, A](f: S => A)(implicit F: Applicative[F], L: Monoid[L]): ReaderWriterStateT[F, E, S, L, A] = + ReaderWriterStateT((_, s) => F.pure((L.empty, s, f(s)))) + + /** + * Like [[inspect]], but using an effectful function. + */ + def inspectF[F[_], E, S, L, A](f: S => F[A])(implicit F: Applicative[F], L: Monoid[L]): ReaderWriterStateT[F, E, S, L, A] = + ReaderWriterStateT((_, s) => F.map(f(s))((L.empty, s, _))) + + /** + * Modify the input state using `f`. + */ + def modify[F[_], E, S, L](f: S => S)(implicit F: Applicative[F], L: Monoid[L]): ReaderWriterStateT[F, E, S, L, Unit] = + ReaderWriterStateT((_, s) => F.pure((L.empty, f(s), ()))) + + /** + * Like [[modify]], but using an effectful function. + */ + def modifyF[F[_], E, S, L](f: S => F[S])(implicit F: Applicative[F], L: Monoid[L]): ReaderWriterStateT[F, E, S, L, Unit] = + ReaderWriterStateT((_, s) => F.map(f(s))((L.empty, _, ()))) + + /** + * Return the input state without modifying it. + */ + def get[F[_], E, S, L](implicit F: Applicative[F], L: Monoid[L]): ReaderWriterStateT[F, E, S, L, S] = + ReaderWriterStateT((_, s) => F.pure((L.empty, s, s))) + + /** + * Set the state to `s`. + */ + def set[F[_], E, S, L](s: S)(implicit F: Applicative[F], L: Monoid[L]): ReaderWriterStateT[F, E, S, L, Unit] = + ReaderWriterStateT((_, _) => F.pure((L.empty, s, ()))) + + /** + * Like [[set]], but using an effectful `S` value. + */ + def setF[F[_], E, S, L](fs: F[S])(implicit F: Applicative[F], L: Monoid[L]): ReaderWriterStateT[F, E, S, L, Unit] = + ReaderWriterStateT((_, _) => F.map(fs)((L.empty, _, ()))) + + /** + * Get the provided environment, without modifying the input state. + */ + def ask[F[_], E, S, L](implicit F: Applicative[F], L: Monoid[L]): ReaderWriterStateT[F, E, S, L, E] = + ReaderWriterStateT((e, s) => F.pure((L.empty, s, e))) + + /** + * Add a value to the log, without modifying the input state. + */ + def tell[F[_], E, S, L](l: L)(implicit F: Applicative[F]): ReaderWriterStateT[F, E, S, L, Unit] = + ReaderWriterStateT((_, s) => F.pure((l, s, ()))) + + /** + * Like [[tell]], but using an effectful `L` value. + */ + def tellF[F[_], E, S, L](fl: F[L])(implicit F: Applicative[F]): ReaderWriterStateT[F, E, S, L, Unit] = + ReaderWriterStateT((_, s) => F.map(fl)((_, s, ()))) +} + +/** + * Convenience functions for ReaderWriterState. + */ +private[data] abstract class RWSFunctions { + /** + * Return `a` and an empty log without modifying the input state. + */ + def apply[E, S, L: Monoid, A](f: (E, S) => (L, S, A)): ReaderWriterState[E, S, L, A] = + ReaderWriterStateT.applyF(Now((e, s) => Now(f(e, s)))) + + /** + * Return `a` and an empty log without modifying the input state. + */ + def pure[E, S, L: Monoid, A](a: A): ReaderWriterState[E, S, L, A] = + ReaderWriterStateT.pure(a) + + /** + * Modify the input state using `f`. + */ + def modify[E, S, L: Monoid](f: S => S): ReaderWriterState[E, S, L, Unit] = + ReaderWriterStateT.modify(f) + + /** + * Inspect a value from the input state, without modifying the state. + */ + def inspect[E, S, L: Monoid, T](f: S => T): ReaderWriterState[E, S, L, T] = + ReaderWriterStateT.inspect(f) + + /** + * Return the input state without modifying it. + */ + def get[E, S, L: Monoid]: ReaderWriterState[E, S, L, S] = + ReaderWriterStateT.get + + /** + * Set the state to `s`. + */ + def set[E, S, L: Monoid](s: S): ReaderWriterState[E, S, L, Unit] = + ReaderWriterStateT.set(s) + + /** + * Get the provided environment, without modifying the input state. + */ + def ask[E, S, L](implicit L: Monoid[L]): ReaderWriterState[E, S, L, E] = + ReaderWriterStateT.ask + + /** + * Add a value to the log, without modifying the input state. + */ + def tell[E, S, L](l: L): ReaderWriterState[E, S, L, Unit] = + ReaderWriterStateT.tell(l) +} + +private[data] sealed trait RWSTInstances extends RWSTInstances1 { + implicit def catsDataMonadStateForRWST[F[_], E, S, L]( + implicit F0: Monad[F], L0: Monoid[L]): MonadState[ReaderWriterStateT[F, E, S, L, ?], S] = + new RWSTMonadState[F, E, S, L] { + implicit def F: Monad[F] = F0 + implicit def L: Monoid[L] = L0 + } + + implicit def catsDataLiftForRWST[E, S, L]( + implicit L0: Monoid[L]): TransLift.Aux[ReaderWriterStateT[?[_], E, S, L, ?], Applicative] = + new RWSTTransLift[E, S, L] { + implicit def L: Monoid[L] = L0 + } +} + +private[data] sealed trait RWSTInstances1 extends RWSTInstances2 { + implicit def catsDataMonadCombineForRWST[F[_], E, S, L]( + implicit F0: MonadCombine[F], L0: Monoid[L]): MonadCombine[ReaderWriterStateT[F, E, S, L, ?]] = + new RWSTMonadCombine[F, E, S, L] { + implicit def F: MonadCombine[F] = F0 + implicit def L: Monoid[L] = L0 + } +} + +private[data] sealed trait RWSTInstances2 extends RWSTInstances3 { + implicit def catsDataMonadErrorForRWST[F[_], E, S, L, R]( + implicit F0: MonadError[F, R], L0: Monoid[L]): MonadError[ReaderWriterStateT[F, E, S, L, ?], R] = + new RWSTMonadError[F, E, S, L, R] { + implicit def F: MonadError[F, R] = F0 + implicit def L: Monoid[L] = L0 + } + + implicit def catsDataSemigroupKForRWST[F[_], E, S, L]( + implicit F0: Monad[F], G0: SemigroupK[F]): SemigroupK[ReaderWriterStateT[F, E, S, L, ?]] = + new RWSTSemigroupK[F, E, S, L] { + implicit def F: Monad[F] = F0 + implicit def G: SemigroupK[F] = G0 + } +} + +private[data] sealed trait RWSTInstances3 extends RWSTInstances4 { + implicit def catsDataMonadReaderForRWST[F[_], E, S, L]( + implicit F0: Monad[F], L0: Monoid[L]): MonadReader[ReaderWriterStateT[F, E, S, L, ?], E] = + new RWSTMonadReader[F, E, S, L] { + implicit def F: Monad[F] = F0 + implicit def L: Monoid[L] = L0 + } +} + +private[data] sealed trait RWSTInstances4 extends RWSTInstances5 { + implicit def catsDataMonadWriterForRWST[F[_], E, S, L]( + implicit F0: Monad[F], L0: Monoid[L]): MonadWriter[ReaderWriterStateT[F, E, S, L, ?], L] = + new RWSTMonadWriter[F, E, S, L] { + implicit def F: Monad[F] = F0 + implicit def L: Monoid[L] = L0 + } +} + +private[data] sealed trait RWSTInstances5 extends RWSTInstances6 { + implicit def catsDataMonadForRWST[F[_], E, S, L](implicit F0: Monad[F], L0: Monoid[L]): Monad[ReaderWriterStateT[F, E, S, L, ?]] = + new RWSTMonad[F, E, S, L] { + implicit def F: Monad[F] = F0 + implicit def L: Monoid[L] = L0 + } +} + +private[data] sealed trait RWSTInstances6 extends RWSTInstances7 { + implicit def catsDataFunctorForRWST[F[_], E, S, L](implicit F0: Functor[F]): Functor[ReaderWriterStateT[F, E, S, L, ?]] = + new RWSTFunctor[F, E, S, L] { + implicit def F: Functor[F] = F0 + } +} + +private[data] sealed trait RWSTInstances7 extends RWSTInstances8 { + implicit def catsDataContravariantForRWST[F[_], S, L, A](implicit F0: Functor[F]): Contravariant[ReaderWriterStateT[F, ?, S, L, A]] = + new RWSTContravariant[F, S, L, A] { + implicit def F: Functor[F] = F0 + } +} + +private[data] sealed trait RWSTInstances8 extends RWSTInstances9 { + implicit def catsDataBifunctorForRWST[F[_], E, S](implicit F0: Functor[F]): Bifunctor[ReaderWriterStateT[F, E, S, ?, ?]] = + new RWSTBifunctor[F, E, S] { + implicit def F: Functor[F] = F0 + } +} + +private[data] sealed trait RWSTInstances9 { + implicit def catsDataProfunctorForRWST[F[_], S, L](implicit F0: Functor[F]): Profunctor[ReaderWriterStateT[F, ?, S, L, ?]] = + new RWSTProfunctor[F, S, L] { + implicit def F: Functor[F] = F0 + } +} + +private[data] sealed trait RWSTFunctor[F[_], E, S, L] extends Functor[ReaderWriterStateT[F, E, S, L, ?]] { + implicit def F: Functor[F] + + def map[A, B](fa: ReaderWriterStateT[F, E, S, L, A])(f: A => B): ReaderWriterStateT[F, E, S, L, B] = + fa.map(f) +} + +private[data] sealed trait RWSTContravariant[F[_], S, L, T] extends Contravariant[ReaderWriterStateT[F, ?, S, L, T]] { + implicit def F: Functor[F] + + override def contramap[A, B](fa: ReaderWriterStateT[F, A, S, L, T])(f: B => A): ReaderWriterStateT[F, B, S, L, T] = + fa.contramap(f) +} + +private[data] sealed trait RWSTBifunctor[F[_], E, S] extends Bifunctor[ReaderWriterStateT[F, E, S, ?, ?]] { + implicit def F: Functor[F] + + override def bimap[A, B, C, D](fab: ReaderWriterStateT[F, E, S, A, B])( + f: A => C, g: B => D): ReaderWriterStateT[F, E, S, C, D] = fab.mapWritten(f).map(g) +} + +private[data] sealed trait RWSTProfunctor[F[_], S, L] extends Profunctor[ReaderWriterStateT[F, ?, S, L, ?]] { + implicit def F: Functor[F] + + override def dimap[A, B, C, D](fab: ReaderWriterStateT[F, A, S, L, B])(f: C => A)(g: B => D): ReaderWriterStateT[F, C, S, L, D] = + fab.contramap(f).map(g) +} + +private[data] sealed trait RWSTMonad[F[_], E, S, L] extends Monad[ReaderWriterStateT[F, E, S, L, ?]] { + implicit def F: Monad[F] + implicit def L: Monoid[L] + + def pure[A](a: A): ReaderWriterStateT[F, E, S, L, A] = + ReaderWriterStateT.pure(a) + + def flatMap[A, B](fa: ReaderWriterStateT[F, E, S, L, A])(f: A => ReaderWriterStateT[F, E, S, L, B]): ReaderWriterStateT[F, E, S, L, B] = + fa.flatMap(f) + + def tailRecM[A, B](initA: A)(f: A => ReaderWriterStateT[F, E, S, L, Either[A, B]]): ReaderWriterStateT[F, E, S, L, B] = + ReaderWriterStateT { (e, initS) => + F.tailRecM((L.empty, initS, initA)) { case (currL, currS, currA) => + F.map(f(currA).run(e, currS)) { case (nextL, nextS, ab) => + ab.bimap((L.combine(currL, nextL), nextS, _), (L.combine(currL, nextL), nextS, _)) + } + } + } + + override def map[A, B](fa: ReaderWriterStateT[F, E, S, L, A])(f: A => B): ReaderWriterStateT[F, E, S, L, B] = + fa.map(f) + + override def map2[A, B, Z](fa: ReaderWriterStateT[F, E, S, L, A], + fb: ReaderWriterStateT[F, E, S, L, B])(f: (A, B) => Z): ReaderWriterStateT[F, E, S, L, Z] = + fa.map2(fb)(f) +} + +private[data] sealed trait RWSTMonadState[F[_], E, S, L] + extends MonadState[ReaderWriterStateT[F, E, S, L, ?], S] with RWSTMonad[F, E, S, L] { + + lazy val get: ReaderWriterStateT[F, E, S, L, S] = ReaderWriterStateT.get + + def set(s: S): ReaderWriterStateT[F, E, S, L, Unit] = ReaderWriterStateT.set(s) +} + +private[data] sealed trait RWSTTransLift[E, S, L] extends TransLift[ReaderWriterStateT[?[_], E, S, L, ?]] { + implicit def L: Monoid[L] + type TC[F[_]] = Applicative[F] + + def liftT[F[_], A](fa: F[A])(implicit F: Applicative[F]): ReaderWriterStateT[F, E, S, L, A] = + ReaderWriterStateT.lift(fa) +} + +private[data] sealed trait RWSTSemigroupK[F[_], E, S, L] extends SemigroupK[ReaderWriterStateT[F, E, S, L, ?]] { + implicit def F: Monad[F] + implicit def G: SemigroupK[F] + + def combineK[A](x: ReaderWriterStateT[F, E, S, L, A], y: ReaderWriterStateT[F, E, S, L, A]): ReaderWriterStateT[F, E, S, L, A] = + ReaderWriterStateT { (e, s) => + G.combineK(x.run(e, s), y.run(e, s)) + } +} + +private[data] sealed trait RWSTMonadCombine[F[_], E, S, L] + extends MonadCombine[ReaderWriterStateT[F, E, S, L, ?]] with RWSTMonad[F, E, S, L] + with RWSTSemigroupK[F, E, S, L] with RWSTTransLift[E, S, L] { + + implicit def F: MonadCombine[F] + override def G: MonadCombine[F] = F + + def empty[A]: ReaderWriterStateT[F, E, S, L, A] = liftT[F, A](F.empty[A]) +} + +private[data] sealed trait RWSTMonadError[F[_], E, S, L, R] + extends RWSTMonad[F, E, S, L] with MonadError[ReaderWriterStateT[F, E, S, L, ?], R] { + + implicit def F: MonadError[F, R] + + def raiseError[A](r: R): ReaderWriterStateT[F, E, S, L, A] = ReaderWriterStateT.lift(F.raiseError(r)) + + def handleErrorWith[A](fa: ReaderWriterStateT[F, E, S, L, A])(f: R => ReaderWriterStateT[F, E, S, L, A]): ReaderWriterStateT[F, E, S, L, A] = + ReaderWriterStateT { (e, s) => + F.handleErrorWith(fa.run(e, s))(r => f(r).run(e, s)) + } +} + +private[data] sealed trait RWSTMonadReader[F[_], E, S, L] + extends RWSTMonad[F, E, S, L] with MonadReader[ReaderWriterStateT[F, E, S, L, ?], E] { + + val ask: ReaderWriterStateT[F, E, S, L, E] = ReaderWriterStateT.ask + + def local[A](f: E => E)(fa: ReaderWriterStateT[F, E, S, L, A]): ReaderWriterStateT[F, E, S, L, A] = fa contramap f +} + +private[data] sealed trait RWSTMonadWriter[F[_], E, S, L] + extends RWSTMonad[F, E, S, L] with MonadWriter[ReaderWriterStateT[F, E, S, L, ?], L] { + + def writer[A](aw: (L, A)): ReaderWriterStateT[F, E, S, L, A] = + ReaderWriterStateT((_, s) => F.pure((aw._1, s, aw._2))) + + def listen[A](fa: ReaderWriterStateT[F, E, S, L, A]): ReaderWriterStateT[F, E, S, L, (L, A)] = + fa.transform { (l, s, a) => + (l, s, (l, a)) + } + + def pass[A](fa: ReaderWriterStateT[F, E, S, L, (L => L, A)]): ReaderWriterStateT[F, E, S, L, A] = + fa.transform { case (l, s, (fl, a)) => + (fl(l), s, a) + } +} diff --git a/core/src/main/scala/cats/data/package.scala b/core/src/main/scala/cats/data/package.scala index d9a348a0d2..f3a9ec9c5d 100644 --- a/core/src/main/scala/cats/data/package.scala +++ b/core/src/main/scala/cats/data/package.scala @@ -30,4 +30,13 @@ package object data { type State[S, A] = StateT[Eval, S, A] object State extends StateFunctions + + type RWST[F[_], E, S, L, A] = ReaderWriterStateT[F, E, S, L, A] + val RWST = ReaderWriterStateT + + type ReaderWriterState[E, S, L, A] = ReaderWriterStateT[Eval, E, S, L, A] + object ReaderWriterState extends RWSFunctions + + type RWS[E, S, L, A] = ReaderWriterState[E, S, L, A] + val RWS = ReaderWriterState } diff --git a/laws/src/main/scala/cats/laws/discipline/Eq.scala b/laws/src/main/scala/cats/laws/discipline/Eq.scala index 113a393b1f..dabd41ce6a 100644 --- a/laws/src/main/scala/cats/laws/discipline/Eq.scala +++ b/laws/src/main/scala/cats/laws/discipline/Eq.scala @@ -24,6 +24,22 @@ object eq { } } + /** + * Create an approximation of Eq[(A, B) => C] by generating 100 values for A and B + * and comparing the application of the two functions. + */ + implicit def catsLawsEqForFn2[A, B, C](implicit A: Arbitrary[A], B: Arbitrary[B], C: Eq[C]): Eq[(A, B) => C] = new Eq[(A, B) => C] { + val sampleCnt: Int = if (Platform.isJvm) 50 else 5 + + def eqv(f: (A, B) => C, g: (A, B) => C): Boolean = { + val samples = List.fill(sampleCnt)((A.arbitrary.sample, B.arbitrary.sample)).collect{ + case (Some(a), Some(b)) => (a, b) + case _ => sys.error("Could not generate arbitrary values to compare two functions") + } + samples.forall { case (a, b) => C.eqv(f(a, b), g(a, b)) } + } + } + /** Create an approximation of Eq[Show[A]] by using catsLawsEqForFn1[A, String] */ implicit def catsLawsEqForShow[A: Arbitrary]: Eq[Show[A]] = { Eq.by[Show[A], A => String] { showInstance => diff --git a/tests/src/test/scala/cats/tests/ReaderWriterStateTTests.scala b/tests/src/test/scala/cats/tests/ReaderWriterStateTTests.scala new file mode 100644 index 0000000000..3239ba3a0e --- /dev/null +++ b/tests/src/test/scala/cats/tests/ReaderWriterStateTTests.scala @@ -0,0 +1,367 @@ +package cats +package tests + +import cats.data.{ ReaderWriterStateT, ReaderWriterState, EitherT } +import cats.laws.discipline._ +import cats.laws.discipline.eq._ +import cats.laws.discipline.arbitrary._ +import org.scalacheck.{ Arbitrary } + +class ReaderWriterStateTTests extends CatsSuite { + import ReaderWriterStateTTests._ + + test("Basic ReaderWriterState usage") { + forAll { (context: String, initial: Int) => + val (log, state, result) = addAndLog(5).run(context, initial).value + + log should === (Vector(s"${context}: Added 5")) + state should === (initial + 5) + result should === (initial + 5) + } + } + + test("Traversing with ReaderWriterState is stack-safe") { + val ns = (0 to 100000).toList + val rws = ns.traverse(_ => addAndLog(1)) + + rws.runS("context", 0).value should === (100001) + } + + test("map2 combines logs") { + forAll { (rwsa: ReaderWriterState[String, Int, Vector[Int], Int], rwsb: ReaderWriterState[String, Int, Vector[Int], Int], c: String, s: Int) => + val logMap2 = rwsa.map2(rwsb)((_, _) => ()).runL(c, s).value + + val (logA, stateA, _) = rwsa.run(c, s).value + val logB = rwsb.runL(c, stateA).value + val combinedLog = logA |+| logB + + logMap2 should === (combinedLog) + } + } + + test("ReaderWriterState.ask provides the context") { + forAll { (context: String, initial: Int) => + ReaderWriterState.ask[String, Int, String].runA(context, initial).value should === (context) + } + } + + test("local is consistent with contramap") { + forAll { (context: Int, initial: Int, f: Int => String) => + val rwsa = ReaderWriterState.pure[String, Int, Unit, Unit](()).contramap(f).flatMap(_ => ReaderWriterState.ask) + val rwsb = ReaderWriterState.pure[String, Int, Unit, Unit](()).local(f).flatMap(_ => ReaderWriterState.ask) + + rwsa.runA(context, initial) should === (rwsb.runA(context, initial)) + } + } + + test("ReaderWriterState.pure and ReaderWriterStateT.pure are consistent") { + forAll { (value: Int) => + val rws: ReaderWriterState[String, Int, Vector[String], Int] = ReaderWriterState.pure(value) + val rwst: ReaderWriterState[String, Int, Vector[String], Int] = ReaderWriterStateT.pure(value) + + rws should === (rwst) + } + } + + test("ReaderWriterState.pure creates an ReaderWriterState with an empty log") { + forAll { (context: String, initial: Int) => + val rws: ReaderWriterState[String, Int, String, Unit] = ReaderWriterState.pure(()) + rws.run(context, initial).value should === ((Monoid[String].empty, initial, ())) + } + } + + test("ReaderWriterState.get and ReaderWriterStateT.get are consistent") { + forAll { (initial: Int) => + val rws: ReaderWriterState[String, Int, Vector[String], Int] = ReaderWriterState.get + val rwst: ReaderWriterState[String, Int, Vector[String], Int] = ReaderWriterStateT.get + + rws should === (rwst) + } + } + + test("ReaderWriterState.get and instance get are consistent") { + forAll { (initial: Int) => + val singleton = ReaderWriterState.inspect[String, Int, String, String](_.toString) + val instance = ReaderWriterState.pure[String, Int, String, Unit](()).inspect(_.toString) + + singleton should === (instance) + } + } + + test("ReaderWriterState.inspect and ReaderWriterStateT.inspect are consistent") { + forAll { (f: Int => Int) => + val rws: ReaderWriterState[String, Int, Vector[String], Int] = ReaderWriterState.inspect(f) + val rwst: ReaderWriterState[String, Int, Vector[String], Int] = ReaderWriterStateT.inspect(f) + + rws should === (rwst) + } + } + + test("ReaderWriterState.inspect and ReaderWriterStateT.inspectF are consistent") { + forAll { (f: Int => Int) => + val rws: ReaderWriterState[String, Int, String, Int] = ReaderWriterState.inspect(f) + val rwst: ReaderWriterState[String, Int, String, Int] = ReaderWriterStateT.inspectF(f.andThen(Eval.now)) + + rws should === (rwst) + } + } + + test("ReaderWriterState.modify and ReaderWriterStateT.modify are consistent") { + forAll { (f: Int => Int) => + val rws: ReaderWriterState[String, Int, Vector[String], Unit] = ReaderWriterState.modify(f) + val rwst: ReaderWriterState[String, Int, Vector[String], Unit] = ReaderWriterStateT.modify(f) + + rws should === (rwst) + } + } + + test("ReaderWriterState.modify and ReaderWriterStateT.modifyF are consistent") { + forAll { (f: Int => Int) => + val rws: ReaderWriterState[String, Int, Vector[String], Unit] = ReaderWriterState.modify(f) + val rwst: ReaderWriterState[String, Int, Vector[String], Unit] = ReaderWriterStateT.modifyF(f.andThen(Eval.now)) + + rws should === (rwst) + } + } + + test("ReaderWriterState.pure and ReaderWriterStateT.lift are consistent") { + forAll { (value: Int) => + val rws: ReaderWriterState[String, Int, Vector[String], Int] = ReaderWriterState.pure(value) + val rwst: ReaderWriterState[String, Int, Vector[String], Int] = ReaderWriterStateT.lift(Eval.now(value)) + + rws should === (rwst) + } + } + + test("ReaderWriterState.set and ReaderWriterStateT.set are consistent") { + forAll { (next: Int) => + val rws: ReaderWriterState[String, Int, Vector[String], Unit] = ReaderWriterState.set(next) + val rwst: ReaderWriterState[String, Int, Vector[String], Unit] = ReaderWriterStateT.set(next) + + rws should === (rwst) + } + } + + test("ReaderWriterState.set and ReaderWriterStateT.setF are consistent") { + forAll { (next: Int) => + val rws: ReaderWriterState[String, Int, Vector[String], Unit] = ReaderWriterState.set(next) + val rwst: ReaderWriterState[String, Int, Vector[String], Unit] = ReaderWriterStateT.setF(Eval.now(next)) + + rws should === (rwst) + } + } + + test("ReaderWriterState.tell and ReaderWriterStateT.tell are consistent") { + forAll { (log: String) => + val rws: ReaderWriterState[String, Int, String, Unit] = ReaderWriterState.tell(log) + val rwst: ReaderWriterState[String, Int, String, Unit] = ReaderWriterStateT.tell(log) + + rws should === (rwst) + } + } + + test("ReaderWriterState.tell and ReaderWriterStateT.tellF are consistent") { + forAll { (log: String) => + val rws: ReaderWriterState[String, Int, String, Unit] = ReaderWriterState.tell(log) + val rwst: ReaderWriterState[String, Int, String, Unit] = ReaderWriterStateT.tellF(Eval.now(log)) + + rws should === (rwst) + } + } + + test("ReaderWriterState.tell + written is identity") { + forAll { (context: String, initial: Int, log: String) => + ReaderWriterState.tell[String, Int, String](log).written.runA(context, initial).value should === (log) + } + } + + test("Cartesian syntax is usable on ReaderWriterState") { + val rws = addAndLog(5) *> addAndLog(10) + val (log, state, result) = rws.run("context", 0).value + + log should === (Vector("context: Added 5", "context: Added 10")) + state should === (15) + result should === (15) + } + + test("flatMap and flatMapF+tell are consistent") { + forAll { + (rwst: ReaderWriterStateT[Option, String, String, String, Int], f: Int => Option[Int], + initial: String, context: String, log: String) => + + val flatMap = rwst.flatMap { a => + ReaderWriterStateT { (e, s) => + f(a).map((log, s, _)) + } + } + + val flatMapF = rwst.flatMapF(f).tell(log) + + flatMap.run(context, initial) should === (flatMapF.run(context, initial)) + } + } + + test("runEmpty, runEmptyS, runEmptyA and runEmptyL are consistent") { + forAll { (f: ReaderWriterStateT[Option, String, String, String, Int], c: String) => + (f.runEmptyL(c) |@| f.runEmptyS(c) |@| f.runEmptyA(c)).tupled should === (f.runEmpty(c)) + } + } + + test("reset on pure is a noop") { + forAll { (c: String, s: Int, a: Int) => + val pure = ReaderWriterState.pure[String, Int, String, Int](a) + pure.reset should === (pure) + } + } + + test("modify identity is a noop") { + forAll { (f: ReaderWriterStateT[Option, String, String, String, Int], c: String, initial: String) => + f.modify(identity).run(c, initial) should === (f.run(c, initial)) + } + } + + test("modify modifies only the state") { + forAll { (rws: ReaderWriterStateT[Option, String, Long, String, Long], c: String, f: Long => Long, initial: Long) => + rws.modify(f).runS(c, initial) should === (rws.runS(c, initial).map(f)) + rws.modify(f).runA(c, initial) should === (rws.runA(c, initial)) + } + } + + test("reset modifies only the log") { + forAll { (rws: ReaderWriterState[String, Int, String, Int], c: String, s: Int) => + rws.reset.runA(c, s) should === (rws.runA(c, s)) + rws.reset.runS(c, s) should === (rws.runS(c, s)) + } + } + + test("modify is equivalent to get and set") { + forAll { (c: String, f: Long => Long, initial: Long) => + val s1 = ReaderWriterStateT.modify[Option, String, Long, String](f) + val s2 = for { + l <- ReaderWriterStateT.get[Option, String, Long, String] + _ <- ReaderWriterStateT.set[Option, String, Long, String](f(l)) + } yield () + + s1.run(c, initial) should === (s2.run(c, initial)) + } + } + + test("ReaderWriterStateT.set is equivalent to modify ignoring first param") { + forAll { (c: String, initial: Long, s: Long) => + val s1 = ReaderWriterStateT.set[Option, String, Long, String](s) + val s2 = ReaderWriterStateT.modify[Option, String, Long, String](_ => s) + + s1.run(c, initial) should === (s2.run(c, initial)) + } + } + + test("ReaderWriterStateT.setF is equivalent to modifyF ignoring first param") { + forAll { (c: String, initial: Long, s: Option[Long]) => + val s1 = ReaderWriterStateT.setF[Option, String, Long, String](s) + val s2 = ReaderWriterStateT.modifyF[Option, String, Long, String](_ => s) + + s1.run(c, initial) should === (s2.run(c, initial)) + } + } + + test(".get and then .run produces the same state as value") { + forAll { (c: String, initial: Long, rws: ReaderWriterState[String, Long, String, Long]) => + val (_, state, value) = rws.get.run(c, initial).value + + state should === (value) + } + } + + test(".get and .flatMap with .get are equivalent") { + forAll { (c: String, initial: Long, rws: ReaderWriterState[String, Long, String, Long]) => + rws.get.run(c, initial) should === (rws.flatMap(_ => ReaderWriterState.get).run(c, initial)) + } + } + + implicit val iso = CartesianTests.Isomorphisms + .invariant[ReaderWriterStateT[ListWrapper, String, Int, String, ?]](ReaderWriterStateT.catsDataFunctorForRWST(ListWrapper.monad)) + + { + implicit val F: Monad[ListWrapper] = ListWrapper.monad + checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]", + FunctorTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?]].functor[Int, Int, Int]) + checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]", + ContravariantTests[ReaderWriterStateT[ListWrapper, ?, Int, String, Int]].contravariant[String, String, String]) + checkAll("ReaderWriterStateT[ListWrapper, Int, Int, String, Int]", + ProfunctorTests[ReaderWriterStateT[ListWrapper, ?, Int, String, ?]].profunctor[Int, Int, Int, Int, Int, Int]) + checkAll("ReaderWriterStateT[ListWrapper, Int, Int, Int, Int]", + BifunctorTests[ReaderWriterStateT[ListWrapper, String, Int, ?, ?]].bifunctor[Int, Int, Int, Int, Int, Int]) + } + + { + implicit val LWM: Monad[ListWrapper] = ListWrapper.monad + + checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]", + MonadTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?]].monad[Int, Int, Int]) + } + + { + implicit val LWM: Monad[ListWrapper] = ListWrapper.monad + + checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]", + MonadStateTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?], Int].monadState[Int, Int, Int]) + } + + { + implicit val LWM: MonadCombine[ListWrapper] = ListWrapper.monadCombine + + checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]", + MonadCombineTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?]].monadCombine[Int, Int, Int]) + } + + { + implicit val LWM: Monad[ListWrapper] = ListWrapper.monad + + checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]", + MonadReaderTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?], String].monadReader[String, String, String]) + } + + { + implicit val LWM: Monad[ListWrapper] = ListWrapper.monad + + checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]", + MonadWriterTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?], String].monadWriter[String, String, String]) + } + + { + implicit val iso = CartesianTests.Isomorphisms.invariant[ReaderWriterStateT[Option, String, Int, String, ?]] + implicit val eqEitherTFA: Eq[EitherT[ReaderWriterStateT[Option, String, Int, String, ?], Unit, Int]] = + EitherT.catsDataEqForEitherT[ReaderWriterStateT[Option, String, Int, String, ?], Unit, Int] + + checkAll("ReaderWriterStateT[Option, String, Int, String, Int]", + MonadErrorTests[ReaderWriterStateT[Option, String, Int, String, ?], Unit].monadError[Int, Int, Int]) + } + + { + implicit def F = ListWrapper.monad + implicit def S = ListWrapper.semigroupK + + checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]", + SemigroupKTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?]].semigroupK[Int]) + } +} + +object ReaderWriterStateTTests { + def addAndLog(i: Int): ReaderWriterState[String, Int, Vector[String], Int] = { + import cats.instances.vector._ + + ReaderWriterState { (context, state) => + (Vector(s"${context}: Added ${i}"), state + i, state + i) + } + } + + implicit def RWSTArbitrary[F[_]: Applicative, E, S, L, A]( + implicit F: Arbitrary[(E, S) => F[(L, S, A)]]): Arbitrary[ReaderWriterStateT[F, E, S, L, A]] = + Arbitrary(F.arbitrary.map(ReaderWriterStateT(_))) + + implicit def RWSTEq[F[_], E, S, L, A](implicit S: Arbitrary[S], E: Arbitrary[E], FLSA: Eq[F[(L, S, A)]], + F: Monad[F]): Eq[ReaderWriterStateT[F, E, S, L, A]] = + Eq.by[ReaderWriterStateT[F, E, S, L, A], (E, S) => F[(L, S, A)]] { state => + (e, s) => state.run(e, s) + } +}