Skip to content

Commit

Permalink
Implement a ReaderWriterStateT data type
Browse files Browse the repository at this point in the history
Resolves #1597.
  • Loading branch information
Itamar Ravid committed Apr 8, 2017
1 parent 175bdfa commit 40ae965
Show file tree
Hide file tree
Showing 2 changed files with 363 additions and 0 deletions.
360 changes: 360 additions & 0 deletions core/src/main/scala/cats/data/RWST.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
package cats
package data

import cats.functor.{ Contravariant, Bifunctor, Profunctor }
import cats.syntax.either._

final class RWST[F[_], E, S, L, A](val runF: F[(E, S) => F[(L, S, A)]]) extends Serializable {
def contramap[E0](f: E0 => E)(implicit F: Functor[F]): RWST[F, E0, S, L, A] =
RWST.applyF {
F.map(runF) { rwsa =>
(e0: E0, s: S) => rwsa(f(e0), s)
}
}

def map[B](f: A => B)(implicit F: Functor[F]): RWST[F, E, S, L, B] =
transform { (s, a) => (s, f(a)) }

def mapWritten[L0](f: L => L0)(implicit F: Functor[F]): RWST[F, E, S, L0, A] =
RWST.applyF {
F.map(runF) { rwsa =>
(e: E, s: S) => F.map(rwsa(e, s)) { case (l, s, a) =>
(f(l), s, a)
}
}
}

def map2[B, Z](rwsb: RWST[F, E, S, L, B])(fn: (A, B) => Z)(implicit F: FlatMap[F], L: Semigroup[L]): RWST[F, E, S, L, Z] =
RWST.applyF {
F.map2(runF, rwsb.runF) { (rwsfa, rwsfb) =>
(e: E, s0: S) =>
F.flatMap(rwsfa(e, s0)) { case (la, sa, a) =>
F.map(rwsfb(e, sa)) { case (lb, sb, b) =>
(L.combine(la, lb), sb, fn(a, b))
}
}
}
}

def map2Eval[B, Z](fb: Eval[RWST[F, E, S, L, B]])(fn: (A, B) => Z)(implicit F: FlatMap[F], L: Semigroup[L]): Eval[RWST[F, E, S, L, Z]] =
F.map2Eval(runF, fb.map(_.runF)) { (rwsfa, rwsfb) =>
(e: E, s0: S) =>
F.flatMap(rwsfa(e, s0)) { case (la, sa, a) =>
F.map(rwsfb(e, sa)) { case (lb, sb, b) =>
(L.combine(la, lb), sb, fn(a, b))
}
}
}.map(RWST.applyF(_))

def flatMap[B](f: A => RWST[F, E, S, L, B])(implicit F: FlatMap[F], L: Semigroup[L]): RWST[F, E, S, L, B] =
RWST.applyF {
F.map(runF) { rwsfa =>
(e: E, s0: S) =>
F.flatMap(rwsfa(e, s0)) { case (la, sa, a) =>
F.flatMap(f(a).runF) { rwsfb =>
F.map(rwsfb(e, sa)) { case (lb, sb, b) =>
(L.combine(la, lb), sb, b)
}
}
}
}
}

def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): RWST[F, E, S, L, B] =
RWST.applyF {
F.map(runF) { rwsfa =>
(e: E, s: S) =>
F.flatMap(rwsfa(e, s)) { case (l, s, a) =>
F.map(faf(a))((l, s, _))
}
}
}

def transform[B](f: (S, A) => (S, B))(implicit F: Functor[F]): RWST[F, E, S, L, B] =
RWST.applyF {
F.map(runF) { rwsfa =>
(e: E, s: S) => F.map(rwsfa(e, s)) { case (l, s, a) =>
val (sb, b) = f(s, a)
(l, sb, b)
}
}
}

def modify(f: S => S)(implicit F: Functor[F]): RWST[F, E, S, L, A] =
transform { (s, a) => (f(s), a) }

def inspect[B](f: S => B)(implicit F: Functor[F]): RWST[F, E, S, L, B] =
transform { (s, a) => (s, f(s)) }

def get(implicit F: Functor[F]): RWST[F, E, S, L, S] =
inspect(identity)

def tell(l: L)(implicit F: Functor[F], L: Semigroup[L]): RWST[F, E, S, L, A] =
mapWritten(L.combine(_, l))

def reset(implicit F: Functor[F], L: Monoid[L]): RWST[F, E, S, L, A] =
mapWritten(_ => L.empty)

def run(env: E, initial: S)(implicit F: Monad[F]): F[(L, S, A)] =
F.flatMap(runF)(_.apply(env, initial))

