Skip to content

Commit

Permalink
Convert StateT to IndexedStateT
Browse files Browse the repository at this point in the history
Resolves #1773.
  • Loading branch information
Itamar Ravid committed Jul 31, 2017
1 parent 1424053 commit 289c538
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 128 deletions.
229 changes: 151 additions & 78 deletions core/src/main/scala/cats/data/StateT.scala
Original file line number Diff line number Diff line change
@@ -1,75 +1,90 @@
package cats
package data

import cats.functor.{ Contravariant, Profunctor }
import cats.syntax.either._

/**
* `StateT[F, S, A]` is similar to `Kleisli[F, S, A]` in that it takes an `S`
* argument and produces an `A` value wrapped in `F`. However, it also produces
* an `S` value representing the updated state (which is wrapped in the `F`
* context along with the `A` value.
*
* `IndexedStateT[F, SA, SB, A]` is a stateful computation in a context `F` yielding
* a value of type `A`. Its state transitions from a value of type `SA` to a value
* of type `SB`.
*
* Given `IndexedStateT[F, S, S, A]`, this yields the plain `StateT[F, S, A]`.
*
* Note that `IndexedStateT[F, SA, SB, A]` is not a monad, but an indexed monad.
*/
final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable {
final class IndexedStateT[F[_], SA, SB, A](val runF: F[SA => F[(SB, A)]]) extends Serializable {

def flatMap[B](fas: A => StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, B] =
StateT.applyF(F.map(runF) { sfsa =>
sfsa.andThen { fsa =>
F.flatMap(fsa) { case (s, a) =>
fas(a).run(s)
def flatMap[B, SC](fas: A => IndexedStateT[F, SB, SC, B])(implicit F: FlatMap[F]): IndexedStateT[F, SA, SC, B] =
IndexedStateT.applyF(F.map(runF) { safsba =>
safsba.andThen { fsba =>
F.flatMap(fsba) { case (sb, a) =>
fas(a).run(sb)
}
}
})

def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): StateT[F, S, B] =
StateT.applyF(F.map(runF) { sfsa =>
def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): IndexedStateT[F, SA, SB, B] =
IndexedStateT.applyF(F.map(runF) { sfsa =>
sfsa.andThen { fsa =>
F.flatMap(fsa) { case (s, a) => F.map(faf(a))((s, _)) }
}
})

def map[B](f: A => B)(implicit F: Functor[F]): StateT[F, S, B] =
def map[B](f: A => B)(implicit F: Functor[F]): IndexedStateT[F, SA, SB, B] =
transform { case (s, a) => (s, f(a)) }

def contramap[S0](f: S0 => SA)(implicit F: Monad[F]): IndexedStateT[F, S0, SB, A] =
IndexedStateT.apply { s0 =>
F.flatMap(runF) { safsba =>
safsba(f(s0))
}
}

def dimap[S0, S1](f: S0 => SA)(g: SB => S1)(implicit F: Monad[F]): IndexedStateT[F, S0, S1, A] =
contramap(f).modify(g)

/**
* Run with the provided initial state value
*/
def run(initial: S)(implicit F: FlatMap[F]): F[(S, A)] =
def run(initial: SA)(implicit F: FlatMap[F]): F[(SB, A)] =
F.flatMap(runF)(f => f(initial))

/**
* Run with the provided initial state value and return the final state
* (discarding the final value).
*/
def runS(s: S)(implicit F: FlatMap[F]): F[S] = F.map(run(s))(_._1)
def runS(s: SA)(implicit F: FlatMap[F]): F[SB] = F.map(run(s))(_._1)

/**
* Run with the provided initial state value and return the final value
* (discarding the final state).
*/
def runA(s: S)(implicit F: FlatMap[F]): F[A] = F.map(run(s))(_._2)
def runA(s: SA)(implicit F: FlatMap[F]): F[A] = F.map(run(s))(_._2)

/**
* Run with `S`'s empty monoid value as the initial state.
*/
def runEmpty(implicit S: Monoid[S], F: FlatMap[F]): F[(S, A)] = run(S.empty)
def runEmpty(implicit S: Monoid[SA], F: FlatMap[F]): F[(SB, A)] = run(S.empty)

/**
* Run with `S`'s empty monoid value as the initial state and return the final
* state (discarding the final value).
*/
def runEmptyS(implicit S: Monoid[S], F: FlatMap[F]): F[S] = runS(S.empty)
def runEmptyS(implicit S: Monoid[SA], F: FlatMap[F]): F[SB] = runS(S.empty)

/**
* Run with `S`'s empty monoid value as the initial state and return the final
* value (discarding the final state).
*/
def runEmptyA(implicit S: Monoid[S], F: FlatMap[F]): F[A] = runA(S.empty)
def runEmptyA(implicit S: Monoid[SA], F: FlatMap[F]): F[A] = runA(S.empty)

/**
* Like [[map]], but also allows the state (`S`) value to be modified.
*/
def transform[B](f: (S, A) => (S, B))(implicit F: Functor[F]): StateT[F, S, B] =
StateT.applyF(
def transform[B, SC](f: (SB, A) => (SC, B))(implicit F: Functor[F]): IndexedStateT[F, SA, SC, B] =
IndexedStateT.applyF(
F.map(runF) { sfsa =>
sfsa.andThen { fsa =>
F.map(fsa) { case (s, a) => f(s, a) }
Expand All @@ -79,8 +94,8 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable
/**
* Like [[transform]], but allows the context to change from `F` to `G`.
*/
def transformF[G[_], B](f: F[(S, A)] => G[(S, B)])(implicit F: FlatMap[F], G: Applicative[G]): StateT[G, S, B] =
StateT(s => f(run(s)))
def transformF[G[_], B, SC](f: F[(SB, A)] => G[(SC, B)])(implicit F: FlatMap[F], G: Applicative[G]): IndexedStateT[G, SA, SC, B] =
IndexedStateT(s => f(run(s)))

/**
* Transform the state used.
Expand All @@ -100,98 +115,139 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable
* res1: Option[(GlobalEnv, Double)] = Some(((6,hello),5.0))
* }}}
*/
def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Functor[F]): StateT[F, R, A] =
def transformS[R](f: R => SA, g: (R, SB) => R)(implicit F: Functor[F]): IndexedStateT[F, R, R, A] =
StateT.applyF(F.map(runF) { sfsa =>
{ r: R =>
val s = f(r)
val fsa = sfsa(s)
F.map(fsa) { case (s, a) => (g(r, s), a) }
val sa = f(r)
val fsba = sfsa(sa)
F.map(fsba) { case (sb, a) => (g(r, sb), a) }
}
})

/**
* Modify the state (`S`) component.
*/
def modify(f: S => S)(implicit F: Functor[F]): StateT[F, S, A] =
def modify[SC](f: SB => SC)(implicit F: Functor[F]): IndexedStateT[F, SA, SC, A] =
transform((s, a) => (f(s), a))

/**
* Inspect a value from the input state, without modifying the state.
*/
def inspect[B](f: S => B)(implicit F: Functor[F]): StateT[F, S, B] =
def inspect[B](f: SB => B)(implicit F: Functor[F]): IndexedStateT[F, SA, SB, B] =
transform((s, _) => (s, f(s)))

/**
* Get the input state, without modifying the state.
*/
def get(implicit F: Functor[F]): StateT[F, S, S] =
def get(implicit F: Functor[F]): IndexedStateT[F, SA, SB, SB] =
inspect(identity)
}

object StateT extends StateTInstances {
object IndexedStateT extends IndexedStateTInstances {
def apply[F[_], SA, SB, A](f: SA => F[(SB, A)])(implicit F: Applicative[F]): IndexedStateT[F, SA, SB, A] =
new IndexedStateT(F.pure(f))

def applyF[F[_], SA, SB, A](runF: F[SA => F[(SB, A)]]): IndexedStateT[F, SA, SB, A] =
new IndexedStateT(runF)

def pure[F[_], S, A](a: A)(implicit F: Applicative[F]): IndexedStateT[F, S, S, A] =
IndexedStateT(s => F.pure((s, a)))

def lift[F[_], S, A](fa: F[A])(implicit F: Applicative[F]): IndexedStateT[F, S, S, A] =
IndexedStateT(s => F.map(fa)(a => (s, a)))

def inspect[F[_], S, A](f: S => A)(implicit F: Applicative[F]): IndexedStateT[F, S, S, A] =
IndexedStateT(s => F.pure((s, f(s))))

def inspectF[F[_], S, A](f: S => F[A])(implicit F: Applicative[F]): IndexedStateT[F, S, S, A] =
IndexedStateT(s => F.map(f(s))(a => (s, a)))

def modify[F[_], SA, SB](f: SA => SB)(implicit F: Applicative[F]): IndexedStateT[F, SA, SB, Unit] =
IndexedStateT(sa => F.pure((f(sa), ())))

def modifyF[F[_], SA, SB](f: SA => F[SB])(implicit F: Applicative[F]): IndexedStateT[F, SA, SB, Unit] =
IndexedStateT(s => F.map(f(s))(s => (s, ())))

def get[F[_], S](implicit F: Applicative[F]): IndexedStateT[F, S, S, S] =
IndexedStateT(s => F.pure((s, s)))

def set[F[_], SA, SB](sb: SB)(implicit F: Applicative[F]): IndexedStateT[F, SA, SB, Unit] =
IndexedStateT(_ => F.pure((sb, ())))

def setF[F[_], SA, SB](fsb: F[SB])(implicit F: Applicative[F]): IndexedStateT[F, SA, SB, Unit] =
IndexedStateT(_ => F.map(fsb)(s => (s, ())))
}

private[data] abstract class StateTFunctions {
def apply[F[_], S, A](f: S => F[(S, A)])(implicit F: Applicative[F]): StateT[F, S, A] =
new StateT(F.pure(f))
IndexedStateT(f)

def applyF[F[_], S, A](runF: F[S => F[(S, A)]]): StateT[F, S, A] =
new StateT(runF)
IndexedStateT.applyF(runF)

def pure[F[_], S, A](a: A)(implicit F: Applicative[F]): StateT[F, S, A] =
StateT(s => F.pure((s, a)))
apply(s => F.pure((s, a)))

def lift[F[_], S, A](fa: F[A])(implicit F: Applicative[F]): StateT[F, S, A] =
StateT(s => F.map(fa)(a => (s, a)))
apply(s => F.map(fa)(a => (s, a)))

def inspect[F[_], S, A](f: S => A)(implicit F: Applicative[F]): StateT[F, S, A] =
StateT(s => F.pure((s, f(s))))
apply(s => F.pure((s, f(s))))

def inspectF[F[_], S, A](f: S => F[A])(implicit F: Applicative[F]): StateT[F, S, A] =
StateT(s => F.map(f(s))(a => (s, a)))
apply(s => F.map(f(s))(a => (s, a)))

def modify[F[_], S](f: S => S)(implicit F: Applicative[F]): StateT[F, S, Unit] =
StateT(s => F.pure((f(s), ())))
apply(sa => F.pure((f(sa), ())))

def modifyF[F[_], S](f: S => F[S])(implicit F: Applicative[F]): StateT[F, S, Unit] =
StateT(s => F.map(f(s))(s => (s, ())))
apply(s => F.map(f(s))(s => (s, ())))

def get[F[_], S](implicit F: Applicative[F]): StateT[F, S, S] =
StateT(s => F.pure((s, s)))
apply(s => F.pure((s, s)))

def set[F[_], S](s: S)(implicit F: Applicative[F]): StateT[F, S, Unit] =
StateT(_ => F.pure((s, ())))
apply(_ => F.pure((s, ())))

def setF[F[_], S](fs: F[S])(implicit F: Applicative[F]): StateT[F, S, Unit] =
StateT(_ => F.map(fs)(s => (s, ())))
apply(_ => F.map(fs)(s => (s, ())))
}

private[data] sealed trait StateTInstances extends StateTInstances1 {
implicit def catsDataAlternativeForStateT[F[_], S](implicit FM: Monad[F], FA: Alternative[F]): Alternative[StateT[F, S, ?]] =
new StateTAlternative[F, S] { implicit def F = FM; implicit def G = FA }
private[data] sealed trait IndexedStateTInstances extends IndexedStateTInstances1 {
implicit def catsDataAlternativeForIndexedStateT[F[_], S](implicit FM: Monad[F], FA: Alternative[F]): Alternative[StateT[F, S, ?]] =
new IndexedStateTAlternative[F, S] { implicit def F = FM; implicit def G = FA }
}

private[data] sealed trait StateTInstances1 extends StateTInstances2 {
implicit def catsDataMonadErrorForStateT[F[_], S, E](implicit F0: MonadError[F, E]): MonadError[StateT[F, S, ?], E] =
new StateTMonadError[F, S, E] { implicit def F = F0 }
private[data] sealed trait IndexedStateTInstances1 extends IndexedStateTInstances2 {
implicit def catsDataMonadErrorForIndexedStateT[F[_], S, E](implicit F0: MonadError[F, E]): MonadError[IndexedStateT[F, S, S, ?], E] =
new IndexedStateTMonadError[F, S, E] { implicit def F = F0 }

implicit def catsDataSemigroupKForStateT[F[_], S](implicit F0: Monad[F], G0: SemigroupK[F]): SemigroupK[StateT[F, S, ?]] =
new StateTSemigroupK[F, S] { implicit def F = F0; implicit def G = G0 }
implicit def catsDataSemigroupKForIndexedStateT[F[_], SA, SB](implicit F0: Monad[F], G0: SemigroupK[F]): SemigroupK[IndexedStateT[F, SA, SB, ?]] =
new IndexedStateTSemigroupK[F, SA, SB] { implicit def F = F0; implicit def G = G0 }
}

private[data] sealed trait StateTInstances2 extends StateTInstances3 {
implicit def catsDataMonadForStateT[F[_], S](implicit F0: Monad[F]): Monad[StateT[F, S, ?]] =
new StateTMonad[F, S] { implicit def F = F0 }
private[data] sealed trait IndexedStateTInstances2 extends IndexedStateTInstances3 {
implicit def catsDataMonadForIndexedStateT[F[_], S](implicit F0: Monad[F]): Monad[IndexedStateT[F, S, S, ?]] =
new IndexedStateTMonad[F, S] { implicit def F = F0 }
}

private[data] sealed trait StateTInstances3 {
implicit def catsDataFunctorForStateT[F[_], S](implicit F0: Functor[F]): Functor[StateT[F, S, ?]] =
new StateTFunctor[F, S] { implicit def F = F0 }
private[data] sealed trait IndexedStateTInstances3 {
implicit def catsDataFunctorForIndexedStateT[F[_], SA, SB](implicit F0: Functor[F]): Functor[IndexedStateT[F, SA, SB, ?]] =
new IndexedStateTFunctor[F, SA, SB] { implicit def F = F0 }

implicit def catsDataContravariantForIndexedStateT[F[_], SB, V](implicit F0: Monad[F]): Contravariant[IndexedStateT[F, ?, SB, V]] =
new IndexedStateTContravariant[F, SB, V] { implicit def F = F0 }

implicit def catsDataProfunctorForIndexedStateT[F[_], V](implicit F0: Monad[F]): Profunctor[IndexedStateT[F, ?, ?, V]] =
new IndexedStateTProfunctor[F, V] { implicit def F = F0 }
}

// To workaround SI-7139 `object State` needs to be defined inside the package object
// together with the type alias.
private[data] abstract class StateFunctions {

def apply[S, A](f: S => (S, A)): State[S, A] =
StateT.applyF(Now((s: S) => Now(f(s))))
IndexedStateT.applyF(Now((s: S) => Now(f(s))))

/**
* Return `a` and maintain the input state.
Expand Down Expand Up @@ -219,62 +275,79 @@ private[data] abstract class StateFunctions {
def set[S](s: S): State[S, Unit] = State(_ => (s, ()))
}

private[data] sealed trait StateTFunctor[F[_], S] extends Functor[StateT[F, S, ?]] {
private[data] sealed trait IndexedStateTFunctor[F[_], SA, SB] extends Functor[IndexedStateT[F, SA, SB, ?]] {
implicit def F: Functor[F]

override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f)
override def map[A, B](fa: IndexedStateT[F, SA, SB, A])(f: A => B): IndexedStateT[F, SA, SB, B] =
fa.map(f)
}

private[data] sealed trait IndexedStateTContravariant[F[_], SB, V] extends Contravariant[IndexedStateT[F, ?, SB, V]] {
implicit def F: Monad[F]

override def contramap[A, B](fa: IndexedStateT[F, A, SB, V])(f: B => A): IndexedStateT[F, B, SB, V] =
fa.contramap(f)
}

private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] with StateTFunctor[F, S] {
private[data] sealed trait IndexedStateTProfunctor[F[_], V] extends Profunctor[IndexedStateT[F, ?, ?, V]] {
implicit def F: Monad[F]

def dimap[A, B, C, D](fab: IndexedStateT[F, A, B, V])(f: C => A)(g: B => D): IndexedStateT[F, C, D, V] =
fab.dimap(f)(g)
}

private[data] sealed trait IndexedStateTMonad[F[_], S] extends Monad[IndexedStateT[F, S, S, ?]]
with IndexedStateTFunctor[F, S, S] {
implicit def F: Monad[F]

def pure[A](a: A): StateT[F, S, A] =
StateT.pure(a)
IndexedStateT.pure(a)

def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] =
fa.flatMap(f)

def tailRecM[A, B](a: A)(f: A => StateT[F, S, Either[A, B]]): StateT[F, S, B] =
StateT[F, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) {
IndexedStateT[F, S, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) {
case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) }
})
}

private[data] sealed trait StateTSemigroupK[F[_], S] extends SemigroupK[StateT[F, S, ?]] {
private[data] sealed trait IndexedStateTSemigroupK[F[_], SA, SB] extends SemigroupK[IndexedStateT[F, SA, SB, ?]] {
implicit def F: Monad[F]
implicit def G: SemigroupK[F]

def combineK[A](x: StateT[F, S, A], y: StateT[F, S, A]): StateT[F, S, A] =
StateT(s => G.combineK(x.run(s), y.run(s)))
def combineK[A](x: IndexedStateT[F, SA, SB, A], y: IndexedStateT[F, SA, SB, A]): IndexedStateT[F, SA, SB, A] =
IndexedStateT(s => G.combineK(x.run(s), y.run(s)))
}

private[data] sealed trait StateTAlternative[F[_], S] extends Alternative[StateT[F, S, ?]] with StateTFunctor[F, S] {
private[data] sealed trait IndexedStateTAlternative[F[_], S] extends Alternative[IndexedStateT[F, S, S, ?]] with IndexedStateTFunctor[F, S, S] {
implicit def F: Monad[F]
def G: Alternative[F]

def combineK[A](x: StateT[F, S, A], y: StateT[F, S, A]): StateT[F, S, A] =
StateT[F, S, A](s => G.combineK(x.run(s), y.run(s)))(G)
def combineK[A](x: IndexedStateT[F, S, S, A], y: IndexedStateT[F, S, S, A]): IndexedStateT[F, S, S, A] =
IndexedStateT[F, S, S, A](s => G.combineK(x.run(s), y.run(s)))(G)

def pure[A](a: A): StateT[F, S, A] =
StateT.pure[F, S, A](a)(G)
def pure[A](a: A): IndexedStateT[F, S, S, A] =
IndexedStateT.pure[F, S, A](a)(G)

def empty[A]: StateT[F, S, A] =
StateT.lift[F, S, A](G.empty[A])(G)
def empty[A]: IndexedStateT[F, S, S, A] =
IndexedStateT.lift[F, S, A](G.empty[A])(G)

override def ap[A, B](ff: StateT[F, S, A => B])(fa: StateT[F, S, A]): StateT[F, S, B] =
StateT[F, S, B]((s: S) =>
override def ap[A, B](ff: IndexedStateT[F, S, S, A => B])(fa: IndexedStateT[F, S, S, A]): IndexedStateT[F, S, S, B] =
IndexedStateT[F, S, S, B]((s: S) =>
F.flatMap(ff.run(s)) { sab =>
val (sn, f) = sab
F.map(fa.run(sn)) { case (snn, a) => (snn, f(a)) }
}
)
}

private[data] sealed trait StateTMonadError[F[_], S, E] extends StateTMonad[F, S] with MonadError[StateT[F, S, ?], E] {
private[data] sealed trait IndexedStateTMonadError[F[_], S, E] extends IndexedStateTMonad[F, S]
with MonadError[IndexedStateT[F, S, S, ?], E] {
implicit def F: MonadError[F, E]

def raiseError[A](e: E): StateT[F, S, A] = StateT.lift(F.raiseError(e))
def raiseError[A](e: E): IndexedStateT[F, S, S, A] = IndexedStateT.lift(F.raiseError(e))

def handleErrorWith[A](fa: StateT[F, S, A])(f: E => StateT[F, S, A]): StateT[F, S, A] =
StateT(s => F.handleErrorWith(fa.run(s))(e => f(e).run(s)))
def handleErrorWith[A](fa: IndexedStateT[F, S, S, A])(f: E => IndexedStateT[F, S, S, A]): IndexedStateT[F, S, S, A] =
IndexedStateT(s => F.handleErrorWith(fa.run(s))(e => f(e).run(s)))
}
Loading

0 comments on commit 289c538

Please sign in to comment.