def runEmpty(env: E)(implicit F: Monad[F], S: Monoid[S]): F[(L, S, A)] =
run(env, S.empty)

def runA(env: E, initial: S)(implicit F: Monad[F]): F[A] =
F.map(run(env, initial))(_._3)

def runS(env: E, initial: S)(implicit F: Monad[F]): F[S] =
F.map(run(env, initial))(_._2)

def runL(env: E, initial: S)(implicit F: Monad[F]): F[L] =
F.map(run(env, initial))(_._1)

def runEmptyA(env: E)(implicit F: Monad[F], S: Monoid[S]): F[A] =
runA(env, S.empty)

def runEmptyS(env: E)(implicit F: Monad[F], S: Monoid[S]): F[S] =
runS(env, S.empty)

def runEmptyL(env: E)(implicit F: Monad[F], S: Monoid[S]): F[L] =
runL(env, S.empty)
}

object RWST extends RWSTInstances {
def apply[F[_], E, S, L, A](runF: (E, S) => F[(L, S, A)])(implicit F: Applicative[F]): RWST[F, E, S, L, A] =
new RWST(F.pure(runF))

def applyF[F[_], E, S, L, A](runF: F[(E, S) => F[(L, S, A)]]): RWST[F, E, S, L, A] =
new RWST(runF)

def pure[F[_], E, S, L, A](a: A)(implicit F: Applicative[F], L: Monoid[L]): RWST[F, E, S, L, A] =
RWST((_, s) => F.pure((L.empty, s, a)))

def lift[F[_], E, S, L, A](fa: F[A])(implicit F: Applicative[F], L: Monoid[L]): RWST[F, E, S, L, A] =
RWST((_, s) => F.map(fa)((L.empty, s, _)))

def inspect[F[_], E, S, L, A](f: S => A)(implicit F: Applicative[F], L: Monoid[L]): RWST[F, E, S, L, A] =
RWST((_, s) => F.pure((L.empty, s, f(s))))

def inspectF[F[_], E, S, L, A](f: S => F[A])(implicit F: Applicative[F], L: Monoid[L]): RWST[F, E, S, L, A] =
RWST((_, s) => F.map(f(s))((L.empty, s, _)))

def modify[F[_], E, S, L, A](f: S => S)(implicit F: Applicative[F], L: Monoid[L]): RWST[F, E, S, L, Unit] =
RWST((_, s) => F.pure((L.empty, f(s), ())))

def modifyF[F[_], E, S, L, A](f: S => F[S])(implicit F: Applicative[F], L: Monoid[L]): RWST[F, E, S, L, Unit] =
RWST((_, s) => F.map(f(s))((L.empty, _, ())))

def get[F[_], E, S, L, A](implicit F: Applicative[F], L: Monoid[L]): RWST[F, E, S, L, S] =
RWST((_, s) => F.pure((L.empty, s, s)))

def set[F[_], E, S, L, A](s: S)(implicit F: Applicative[F], L: Monoid[L]): RWST[F, E, S, L, Unit] =
RWST((_, _) => F.pure((L.empty, s, ())))

def setF[F[_], E, S, L, A](fs: F[S])(implicit F: Applicative[F], L: Monoid[L]): RWST[F, E, S, L, Unit] =
RWST((_, _) => F.map(fs)((L.empty, _, ())))

def tellF[F[_], E, S, L](fl: F[L])(implicit F: Applicative[F]): RWST[F, E, S, L, Unit] =
RWST((_, s) => F.map(fl)((_, s, ())))

def tell[F[_], E, S, L](l: L)(implicit F: Applicative[F]): RWST[F, E, S, L, Unit] =
RWST((_, s) => F.pure((l, s, ())))
}

/**
* Convenience functions for RWS.
*/
private[data] abstract class RWSFunctions {
def apply[E, S, L: Monoid, A](f: (E, S) => (L, S, A)): RWS[E, S, L, A] =
RWST.applyF(Now((e, s) => Now(f(e, s))))

def pure[E, S, L: Monoid, A](a: A): RWS[E, S, L, A] =
RWST.pure(a)

def modify[E, S, L: Monoid](f: S => S): RWS[E, S, L, Unit] =
RWST.modify(f)

def inspect[E, S, L: Monoid, T](f: S => T): RWS[E, S, L, T] =
RWST.inspect(f)

def get[E, S, L: Monoid]: RWS[E, S, L, S] =
RWST.get

def set[E, S, L: Monoid](s: S): RWS[E, S, L, Unit] =
RWST.set(s)

def tell[E, S, L](l: L): RWS[E, S, L, Unit] =
RWST.tell(l)
}

private[data] sealed trait RWSTInstances extends RWSTInstances1 {
implicit def catsDataMonadStateForRWST[F[_], E, S, L](implicit F0: Monad[F], L0: Monoid[L]): MonadState[RWST[F, E, S, L, ?], S] =
new RWSTMonadState[F, E, S, L] {
implicit def F: Monad[F] = F0
implicit def L: Monoid[L] = L0
}

implicit def catsDataLiftForRWST[E, S, L](implicit L0: Monoid[L]): TransLift.Aux[RWST[?[_], E, S, L, ?], Applicative] =
new RWSTTransLift[E, S, L] {
implicit def L: Monoid[L] = L0
}
}

private[data] sealed trait RWSTInstances1 extends RWSTInstances2 {
implicit def catsDataMonadCombineForRWST[F[_], E, S, L](implicit F0: MonadCombine[F], L0: Monoid[L]): MonadCombine[RWST[F, E, S, L, ?]] =
new RWSTMonadCombine[F, E, S, L] {
implicit def F: MonadCombine[F] = F0
implicit def L: Monoid[L] = L0
}
}

private[data] sealed trait RWSTInstances2 extends RWSTInstances3 {
implicit def catsDataMonadErrorForRWST[F[_], E, S, L, R](implicit F0: MonadError[F, R], L0: Monoid[L]): MonadError[RWST[F, E, S, L, ?], R] =
new RWSTMonadError[F, E, S, L, R] {
implicit def F: MonadError[F, R] = F0
implicit def L: Monoid[L] = L0
}

implicit def catsDataSemigroupKForRWST[F[_], E, S, L](implicit F0: Monad[F], G0: SemigroupK[F]): SemigroupK[RWST[F, E, S, L, ?]] =
new RWSTSemigroupK[F, E, S, L] {
implicit def F: Monad[F] = F0
implicit def G: SemigroupK[F] = G0
}
}

private[data] sealed trait RWSTInstances3 extends RWSTInstances4 {
implicit def catsDataMonadForRWST[F[_], E, S, L](implicit F0: Monad[F], L0: Monoid[L]): Monad[RWST[F, E, S, L, ?]] =
new RWSTMonad[F, E, S, L] {
implicit def F: Monad[F] = F0
implicit def L: Monoid[L] = L0
}
}

private[data] sealed trait RWSTInstances4 extends RWSTInstances5 {
implicit def catsDataFunctorForRWST[F[_], E, S, L](implicit F0: Functor[F]): Functor[RWST[F, E, S, L, ?]] =
new RWSTFunctor[F, E, S, L] {
implicit def F: Functor[F] = F0
}
}

private[data] sealed trait RWSTInstances5 extends RWSTInstances6 {
implicit def catsDataContravariantForRWST[F[_], S, L, A](implicit F0: Functor[F]): Contravariant[RWST[F, ?, S, L, A]] =
new RWSTContravariant[F, S, L, A] {
implicit def F: Functor[F] = F0
}
}

private[data] sealed trait RWSTInstances6 extends RWSTInstances7 {
implicit def catsDataBifunctorForRWST[F[_], E, S](implicit F0: Functor[F]): Bifunctor[RWST[F, E, S, ?, ?]] =
new RWSTBifunctor[F, E, S] {
implicit def F: Functor[F] = F0
}
}

private[data] sealed trait RWSTInstances7 {
implicit def catsDataProfunctorForRWST[F[_], S, L](implicit F0: Functor[F]): Profunctor[RWST[F, ?, S, L, ?]] =
new RWSTProfunctor[F, S, L] {
implicit def F: Functor[F] = F0
}
}

private[data] sealed trait RWSTFunctor[F[_], E, S, L] extends Functor[RWST[F, E, S, L, ?]] {
implicit def F: Functor[F]

def map[A, B](fa: RWST[F, E, S, L, A])(f: A => B): RWST[F, E, S, L, B] =
fa.map(f)
}

private[data] sealed trait RWSTContravariant[F[_], S, L, T] extends Contravariant[RWST[F, ?, S, L, T]] {
implicit def F: Functor[F]

override def contramap[A, B](fa: RWST[F, A, S, L, T])(f: B => A): RWST[F, B, S, L, T] =
fa.contramap(f)
}

private[data] sealed trait RWSTBifunctor[F[_], E, S] extends Bifunctor[RWST[F, E, S, ?, ?]] {
implicit def F: Functor[F]

override def bimap[A, B, C, D](fab: RWST[F, E, S, A, B])(f: A => C, g: B => D): RWST[F, E, S, C, D] = fab.mapWritten(f).map(g)
}

private[data] sealed trait RWSTProfunctor[F[_], S, L] extends Profunctor[RWST[F, ?, S, L, ?]] {
implicit def F: Functor[F]

override def dimap[A, B, C, D](fab: RWST[F, A, S, L, B])(f: C => A)(g: B => D): RWST[F, C, S, L, D] =
fab.contramap(f).map(g)
}

private[data] sealed trait RWSTMonad[F[_], E, S, L] extends Monad[RWST[F, E, S, L, ?]] {
implicit def F: Monad[F]
implicit def L: Monoid[L]

def pure[A](a: A): RWST[F, E, S, L, A] =
RWST.pure(a)

def flatMap[A, B](fa: RWST[F, E, S, L, A])(f: A => RWST[F, E, S, L, B]): RWST[F, E, S, L, B] =
fa.flatMap(f)

def tailRecM[A, B](a: A)(f: A => RWST[F, E, S, L, Either[A, B]]): RWST[F, E, S, L, B] =
RWST { (e, s) =>
F.tailRecM((s, a)) { case (s, a) =>
F.map(f(a).run(e, s)) { case (la, sa, a) =>
a.bimap((sa, _), (la, sa, _))
}
}
}

override def map[A, B](fa: RWST[F, E, S, L, A])(f: A => B): RWST[F, E, S, L, B] =
fa.map(f)

override def map2[A, B, Z](fa: RWST[F, E, S, L, A], fb: RWST[F, E, S, L, B])(f: (A, B) => Z): RWST[F, E, S, L, Z] =
fa.map2(fb)(f)

override def map2Eval[A, B, Z](fa: RWST[F, E, S, L, A], fb: Eval[RWST[F, E, S, L, B]])(f: (A, B) => Z): Eval[RWST[F, E, S, L, Z]] =
fa.map2Eval(fb)(f)
}

private[data] sealed trait RWSTMonadState[F[_], E, S, L] extends MonadState[RWST[F, E, S, L, ?], S] with RWSTMonad[F, E, S, L] {
lazy val get: RWST[F, E, S, L, S] = RWST.get

def set(s: S): RWST[F, E, S, L, Unit] = RWST.set(s)
}

private[data] sealed trait RWSTTransLift[E, S, L] extends TransLift[RWST[?[_], E, S, L, ?]] {
implicit def L: Monoid[L]
type TC[F[_]] = Applicative[F]

def liftT[F[_], A](fa: F[A])(implicit F: Applicative[F]): RWST[F, E, S, L, A] =
RWST { (e, s) =>
F.map(fa)((L.empty, s, _))
}
}

private[data] sealed trait RWSTSemigroupK[F[_], E, S, L] extends SemigroupK[RWST[F, E, S, L, ?]] {
implicit def F: Monad[F]
implicit def G: SemigroupK[F]

def combineK[A](x: RWST[F, E, S, L, A], y: RWST[F, E, S, L, A]): RWST[F, E, S, L, A] =
RWST { (e, s) =>
G.combineK(x.run(e, s), y.run(e, s))
}
}

private[data] sealed trait RWSTMonadCombine[F[_], E, S, L] extends MonadCombine[RWST[F, E, S, L, ?]] with RWSTMonad[F, E, S, L]
with RWSTSemigroupK[F, E, S, L] with RWSTTransLift[E, S, L] {
implicit def F: MonadCombine[F]
override def G: MonadCombine[F] = F

def empty[A]: RWST[F, E, S, L, A] = liftT[F, A](F.empty[A])
}

private[data] sealed trait RWSTMonadError[F[_], E, S, L, R] extends RWSTMonad[F, E, S, L] with MonadError[RWST[F, E, S, L, ?], R] {
implicit def F: MonadError[F, R]

def raiseError[A](r: R): RWST[F, E, S, L, A] = RWST.lift(F.raiseError(r))

def handleErrorWith[A](fa: RWST[F, E, S, L, A])(f: R => RWST[F, E, S, L, A]): RWST[F, E, S, L, A] =
RWST { (e, s) =>
F.handleErrorWith(fa.run(e, s))(r => f(r).run(e, s))
}
}
3 changes: 3 additions & 0 deletions core/src/main/scala/cats/data/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,7 @@ package object data {

type State[S, A] = StateT[Eval, S, A]
object State extends StateFunctions

type RWS[E, S, L, A] = RWST[Eval, E, S, L, A]
object RWS extends RWSFunctions
}

0 comments on commit 40ae965

Please sign in to comment.