diff --git a/core/src/main/scala/cats/data/IndexedStateT.scala b/core/src/main/scala/cats/data/IndexedStateT.scala new file mode 100644 index 0000000000..f2d855b07e --- /dev/null +++ b/core/src/main/scala/cats/data/IndexedStateT.scala @@ -0,0 +1,375 @@ +package cats +package data + +import cats.functor.{ Contravariant, Bifunctor, Profunctor, Strong } +import cats.syntax.either._ + +/** + * + * `IndexedStateT[F, SA, SB, A]` is a stateful computation in a context `F` yielding + * a value of type `A`. The state transitions from a value of type `SA` to a value + * of type `SB`. + * + * Note that for the `SA != SB` case, this is an indexed monad. Indexed monads + * are monadic type constructors annotated by an additional type for effect + * tracking purposes. In this case, the annotation tracks the initial state and + * the resulting state. + * + * Given `IndexedStateT[F, S, S, A]`, this yields the `StateT[F, S, A]` monad. + */ +final class IndexedStateT[F[_], SA, SB, A](val runF: F[SA => F[(SB, A)]]) extends Serializable { + + def flatMap[B, SC](fas: A => IndexedStateT[F, SB, SC, B])(implicit F: FlatMap[F]): IndexedStateT[F, SA, SC, B] = + IndexedStateT.applyF(F.map(runF) { safsba => + safsba.andThen { fsba => + F.flatMap(fsba) { case (sb, a) => + fas(a).run(sb) + } + } + }) + + def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): IndexedStateT[F, SA, SB, B] = + IndexedStateT.applyF(F.map(runF) { sfsa => + sfsa.andThen { fsa => + F.flatMap(fsa) { case (s, a) => F.map(faf(a))((s, _)) } + } + }) + + def map[B](f: A => B)(implicit F: Functor[F]): IndexedStateT[F, SA, SB, B] = + transform { case (s, a) => (s, f(a)) } + + def contramap[S0](f: S0 => SA)(implicit F: Functor[F]): IndexedStateT[F, S0, SB, A] = + IndexedStateT.applyF { + F.map(runF) { safsba => + (s0: S0) => safsba(f(s0)) + } + } + + def bimap[SC, B](f: SB => SC, g: A => B)(implicit F: Functor[F]): IndexedStateT[F, SA, SC, B] = + transform { (s, a) => (f(s), g(a)) } + + def dimap[S0, S1](f: S0 => SA)(g: SB => S1)(implicit F: Functor[F]): IndexedStateT[F, S0, S1, A] = + contramap(f).modify(g) + + /** + * Run with the provided initial state value + */ + def run(initial: SA)(implicit F: FlatMap[F]): F[(SB, A)] = + F.flatMap(runF)(f => f(initial)) + + /** + * Run with the provided initial state value and return the final state + * (discarding the final value). + */ + def runS(s: SA)(implicit F: FlatMap[F]): F[SB] = F.map(run(s))(_._1) + + /** + * Run with the provided initial state value and return the final value + * (discarding the final state). + */ + def runA(s: SA)(implicit F: FlatMap[F]): F[A] = F.map(run(s))(_._2) + + /** + * Run with `S`'s empty monoid value as the initial state. + */ + def runEmpty(implicit S: Monoid[SA], F: FlatMap[F]): F[(SB, A)] = run(S.empty) + + /** + * Run with `S`'s empty monoid value as the initial state and return the final + * state (discarding the final value). + */ + def runEmptyS(implicit S: Monoid[SA], F: FlatMap[F]): F[SB] = runS(S.empty) + + /** + * Run with `S`'s empty monoid value as the initial state and return the final + * value (discarding the final state). + */ + def runEmptyA(implicit S: Monoid[SA], F: FlatMap[F]): F[A] = runA(S.empty) + + /** + * Like [[map]], but also allows the state (`S`) value to be modified. + */ + def transform[B, SC](f: (SB, A) => (SC, B))(implicit F: Functor[F]): IndexedStateT[F, SA, SC, B] = + IndexedStateT.applyF( + F.map(runF) { sfsa => + sfsa.andThen { fsa => + F.map(fsa) { case (s, a) => f(s, a) } + } + }) + + /** + * Like [[transform]], but allows the context to change from `F` to `G`. + * + * {{{ + * scala> import cats.implicits._ + * scala> type ErrorOr[A] = Either[String, A] + * scala> val xError: IndexedStateT[ErrorOr, Int, Int, Int] = IndexedStateT.get + * scala> val xOpt: IndexedStateT[Option, Int, Int, Int] = xError.transformF(_.toOption) + * scala> val input = 5 + * scala> xError.run(input) + * res0: ErrorOr[(Int, Int)] = Right((5,5)) + * scala> xOpt.run(5) + * res1: Option[(Int, Int)] = Some((5,5)) + * }}} + */ + def transformF[G[_], B, SC](f: F[(SB, A)] => G[(SC, B)])(implicit F: FlatMap[F], G: Applicative[G]): IndexedStateT[G, SA, SC, B] = + IndexedStateT(s => f(run(s))) + + /** + * Transform the state used. + * + * This is useful when you are working with many focused `StateT`s and want to pass in a + * global state containing the various states needed for each individual `StateT`. + * + * {{{ + * scala> import cats.implicits._ // needed for StateT.apply + * scala> type GlobalEnv = (Int, String) + * scala> val x: StateT[Option, Int, Double] = StateT((x: Int) => Option((x + 1, x.toDouble))) + * scala> val xt: StateT[Option, GlobalEnv, Double] = x.transformS[GlobalEnv](_._1, (t, i) => (i, t._2)) + * scala> val input = 5 + * scala> x.run(input) + * res0: Option[(Int, Double)] = Some((6,5.0)) + * scala> xt.run((input, "hello")) + * res1: Option[(GlobalEnv, Double)] = Some(((6,hello),5.0)) + * }}} + */ + def transformS[R](f: R => SA, g: (R, SB) => R)(implicit F: Functor[F]): IndexedStateT[F, R, R, A] = + StateT.applyF(F.map(runF) { sfsa => + { r: R => + val sa = f(r) + val fsba = sfsa(sa) + F.map(fsba) { case (sb, a) => (g(r, sb), a) } + } + }) + + /** + * Modify the state (`S`) component. + */ + def modify[SC](f: SB => SC)(implicit F: Functor[F]): IndexedStateT[F, SA, SC, A] = + transform((s, a) => (f(s), a)) + + /** + * Inspect a value from the input state, without modifying the state. + */ + def inspect[B](f: SB => B)(implicit F: Functor[F]): IndexedStateT[F, SA, SB, B] = + transform((s, _) => (s, f(s))) + + /** + * Get the input state, without modifying the state. + */ + def get(implicit F: Functor[F]): IndexedStateT[F, SA, SB, SB] = + inspect(identity) +} + +private[data] trait CommonStateTConstructors { + def pure[F[_], S, A](a: A)(implicit F: Applicative[F]): IndexedStateT[F, S, S, A] = + IndexedStateT(s => F.pure((s, a))) + + def lift[F[_], S, A](fa: F[A])(implicit F: Applicative[F]): IndexedStateT[F, S, S, A] = + IndexedStateT(s => F.map(fa)(a => (s, a))) + + def inspect[F[_], S, A](f: S => A)(implicit F: Applicative[F]): IndexedStateT[F, S, S, A] = + IndexedStateT(s => F.pure((s, f(s)))) + + def inspectF[F[_], S, A](f: S => F[A])(implicit F: Applicative[F]): IndexedStateT[F, S, S, A] = + IndexedStateT(s => F.map(f(s))(a => (s, a))) + + def get[F[_], S](implicit F: Applicative[F]): IndexedStateT[F, S, S, S] = + IndexedStateT(s => F.pure((s, s))) +} + +object IndexedStateT extends IndexedStateTInstances with CommonStateTConstructors { + def apply[F[_], SA, SB, A](f: SA => F[(SB, A)])(implicit F: Applicative[F]): IndexedStateT[F, SA, SB, A] = + new IndexedStateT(F.pure(f)) + + def applyF[F[_], SA, SB, A](runF: F[SA => F[(SB, A)]]): IndexedStateT[F, SA, SB, A] = + new IndexedStateT(runF) + + def modify[F[_], SA, SB](f: SA => SB)(implicit F: Applicative[F]): IndexedStateT[F, SA, SB, Unit] = + IndexedStateT(sa => F.pure((f(sa), ()))) + + def modifyF[F[_], SA, SB](f: SA => F[SB])(implicit F: Applicative[F]): IndexedStateT[F, SA, SB, Unit] = + IndexedStateT(s => F.map(f(s))(s => (s, ()))) + + def set[F[_], SA, SB](sb: SB)(implicit F: Applicative[F]): IndexedStateT[F, SA, SB, Unit] = + IndexedStateT(_ => F.pure((sb, ()))) + + def setF[F[_], SA, SB](fsb: F[SB])(implicit F: Applicative[F]): IndexedStateT[F, SA, SB, Unit] = + IndexedStateT(_ => F.map(fsb)(s => (s, ()))) +} + +private[data] abstract class StateTFunctions extends CommonStateTConstructors { + def apply[F[_], S, A](f: S => F[(S, A)])(implicit F: Applicative[F]): StateT[F, S, A] = + IndexedStateT(f) + + def applyF[F[_], S, A](runF: F[S => F[(S, A)]]): StateT[F, S, A] = + IndexedStateT.applyF(runF) + + def modify[F[_], S](f: S => S)(implicit F: Applicative[F]): StateT[F, S, Unit] = + apply(sa => F.pure((f(sa), ()))) + + def modifyF[F[_], S](f: S => F[S])(implicit F: Applicative[F]): StateT[F, S, Unit] = + apply(s => F.map(f(s))(s => (s, ()))) + + def set[F[_], S](s: S)(implicit F: Applicative[F]): StateT[F, S, Unit] = + apply(_ => F.pure((s, ()))) + + def setF[F[_], S](fs: F[S])(implicit F: Applicative[F]): StateT[F, S, Unit] = + apply(_ => F.map(fs)(s => (s, ()))) +} + +private[data] sealed abstract class IndexedStateTInstances extends IndexedStateTInstances1 { + implicit def catsDataAlternativeForIndexedStateT[F[_], S](implicit FM: Monad[F], + FA: Alternative[F]): Alternative[IndexedStateT[F, S, S, ?]] with Monad[IndexedStateT[F, S, S, ?]] = + new IndexedStateTAlternative[F, S] { implicit def F = FM; implicit def G = FA } +} + +private[data] sealed abstract class IndexedStateTInstances1 extends IndexedStateTInstances2 { + implicit def catsDataMonadErrorForIndexedStateT[F[_], S, E](implicit F0: MonadError[F, E]): MonadError[IndexedStateT[F, S, S, ?], E] = + new IndexedStateTMonadError[F, S, E] { implicit def F = F0 } + + implicit def catsDataSemigroupKForIndexedStateT[F[_], SA, SB](implicit F0: Monad[F], G0: SemigroupK[F]): SemigroupK[IndexedStateT[F, SA, SB, ?]] = + new IndexedStateTSemigroupK[F, SA, SB] { implicit def F = F0; implicit def G = G0 } +} + +private[data] sealed abstract class IndexedStateTInstances2 extends IndexedStateTInstances3 { + implicit def catsDataMonadForIndexedStateT[F[_], S](implicit F0: Monad[F]): Monad[IndexedStateT[F, S, S, ?]] = + new IndexedStateTMonad[F, S] { implicit def F = F0 } +} + +private[data] sealed abstract class IndexedStateTInstances3 extends IndexedStateTInstances4 { + implicit def catsDataFunctorForIndexedStateT[F[_], SA, SB](implicit F0: Functor[F]): Functor[IndexedStateT[F, SA, SB, ?]] = + new IndexedStateTFunctor[F, SA, SB] { implicit def F = F0 } + + implicit def catsDataContravariantForIndexedStateT[F[_], SB, V](implicit F0: Functor[F]): Contravariant[IndexedStateT[F, ?, SB, V]] = + new IndexedStateTContravariant[F, SB, V] { implicit def F = F0 } + + implicit def catsDataProfunctorForIndexedStateT[F[_], V](implicit F0: Functor[F]): Profunctor[IndexedStateT[F, ?, ?, V]] = + new IndexedStateTProfunctor[F, V] { implicit def F = F0 } + + implicit def catsDataBifunctorForIndexedStateT[F[_], SA](implicit F0: Functor[F]): Bifunctor[IndexedStateT[F, SA, ?, ?]] = + new IndexedStateTBifunctor[F, SA] { implicit def F = F0 } +} + +private[data] sealed abstract class IndexedStateTInstances4 { + implicit def catsDataStrongForIndexedStateT[F[_], V](implicit F0: Monad[F]): Strong[IndexedStateT[F, ?, ?, V]] = + new IndexedStateTStrong[F, V] { implicit def F = F0 } +} + +// To workaround SI-7139 `object State` needs to be defined inside the package object +// together with the type alias. +private[data] abstract class StateFunctions { + + def apply[S, A](f: S => (S, A)): State[S, A] = + IndexedStateT.applyF(Now((s: S) => Now(f(s)))) + + /** + * Return `a` and maintain the input state. + */ + def pure[S, A](a: A): State[S, A] = State(s => (s, a)) + + /** + * Modify the input state and return Unit. + */ + def modify[S](f: S => S): State[S, Unit] = State(s => (f(s), ())) + + /** + * Inspect a value from the input state, without modifying the state. + */ + def inspect[S, T](f: S => T): State[S, T] = State(s => (s, f(s))) + + /** + * Return the input state without modifying it. + */ + def get[S]: State[S, S] = inspect(identity) + + /** + * Set the state to `s` and return Unit. + */ + def set[S](s: S): State[S, Unit] = State(_ => (s, ())) +} + +private[data] sealed abstract class IndexedStateTFunctor[F[_], SA, SB] extends Functor[IndexedStateT[F, SA, SB, ?]] { + implicit def F: Functor[F] + + override def map[A, B](fa: IndexedStateT[F, SA, SB, A])(f: A => B): IndexedStateT[F, SA, SB, B] = + fa.map(f) +} + +private[data] sealed abstract class IndexedStateTContravariant[F[_], SB, V] extends Contravariant[IndexedStateT[F, ?, SB, V]] { + implicit def F: Functor[F] + + override def contramap[A, B](fa: IndexedStateT[F, A, SB, V])(f: B => A): IndexedStateT[F, B, SB, V] = + fa.contramap(f) +} + +private[data] sealed abstract class IndexedStateTBifunctor[F[_], SA] extends Bifunctor[IndexedStateT[F, SA, ?, ?]] { + implicit def F: Functor[F] + + def bimap[A, B, C, D](fab: IndexedStateT[F, SA, A, B])(f: A => C, g: B => D): IndexedStateT[F, SA, C, D] = + fab.bimap(f, g) +} + +private[data] sealed abstract class IndexedStateTProfunctor[F[_], V] extends Profunctor[IndexedStateT[F, ?, ?, V]] { + implicit def F: Functor[F] + + def dimap[A, B, C, D](fab: IndexedStateT[F, A, B, V])(f: C => A)(g: B => D): IndexedStateT[F, C, D, V] = + fab.dimap(f)(g) +} + +private[data] sealed abstract class IndexedStateTStrong[F[_], V] extends IndexedStateTProfunctor[F, V] with Strong[IndexedStateT[F, ?, ?, V]] { + implicit def F: Monad[F] + + def first[A, B, C](fa: IndexedStateT[F, A, B, V]): IndexedStateT[F, (A, C), (B, C), V] = + IndexedStateT { case (a, c) => + F.map(fa.run(a)) { case (b, v) => + ((b, c), v) + } + } + + def second[A, B, C](fa: IndexedStateT[F, A, B, V]): IndexedStateT[F, (C, A), (C, B), V] = + first(fa).dimap((_: (C, A)).swap)(_.swap) +} + +private[data] sealed abstract class IndexedStateTMonad[F[_], S] extends IndexedStateTFunctor[F, S, S] with Monad[IndexedStateT[F, S, S, ?]] { + implicit def F: Monad[F] + + def pure[A](a: A): IndexedStateT[F, S, S, A] = + IndexedStateT.pure(a) + + def flatMap[A, B](fa: IndexedStateT[F, S, S, A])(f: A => IndexedStateT[F, S, S, B]): IndexedStateT[F, S, S, B] = + fa.flatMap(f) + + def tailRecM[A, B](a: A)(f: A => IndexedStateT[F, S, S, Either[A, B]]): IndexedStateT[F, S, S, B] = + IndexedStateT[F, S, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) { + case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) } + }) +} + +private[data] sealed abstract class IndexedStateTSemigroupK[F[_], SA, SB] extends SemigroupK[IndexedStateT[F, SA, SB, ?]] { + implicit def F: Monad[F] + implicit def G: SemigroupK[F] + + def combineK[A](x: IndexedStateT[F, SA, SB, A], y: IndexedStateT[F, SA, SB, A]): IndexedStateT[F, SA, SB, A] = + IndexedStateT(s => G.combineK(x.run(s), y.run(s))) +} + +private[data] sealed abstract class IndexedStateTAlternative[F[_], S] extends IndexedStateTMonad[F, S] with Alternative[IndexedStateT[F, S, S, ?]] { + def G: Alternative[F] + + def combineK[A](x: IndexedStateT[F, S, S, A], y: IndexedStateT[F, S, S, A]): IndexedStateT[F, S, S, A] = + IndexedStateT[F, S, S, A](s => G.combineK(x.run(s), y.run(s)))(G) + + def empty[A]: IndexedStateT[F, S, S, A] = + IndexedStateT.lift[F, S, A](G.empty[A])(G) +} + +private[data] sealed abstract class IndexedStateTMonadError[F[_], S, E] extends IndexedStateTMonad[F, S] + with MonadError[IndexedStateT[F, S, S, ?], E] { + implicit def F: MonadError[F, E] + + def raiseError[A](e: E): IndexedStateT[F, S, S, A] = IndexedStateT.lift(F.raiseError(e)) + + def handleErrorWith[A](fa: IndexedStateT[F, S, S, A])(f: E => IndexedStateT[F, S, S, A]): IndexedStateT[F, S, S, A] = + IndexedStateT(s => F.handleErrorWith(fa.run(s))(e => f(e).run(s))) +} diff --git a/core/src/main/scala/cats/data/StateT.scala b/core/src/main/scala/cats/data/StateT.scala deleted file mode 100644 index 82c5d21f3e..0000000000 --- a/core/src/main/scala/cats/data/StateT.scala +++ /dev/null @@ -1,280 +0,0 @@ -package cats -package data - -import cats.syntax.either._ - -/** - * `StateT[F, S, A]` is similar to `Kleisli[F, S, A]` in that it takes an `S` - * argument and produces an `A` value wrapped in `F`. However, it also produces - * an `S` value representing the updated state (which is wrapped in the `F` - * context along with the `A` value. - */ -final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable { - - def flatMap[B](fas: A => StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, B] = - StateT.applyF(F.map(runF) { sfsa => - sfsa.andThen { fsa => - F.flatMap(fsa) { case (s, a) => - fas(a).run(s) - } - } - }) - - def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): StateT[F, S, B] = - StateT.applyF(F.map(runF) { sfsa => - sfsa.andThen { fsa => - F.flatMap(fsa) { case (s, a) => F.map(faf(a))((s, _)) } - } - }) - - def map[B](f: A => B)(implicit F: Functor[F]): StateT[F, S, B] = - transform { case (s, a) => (s, f(a)) } - - /** - * Run with the provided initial state value - */ - def run(initial: S)(implicit F: FlatMap[F]): F[(S, A)] = - F.flatMap(runF)(f => f(initial)) - - /** - * Run with the provided initial state value and return the final state - * (discarding the final value). - */ - def runS(s: S)(implicit F: FlatMap[F]): F[S] = F.map(run(s))(_._1) - - /** - * Run with the provided initial state value and return the final value - * (discarding the final state). - */ - def runA(s: S)(implicit F: FlatMap[F]): F[A] = F.map(run(s))(_._2) - - /** - * Run with `S`'s empty monoid value as the initial state. - */ - def runEmpty(implicit S: Monoid[S], F: FlatMap[F]): F[(S, A)] = run(S.empty) - - /** - * Run with `S`'s empty monoid value as the initial state and return the final - * state (discarding the final value). - */ - def runEmptyS(implicit S: Monoid[S], F: FlatMap[F]): F[S] = runS(S.empty) - - /** - * Run with `S`'s empty monoid value as the initial state and return the final - * value (discarding the final state). - */ - def runEmptyA(implicit S: Monoid[S], F: FlatMap[F]): F[A] = runA(S.empty) - - /** - * Like [[map]], but also allows the state (`S`) value to be modified. - */ - def transform[B](f: (S, A) => (S, B))(implicit F: Functor[F]): StateT[F, S, B] = - StateT.applyF( - F.map(runF) { sfsa => - sfsa.andThen { fsa => - F.map(fsa) { case (s, a) => f(s, a) } - } - }) - - /** - * Like [[transform]], but allows the context to change from `F` to `G`. - */ - def transformF[G[_], B](f: F[(S, A)] => G[(S, B)])(implicit F: FlatMap[F], G: Applicative[G]): StateT[G, S, B] = - StateT(s => f(run(s))) - - /** - * Transform the state used. - * - * This is useful when you are working with many focused `StateT`s and want to pass in a - * global state containing the various states needed for each individual `StateT`. - * - * {{{ - * scala> import cats.implicits._ // needed for StateT.apply - * scala> type GlobalEnv = (Int, String) - * scala> val x: StateT[Option, Int, Double] = StateT((x: Int) => Option((x + 1, x.toDouble))) - * scala> val xt: StateT[Option, GlobalEnv, Double] = x.transformS[GlobalEnv](_._1, (t, i) => (i, t._2)) - * scala> val input = 5 - * scala> x.run(input) - * res0: Option[(Int, Double)] = Some((6,5.0)) - * scala> xt.run((input, "hello")) - * res1: Option[(GlobalEnv, Double)] = Some(((6,hello),5.0)) - * }}} - */ - def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Functor[F]): StateT[F, R, A] = - StateT.applyF(F.map(runF) { sfsa => - { r: R => - val s = f(r) - val fsa = sfsa(s) - F.map(fsa) { case (s, a) => (g(r, s), a) } - } - }) - - /** - * Modify the state (`S`) component. - */ - def modify(f: S => S)(implicit F: Functor[F]): StateT[F, S, A] = - transform((s, a) => (f(s), a)) - - /** - * Inspect a value from the input state, without modifying the state. - */ - def inspect[B](f: S => B)(implicit F: Functor[F]): StateT[F, S, B] = - transform((s, _) => (s, f(s))) - - /** - * Get the input state, without modifying the state. - */ - def get(implicit F: Functor[F]): StateT[F, S, S] = - inspect(identity) -} - -object StateT extends StateTInstances { - def apply[F[_], S, A](f: S => F[(S, A)])(implicit F: Applicative[F]): StateT[F, S, A] = - new StateT(F.pure(f)) - - def applyF[F[_], S, A](runF: F[S => F[(S, A)]]): StateT[F, S, A] = - new StateT(runF) - - def pure[F[_], S, A](a: A)(implicit F: Applicative[F]): StateT[F, S, A] = - StateT(s => F.pure((s, a))) - - def lift[F[_], S, A](fa: F[A])(implicit F: Applicative[F]): StateT[F, S, A] = - StateT(s => F.map(fa)(a => (s, a))) - - def inspect[F[_], S, A](f: S => A)(implicit F: Applicative[F]): StateT[F, S, A] = - StateT(s => F.pure((s, f(s)))) - - def inspectF[F[_], S, A](f: S => F[A])(implicit F: Applicative[F]): StateT[F, S, A] = - StateT(s => F.map(f(s))(a => (s, a))) - - def modify[F[_], S](f: S => S)(implicit F: Applicative[F]): StateT[F, S, Unit] = - StateT(s => F.pure((f(s), ()))) - - def modifyF[F[_], S](f: S => F[S])(implicit F: Applicative[F]): StateT[F, S, Unit] = - StateT(s => F.map(f(s))(s => (s, ()))) - - def get[F[_], S](implicit F: Applicative[F]): StateT[F, S, S] = - StateT(s => F.pure((s, s))) - - def set[F[_], S](s: S)(implicit F: Applicative[F]): StateT[F, S, Unit] = - StateT(_ => F.pure((s, ()))) - - def setF[F[_], S](fs: F[S])(implicit F: Applicative[F]): StateT[F, S, Unit] = - StateT(_ => F.map(fs)(s => (s, ()))) -} - -private[data] sealed trait StateTInstances extends StateTInstances1 { - implicit def catsDataAlternativeForStateT[F[_], S](implicit FM: Monad[F], FA: Alternative[F]): Alternative[StateT[F, S, ?]] = - new StateTAlternative[F, S] { implicit def F = FM; implicit def G = FA } -} - -private[data] sealed trait StateTInstances1 extends StateTInstances2 { - implicit def catsDataMonadErrorForStateT[F[_], S, E](implicit F0: MonadError[F, E]): MonadError[StateT[F, S, ?], E] = - new StateTMonadError[F, S, E] { implicit def F = F0 } - - implicit def catsDataSemigroupKForStateT[F[_], S](implicit F0: Monad[F], G0: SemigroupK[F]): SemigroupK[StateT[F, S, ?]] = - new StateTSemigroupK[F, S] { implicit def F = F0; implicit def G = G0 } -} - -private[data] sealed trait StateTInstances2 extends StateTInstances3 { - implicit def catsDataMonadForStateT[F[_], S](implicit F0: Monad[F]): Monad[StateT[F, S, ?]] = - new StateTMonad[F, S] { implicit def F = F0 } -} - -private[data] sealed trait StateTInstances3 { - implicit def catsDataFunctorForStateT[F[_], S](implicit F0: Functor[F]): Functor[StateT[F, S, ?]] = - new StateTFunctor[F, S] { implicit def F = F0 } -} - -// To workaround SI-7139 `object State` needs to be defined inside the package object -// together with the type alias. -private[data] abstract class StateFunctions { - - def apply[S, A](f: S => (S, A)): State[S, A] = - StateT.applyF(Now((s: S) => Now(f(s)))) - - /** - * Return `a` and maintain the input state. - */ - def pure[S, A](a: A): State[S, A] = State(s => (s, a)) - - /** - * Modify the input state and return Unit. - */ - def modify[S](f: S => S): State[S, Unit] = State(s => (f(s), ())) - - /** - * Inspect a value from the input state, without modifying the state. - */ - def inspect[S, T](f: S => T): State[S, T] = State(s => (s, f(s))) - - /** - * Return the input state without modifying it. - */ - def get[S]: State[S, S] = inspect(identity) - - /** - * Set the state to `s` and return Unit. - */ - def set[S](s: S): State[S, Unit] = State(_ => (s, ())) -} - -private[data] sealed trait StateTFunctor[F[_], S] extends Functor[StateT[F, S, ?]] { - implicit def F: Functor[F] - - override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f) -} - -private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] with StateTFunctor[F, S] { - implicit def F: Monad[F] - - def pure[A](a: A): StateT[F, S, A] = - StateT.pure(a) - - def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] = - fa.flatMap(f) - - def tailRecM[A, B](a: A)(f: A => StateT[F, S, Either[A, B]]): StateT[F, S, B] = - StateT[F, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) { - case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) } - }) -} - -private[data] sealed trait StateTSemigroupK[F[_], S] extends SemigroupK[StateT[F, S, ?]] { - implicit def F: Monad[F] - implicit def G: SemigroupK[F] - - def combineK[A](x: StateT[F, S, A], y: StateT[F, S, A]): StateT[F, S, A] = - StateT(s => G.combineK(x.run(s), y.run(s))) -} - -private[data] sealed trait StateTAlternative[F[_], S] extends Alternative[StateT[F, S, ?]] with StateTFunctor[F, S] { - implicit def F: Monad[F] - def G: Alternative[F] - - def combineK[A](x: StateT[F, S, A], y: StateT[F, S, A]): StateT[F, S, A] = - StateT[F, S, A](s => G.combineK(x.run(s), y.run(s)))(G) - - def pure[A](a: A): StateT[F, S, A] = - StateT.pure[F, S, A](a)(G) - - def empty[A]: StateT[F, S, A] = - StateT.lift[F, S, A](G.empty[A])(G) - - override def ap[A, B](ff: StateT[F, S, A => B])(fa: StateT[F, S, A]): StateT[F, S, B] = - StateT[F, S, B]((s: S) => - F.flatMap(ff.run(s)) { sab => - val (sn, f) = sab - F.map(fa.run(sn)) { case (snn, a) => (snn, f(a)) } - } - ) -} - -private[data] sealed trait StateTMonadError[F[_], S, E] extends StateTMonad[F, S] with MonadError[StateT[F, S, ?], E] { - implicit def F: MonadError[F, E] - - def raiseError[A](e: E): StateT[F, S, A] = StateT.lift(F.raiseError(e)) - - def handleErrorWith[A](fa: StateT[F, S, A])(f: E => StateT[F, S, A]): StateT[F, S, A] = - StateT(s => F.handleErrorWith(fa.run(s))(e => f(e).run(s))) -} diff --git a/core/src/main/scala/cats/data/package.scala b/core/src/main/scala/cats/data/package.scala index e367b1599c..0d82370fb3 100644 --- a/core/src/main/scala/cats/data/package.scala +++ b/core/src/main/scala/cats/data/package.scala @@ -29,6 +29,15 @@ package object data { def tell[L](l: L): Writer[L, Unit] = WriterT.tell(l) } + /** + * `StateT[F, S, A]` is similar to `Kleisli[F, S, A]` in that it takes an `S` + * argument and produces an `A` value wrapped in `F`. However, it also produces + * an `S` value representing the updated state (which is wrapped in the `F` + * context along with the `A` value. + */ + type StateT[F[_], S, A] = IndexedStateT[F, S, S, A] + object StateT extends StateTFunctions + type State[S, A] = StateT[Eval, S, A] object State extends StateFunctions diff --git a/free/src/test/scala/cats/free/FreeTTests.scala b/free/src/test/scala/cats/free/FreeTTests.scala index a5f2f30f1e..51329cb8f8 100644 --- a/free/src/test/scala/cats/free/FreeTTests.scala +++ b/free/src/test/scala/cats/free/FreeTTests.scala @@ -170,9 +170,9 @@ object FreeTTests extends FreeTTestsInstances { trait FreeTTestsInstances { import FreeT._ - import StateT._ + import IndexedStateT._ import cats.kernel.instances.option._ - import cats.tests.StateTTests._ + import cats.tests.IndexedStateTTests._ import CartesianTests._ type IntState[A] = State[Int, A] diff --git a/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala b/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala index 963d71faa2..4e719bd641 100644 --- a/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala +++ b/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala @@ -153,7 +153,7 @@ object arbitrary extends ArbitraryInstances0 { Arbitrary(FG.arbitrary.map(Nested(_))) implicit def catsLawArbitraryForState[S: Arbitrary: Cogen, A: Arbitrary]: Arbitrary[State[S, A]] = - catsLawArbitraryForStateT[Eval, S, A] + catsLawArbitraryForIndexedStateT[Eval, S, S, A] implicit def catsLawArbitraryForReader[A: Arbitrary: Cogen, B: Arbitrary]: Arbitrary[Reader[A, B]] = catsLawsArbitraryForKleisli[Id, A, B] @@ -168,8 +168,8 @@ object arbitrary extends ArbitraryInstances0 { private[discipline] sealed trait ArbitraryInstances0 { - implicit def catsLawArbitraryForStateT[F[_], S, A](implicit F: Arbitrary[F[S => F[(S, A)]]]): Arbitrary[StateT[F, S, A]] = - Arbitrary(F.arbitrary.map(StateT.applyF)) + implicit def catsLawArbitraryForIndexedStateT[F[_], SA, SB, A](implicit F: Arbitrary[F[SA => F[(SB, A)]]]): Arbitrary[IndexedStateT[F, SA, SB, A]] = + Arbitrary(F.arbitrary.map(IndexedStateT.applyF)) implicit def catsLawsArbitraryForWriterT[F[_], L, V](implicit F: Arbitrary[F[(L, V)]]): Arbitrary[WriterT[F, L, V]] = Arbitrary(F.arbitrary.map(WriterT(_))) diff --git a/tests/src/test/scala/cats/tests/IndexedStateTTests.scala b/tests/src/test/scala/cats/tests/IndexedStateTTests.scala new file mode 100644 index 0000000000..cc77122c0f --- /dev/null +++ b/tests/src/test/scala/cats/tests/IndexedStateTTests.scala @@ -0,0 +1,396 @@ +package cats +package tests + +import cats.data.{State, StateT, IndexedStateT, EitherT} +import cats.functor.{Contravariant, Bifunctor, Profunctor, Strong} +import cats.kernel.instances.tuple._ +import cats.laws.discipline._ +import cats.laws.discipline.eq._ +import cats.laws.discipline.arbitrary._ +import org.scalacheck.Arbitrary + +class IndexedStateTTests extends CatsSuite { + + implicit override val generatorDrivenConfig: PropertyCheckConfiguration = + checkConfiguration.copy(sizeRange = 5) + + import IndexedStateTTests._ + + test("basic state usage"){ + add1.run(1).value should === (2 -> 1) + } + + test("basic IndexedStateT usage") { + val listHead: IndexedStateT[Id, List[Int], Option[Int], Unit] = IndexedStateT.modify(_.headOption) + val getOrElse: IndexedStateT[Id, Option[Int], Int, Unit] = IndexedStateT.modify(_.getOrElse(0)) + val toString: IndexedStateT[Id, Int, String, Unit] = IndexedStateT.modify(_.toString) + + val composite = for { + _ <- listHead + _ <- getOrElse + _ <- toString + r <- IndexedStateT.get[Id, String] + } yield r + + composite.run(List(1, 2, 3)) should === (("1", "1")) + composite.run(Nil) should === (("0", "0")) + } + + test("traversing state is stack-safe"){ + val ns = (0 to 70000).toList + val x = ns.traverse(_ => add1) + x.runS(0).value should === (70001) + } + + test("State.pure, StateT.pure and IndexedStateT.pure are consistent"){ + forAll { (s: String, i: Int) => + val state: State[String, Int] = State.pure(i) + val stateT: State[String, Int] = StateT.pure(i) + val indexedStateT: State[String, Int] = IndexedStateT.pure(i) + + state.run(s) should === (stateT.run(s)) + state.run(s) should === (indexedStateT.run(s)) + } + } + + test("State.get, StateT.get and IndexedStateT.get are consistent") { + forAll{ (s: String) => + val state: State[String, String] = State.get + val stateT: State[String, String] = StateT.get + val indexedStateT: State[String, String] = IndexedStateT.get + + state.run(s) should === (stateT.run(s)) + state.run(s) should === (indexedStateT.run(s)) + } + } + + test("State.inspect, StateT.inspect and IndexedStateT.inspect are consistent") { + forAll { (s: String, f: String => Int) => + val state: State[String, Int] = State.inspect(f) + val stateT: State[String, Int] = StateT.inspect(f) + val indexedStateT: State[String, Int] = IndexedStateT.inspect(f) + + state.run(s) should === (stateT.run(s)) + state.run(s) should === (indexedStateT.run(s)) + } + } + + test("State.inspect, StateT.inspectF and IndexedStateT.inspectF are consistent") { + forAll { (s: String, f: String => Int) => + val state: State[String, Int] = State.inspect(f) + val stateT: State[String, Int] = StateT.inspectF(f.andThen(Eval.now)) + val indexedStateT: State[String, Int] = IndexedStateT.inspectF(f.andThen(Eval.now)) + + state.run(s) should === (stateT.run(s)) + state.run(s) should === (indexedStateT.run(s)) + } + } + + test("State.modify, StateT.modify and IndexedStateT.modify are consistent") { + forAll { (s: String, f: String => String) => + val state: State[String, Unit] = State.modify(f) + val stateT: State[String, Unit] = StateT.modify(f) + val indexedStateT: State[String, Unit] = IndexedStateT.modify(f) + + state.run(s) should === (stateT.run(s)) + state.run(s) should === (indexedStateT.run(s)) + } + } + + test("State.modify, StateT.modifyF and IndexedStateT.modifyF are consistent") { + forAll { (s: String, f: String => String) => + val state: State[String, Unit] = State.modify(f) + val stateT: State[String, Unit] = StateT.modifyF(f.andThen(Eval.now)) + val indexedStateT: State[String, Unit] = IndexedStateT.modifyF(f.andThen(Eval.now)) + + state.run(s) should === (stateT.run(s)) + state.run(s) should === (indexedStateT.run(s)) + } + } + + test("State.pure, StateT.lift and IndexedStateT.lift are consistent") { + forAll { (s: String, i: Int) => + val state: State[String, Int] = State.pure(i) + val stateT: State[String, Int] = StateT.lift(Eval.now(i)) + val indexedStateT: State[String, Int] = IndexedStateT.lift(Eval.now(i)) + + state.run(s) should === (stateT.run(s)) + state.run(s) should === (indexedStateT.run(s)) + } + } + + test("State.set, StateT.set and IndexedStateT.set are consistent") { + forAll { (init: String, s: String) => + val state: State[String, Unit] = State.set(s) + val stateT: StateT[Eval, String, Unit] = StateT.set(s) + val indexedStateT: StateT[Eval, String, Unit] = IndexedStateT.set(s) + + state.run(init) should === (stateT.run(init)) + state.run(init) should === (indexedStateT.run(init)) + } + } + + test("State.set, StateT.setF and IndexedStateT.setF are consistent") { + forAll { (init: String, s: String) => + val state: State[String, Unit] = State.set(s) + val stateT: StateT[Eval, String, Unit] = StateT.setF(Eval.now(s)) + val indexedStateT: StateT[Eval, String, Unit] = IndexedStateT.setF(Eval.now(s)) + + state.run(init) should === (stateT.run(init)) + state.run(init) should === (indexedStateT.run(init)) + } + } + + test("Cartesian syntax is usable on State") { + val x = add1 *> add1 + x.runS(0).value should === (2) + } + + test("Singleton and instance inspect are consistent"){ + forAll { (s: String, i: Int) => + State.inspect[Int, String](_.toString).run(i) should === ( + State.pure[Int, Unit](()).inspect(_.toString).run(i)) + } + } + + test("flatMap and flatMapF consistent") { + forAll { (stateT: StateT[Option, Long, Int], f: Int => Option[Int]) => + stateT.flatMap(a => StateT(s => f(a).map(b => (s, b)))) should === (stateT.flatMapF(f)) + } + } + + test("runEmpty, runEmptyS, and runEmptyA consistent"){ + forAll { (f: StateT[List, Long, Int]) => + (f.runEmptyS zip f.runEmptyA) should === (f.runEmpty) + } + } + + test("modify identity is a noop"){ + forAll { (f: StateT[List, Long, Int]) => + f.modify(identity) should === (f) + } + } + + test("modify modifies state"){ + forAll { (f: StateT[List, Long, Int], g: Long => Long, initial: Long) => + f.modify(g).runS(initial) should === (f.runS(initial).map(g)) + } + } + + test("modify doesn't affect A value"){ + forAll { (f: StateT[List, Long, Int], g: Long => Long, initial: Long) => + f.modify(g).runA(initial) should === (f.runA(initial)) + } + } + + test("State.modify equivalent to get then set"){ + forAll { (f: Long => Long) => + val s1 = for { + l <- State.get[Long] + _ <- State.set(f(l)) + } yield () + + val s2 = State.modify(f) + + s1 should === (s2) + } + } + + test("StateT.set equivalent to modify ignoring first param") { + forAll { (init: String, update: String) => + val s1 = StateT.modify[Eval, String](_ => update) + val s2 = StateT.set[Eval, String](update) + s1.run(init) should === (s2.run(init)) + } + } + + test("StateT.setF equivalent to modifyF ignoring first param") { + forAll { (init: String, update: String) => + val s1 = StateT.modifyF[Eval, String](_ => Eval.now(update)) + val s2 = StateT.setF(Eval.now(update)) + s1.run(init) should === (s2.run(init)) + } + } + + test(".get and then .run produces same state as value"){ + forAll { (s: State[Long, Int], initial: Long) => + val (finalS, finalA) = s.get.run(initial).value + finalS should === (finalA) + } + } + + test(".get equivalent to flatMap with State.get"){ + forAll { (s: State[Long, Int]) => + s.get should === (s.flatMap(_ => State.get)) + } + } + + test("StateT#transformS with identity is identity") { + forAll { (s: StateT[List, Long, Int]) => + s.transformS[Long](identity, (s, i) => i) should === (s) + } + } + + test("StateT#transformS modifies state") { + final case class Env(int: Int, str: String) + val x = StateT((x: Int) => Option((x + 1, x))) + val xx = x.transformS[Env](_.int, (e, i) => e.copy(int = i)) + val input = 5 + + val got = x.run(input) + val expected = xx.run(Env(input, "hello")).map { case (e, i) => (e.int, i) } + got should === (expected) + } + + + implicit val iso = CartesianTests.Isomorphisms.invariant[IndexedStateT[ListWrapper, String, Int, ?]](IndexedStateT.catsDataFunctorForIndexedStateT(ListWrapper.monad)) + + { + // F has a Functor + implicit val F: Functor[ListWrapper] = ListWrapper.functor + // We only need a Functor on F to find a Functor on StateT + Functor[IndexedStateT[ListWrapper, String, Int, ?]] + } + + { + // We only need a Functor to derive a Contravariant for IndexedStateT + implicit val F: Functor[ListWrapper] = ListWrapper.monad + Contravariant[IndexedStateT[ListWrapper, ?, Int, String]] + } + + { + // We only need a Functor to derive a Bifunctor for IndexedStateT + implicit val F: Functor[ListWrapper] = ListWrapper.monad + Bifunctor[IndexedStateT[ListWrapper, Int, ?, ?]] + } + + { + // We only need a Functor to derive a Profunctor for IndexedStateT + implicit val F: Functor[ListWrapper] = ListWrapper.monad + Profunctor[IndexedStateT[ListWrapper, ?, ?, String]] + } + + { + // F needs a Monad to do Eq on StateT + implicit val F: Monad[ListWrapper] = ListWrapper.monad + implicit val FS: Functor[IndexedStateT[ListWrapper, String, Int, ?]] = IndexedStateT.catsDataFunctorForIndexedStateT + + checkAll("IndexedStateT[ListWrapper, String, Int, Int]", FunctorTests[IndexedStateT[ListWrapper, String, Int, ?]].functor[Int, Int, Int]) + checkAll("Functor[IndexedStateT[ListWrapper, Int, ?]]", SerializableTests.serializable(Functor[IndexedStateT[ListWrapper, String, Int, ?]])) + + Functor[IndexedStateT[ListWrapper, String, Int, ?]] + } + + { + implicit val F: Monad[ListWrapper] = ListWrapper.monad + implicit val FS: Contravariant[IndexedStateT[ListWrapper, ?, Int, Int]] = IndexedStateT.catsDataContravariantForIndexedStateT + + checkAll("IndexedStateT[ListWrapper, Int, Int, Int]", ContravariantTests[IndexedStateT[ListWrapper, ?, Int, Int]].contravariant[Int, Int, Int]) + checkAll("Contravariant[IndexedStateT[ListWrapper, ?, Int, Int]]", SerializableTests.serializable(Contravariant[IndexedStateT[ListWrapper, ?, Int, Int]])) + + Contravariant[IndexedStateT[ListWrapper, ?, Int, Int]] + } + + { + implicit val F: Monad[ListWrapper] = ListWrapper.monad + implicit val FS: Bifunctor[IndexedStateT[ListWrapper, Int, ?, ?]] = IndexedStateT.catsDataBifunctorForIndexedStateT + + checkAll("IndexedStateT[ListWrapper, Int, String, Int]", BifunctorTests[IndexedStateT[ListWrapper, Int, ?, ?]].bifunctor[String, String, String, Int, Int, Int]) + checkAll("Bifunctor[IndexedStateT[ListWrapper, Int, ?, ?]]", SerializableTests.serializable(Bifunctor[IndexedStateT[ListWrapper, Int, ?, ?]])) + + Bifunctor[IndexedStateT[ListWrapper, Int, ?, ?]] + } + + { + implicit val F: Monad[ListWrapper] = ListWrapper.monad + implicit val FS: Profunctor[IndexedStateT[ListWrapper, ?, ?, Int]] = IndexedStateT.catsDataProfunctorForIndexedStateT + + checkAll("IndexedStateT[ListWrapper, String, Int, Int]", ProfunctorTests[IndexedStateT[ListWrapper, ?, ?, Int]].profunctor[String, String, String, Int, Int, Int]) + checkAll("Profunctor[IndexedStateT[ListWrapper, ?, ?, Int]]", SerializableTests.serializable(Profunctor[IndexedStateT[ListWrapper, ?, ?, Int]])) + + Profunctor[IndexedStateT[ListWrapper, ?, ?, Int]] + } + + { + implicit val F: Monad[ListWrapper] = ListWrapper.monad + implicit val FS: Strong[IndexedStateT[ListWrapper, ?, ?, Int]] = IndexedStateT.catsDataStrongForIndexedStateT + + checkAll("IndexedStateT[ListWrapper, String, Int, Int]", StrongTests[IndexedStateT[ListWrapper, ?, ?, Int]].strong[String, String, String, Int, Int, Int]) + checkAll("Strong[IndexedStateT[ListWrapper, ?, ?, Int]]", SerializableTests.serializable(Strong[IndexedStateT[ListWrapper, ?, ?, Int]])) + + Strong[IndexedStateT[ListWrapper, ?, ?, Int]] + } + + { + // F has a Monad + implicit val F = ListWrapper.monad + + checkAll("IndexedStateT[ListWrapper, Int, Int]", MonadTests[IndexedStateT[ListWrapper, Int, Int, ?]].monad[Int, Int, Int]) + checkAll("Monad[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(Monad[IndexedStateT[ListWrapper, Int, Int, ?]])) + + Monad[IndexedStateT[ListWrapper, Int, Int, ?]] + FlatMap[IndexedStateT[ListWrapper, Int, Int, ?]] + Applicative[IndexedStateT[ListWrapper, Int, Int, ?]] + Apply[IndexedStateT[ListWrapper, Int, Int, ?]] + Functor[IndexedStateT[ListWrapper, Int, Int, ?]] + } + + { + // F has a Monad and a SemigroupK + implicit val F = ListWrapper.monad + implicit val S = ListWrapper.semigroupK + + checkAll("IndexedStateT[ListWrapper, Int, Int]", SemigroupKTests[IndexedStateT[ListWrapper, Int, Int, ?]].semigroupK[Int]) + checkAll("SemigroupK[IndexedStateT[ListWrapper, Int, ?]]", SerializableTests.serializable(SemigroupK[IndexedStateT[ListWrapper, String, Int, ?]])) + } + + { + // F has an Alternative + implicit val G = ListWrapper.monad + implicit val F = ListWrapper.alternative + val SA = IndexedStateT.catsDataAlternativeForIndexedStateT[ListWrapper, Int](ListWrapper.monad, ListWrapper.alternative) + + checkAll("IndexedStateT[ListWrapper, Int, Int, Int]", AlternativeTests[IndexedStateT[ListWrapper, Int, Int, ?]](SA).monoidK[Int]) + checkAll("Alternative[IndexedStateT[ListWrapper, Int, Int, ?]]", SerializableTests.serializable(SA)) + + Monad[IndexedStateT[ListWrapper, Int, Int, ?]] + FlatMap[IndexedStateT[ListWrapper, Int, Int, ?]] + Alternative[IndexedStateT[ListWrapper, Int, Int, ?]] + Applicative[IndexedStateT[ListWrapper, Int, Int, ?]] + Apply[IndexedStateT[ListWrapper, Int, Int, ?]] + Functor[IndexedStateT[ListWrapper, Int, Int, ?]] + MonoidK[IndexedStateT[ListWrapper, Int, Int, ?]] + SemigroupK[IndexedStateT[ListWrapper, Int, Int, ?]] + } + + { + implicit val iso = CartesianTests.Isomorphisms.invariant[State[Long, ?]] + + checkAll("State[Long, ?]", MonadTests[State[Long, ?]].monad[Int, Int, Int]) + checkAll("Monad[State[Long, ?]]", SerializableTests.serializable(Monad[State[Long, ?]])) + } + + { + // F has a MonadError + implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[Option, Int, ?]] + implicit val eqEitherTFA: Eq[EitherT[StateT[Option, Int , ?], Unit, Int]] = EitherT.catsDataEqForEitherT[StateT[Option, Int , ?], Unit, Int] + + checkAll("StateT[Option, Int, Int]", MonadErrorTests[StateT[Option, Int, ?], Unit].monadError[Int, Int, Int]) + checkAll("MonadError[StateT[Option, Int, ?], Unit]", SerializableTests.serializable(MonadError[StateT[Option, Int , ?], Unit])) + } + +} + +object IndexedStateTTests extends IndexedStateTTestsInstances { + implicit def stateEq[S:Eq:Arbitrary, A:Eq]: Eq[State[S, A]] = + indexedStateTEq[Eval, S, S, A] + + val add1: State[Int, Int] = State(n => (n + 1, n)) +} + +sealed trait IndexedStateTTestsInstances { + + implicit def indexedStateTEq[F[_], SA, SB, A](implicit SA: Arbitrary[SA], FSB: Eq[F[(SB, A)]], F: FlatMap[F]): Eq[IndexedStateT[F, SA, SB, A]] = + Eq.by[IndexedStateT[F, SA, SB, A], SA => F[(SB, A)]](state => + s => state.run(s)) +} diff --git a/tests/src/test/scala/cats/tests/MonadTest.scala b/tests/src/test/scala/cats/tests/MonadTest.scala index 063f389f9a..1a1c0eae4e 100644 --- a/tests/src/test/scala/cats/tests/MonadTest.scala +++ b/tests/src/test/scala/cats/tests/MonadTest.scala @@ -1,11 +1,11 @@ package cats package tests -import cats.data.{StateT} +import cats.data.{IndexedStateT, StateT} import org.scalacheck.Gen class MonadTest extends CatsSuite { - implicit val testInstance: Monad[StateT[Id, Int, ?]] = StateT.catsDataMonadForStateT[Id, Int] + implicit val testInstance: Monad[StateT[Id, Int, ?]] = IndexedStateT.catsDataMonadForIndexedStateT[Id, Int] val increment: StateT[Id, Int, Unit] = StateT.modify(_ + 1) val incrementAndGet: StateT[Id, Int, Int] = increment >> StateT.get diff --git a/tests/src/test/scala/cats/tests/StateTTests.scala b/tests/src/test/scala/cats/tests/StateTTests.scala deleted file mode 100644 index 81c1ee99ed..0000000000 --- a/tests/src/test/scala/cats/tests/StateTTests.scala +++ /dev/null @@ -1,293 +0,0 @@ -package cats -package tests - -import cats.data.{State, StateT, EitherT} -import cats.kernel.instances.tuple._ -import cats.laws.discipline._ -import cats.laws.discipline.eq._ -import cats.laws.discipline.arbitrary._ -import org.scalacheck.Arbitrary - -class StateTTests extends CatsSuite { - - implicit override val generatorDrivenConfig: PropertyCheckConfiguration = - checkConfiguration.copy(sizeRange = 5) - - import StateTTests._ - - test("basic state usage"){ - add1.run(1).value should === (2 -> 1) - } - - test("traversing state is stack-safe"){ - val ns = (0 to 70000).toList - val x = ns.traverse(_ => add1) - x.runS(0).value should === (70001) - } - - test("State.pure and StateT.pure are consistent"){ - forAll { (s: String, i: Int) => - val state: State[String, Int] = State.pure(i) - val stateT: State[String, Int] = StateT.pure(i) - state.run(s) should === (stateT.run(s)) - } - } - - test("State.get and StateT.get are consistent") { - forAll{ (s: String) => - val state: State[String, String] = State.get - val stateT: State[String, String] = StateT.get - state.run(s) should === (stateT.run(s)) - } - } - - test("State.inspect and StateT.inspect are consistent") { - forAll { (s: String, f: String => Int) => - val state: State[String, Int] = State.inspect(f) - val stateT: State[String, Int] = StateT.inspect(f) - state.run(s) should === (stateT.run(s)) - } - } - - test("State.inspect and StateT.inspectF are consistent") { - forAll { (s: String, f: String => Int) => - val state: State[String, Int] = State.inspect(f) - val stateT: State[String, Int] = StateT.inspectF(f.andThen(Eval.now)) - state.run(s) should === (stateT.run(s)) - } - } - - test("State.modify and StateT.modify are consistent") { - forAll { (s: String, f: String => String) => - val state: State[String, Unit] = State.modify(f) - val stateT: State[String, Unit] = StateT.modify(f) - state.run(s) should === (stateT.run(s)) - } - } - - test("State.modify and StateT.modifyF are consistent") { - forAll { (s: String, f: String => String) => - val state: State[String, Unit] = State.modify(f) - val stateT: State[String, Unit] = StateT.modifyF(f.andThen(Eval.now)) - state.run(s) should === (stateT.run(s)) - } - } - - test("State.pure and StateT.lift are consistent") { - forAll { (s: String, i: Int) => - val state: State[String, Int] = State.pure(i) - val stateT: State[String, Int] = StateT.lift(Eval.now(i)) - state.run(s) should === (stateT.run(s)) - } - } - - test("State.set and StateT.set are consistent") { - forAll { (init: String, s: String) => - val state: State[String, Unit] = State.set(s) - val stateT: StateT[Eval, String, Unit] = StateT.set(s) - state.run(init) should === (stateT.run(init)) - } - } - - test("State.set and StateT.setF are consistent") { - forAll { (init: String, s: String) => - val state: State[String, Unit] = State.set(s) - val stateT: StateT[Eval, String, Unit] = StateT.setF(Eval.now(s)) - state.run(init) should === (stateT.run(init)) - } - } - - test("Cartesian syntax is usable on State") { - val x = add1 *> add1 - x.runS(0).value should === (2) - } - - test("Singleton and instance inspect are consistent"){ - forAll { (s: String, i: Int) => - State.inspect[Int, String](_.toString).run(i) should === ( - State.pure[Int, Unit](()).inspect(_.toString).run(i)) - } - } - - test("flatMap and flatMapF consistent") { - forAll { (stateT: StateT[Option, Long, Int], f: Int => Option[Int]) => - stateT.flatMap(a => StateT(s => f(a).map(b => (s, b)))) should === (stateT.flatMapF(f)) - } - } - - test("runEmpty, runEmptyS, and runEmptyA consistent"){ - forAll { (f: StateT[List, Long, Int]) => - (f.runEmptyS zip f.runEmptyA) should === (f.runEmpty) - } - } - - test("modify identity is a noop"){ - forAll { (f: StateT[List, Long, Int]) => - f.modify(identity) should === (f) - } - } - - test("modify modifies state"){ - forAll { (f: StateT[List, Long, Int], g: Long => Long, initial: Long) => - f.modify(g).runS(initial) should === (f.runS(initial).map(g)) - } - } - - test("modify doesn't affect A value"){ - forAll { (f: StateT[List, Long, Int], g: Long => Long, initial: Long) => - f.modify(g).runA(initial) should === (f.runA(initial)) - } - } - - test("State.modify equivalent to get then set"){ - forAll { (f: Long => Long) => - val s1 = for { - l <- State.get[Long] - _ <- State.set(f(l)) - } yield () - - val s2 = State.modify(f) - - s1 should === (s2) - } - } - - test("StateT.set equivalent to modify ignoring first param") { - forAll { (init: String, update: String) => - val s1 = StateT.modify[Eval, String](_ => update) - val s2 = StateT.set[Eval, String](update) - s1.run(init) should === (s2.run(init)) - } - } - - test("StateT.setF equivalent to modifyF ignoring first param") { - forAll { (init: String, update: String) => - val s1 = StateT.modifyF[Eval, String](_ => Eval.now(update)) - val s2 = StateT.setF(Eval.now(update)) - s1.run(init) should === (s2.run(init)) - } - } - - test(".get and then .run produces same state as value"){ - forAll { (s: State[Long, Int], initial: Long) => - val (finalS, finalA) = s.get.run(initial).value - finalS should === (finalA) - } - } - - test(".get equivalent to flatMap with State.get"){ - forAll { (s: State[Long, Int]) => - s.get should === (s.flatMap(_ => State.get)) - } - } - - test("StateT#transformS with identity is identity") { - forAll { (s: StateT[List, Long, Int]) => - s.transformS[Long](identity, (s, i) => i) should === (s) - } - } - - test("StateT#transformS modifies state") { - final case class Env(int: Int, str: String) - val x = StateT((x: Int) => Option((x + 1, x))) - val xx = x.transformS[Env](_.int, (e, i) => e.copy(int = i)) - val input = 5 - - val got = x.run(input) - val expected = xx.run(Env(input, "hello")).map { case (e, i) => (e.int, i) } - got should === (expected) - } - - - implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[ListWrapper, Int, ?]](StateT.catsDataFunctorForStateT(ListWrapper.monad)) - - { - // F has a Functor - implicit val F: Functor[ListWrapper] = ListWrapper.functor - // We only need a Functor on F to find a Functor on StateT - Functor[StateT[ListWrapper, Int, ?]] - } - - { - // F needs a Monad to do Eq on StateT - implicit val F: Monad[ListWrapper] = ListWrapper.monad - implicit val FS: Functor[StateT[ListWrapper, Int, ?]] = StateT.catsDataFunctorForStateT - - checkAll("StateT[ListWrapper, Int, Int]", FunctorTests[StateT[ListWrapper, Int, ?]].functor[Int, Int, Int]) - checkAll("Functor[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(Functor[StateT[ListWrapper, Int, ?]])) - - Functor[StateT[ListWrapper, Int, ?]] - } - - { - // F has a Monad - implicit val F = ListWrapper.monad - - checkAll("StateT[ListWrapper, Int, Int]", MonadTests[StateT[ListWrapper, Int, ?]].monad[Int, Int, Int]) - checkAll("Monad[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(Monad[StateT[ListWrapper, Int, ?]])) - - Monad[StateT[ListWrapper, Int, ?]] - FlatMap[StateT[ListWrapper, Int, ?]] - Applicative[StateT[ListWrapper, Int, ?]] - Apply[StateT[ListWrapper, Int, ?]] - Functor[StateT[ListWrapper, Int, ?]] - } - - { - // F has a Monad and a SemigroupK - implicit val F = ListWrapper.monad - implicit val S = ListWrapper.semigroupK - - checkAll("StateT[ListWrapper, Int, Int]", SemigroupKTests[StateT[ListWrapper, Int, ?]].semigroupK[Int]) - checkAll("SemigroupK[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(SemigroupK[StateT[ListWrapper, Int, ?]])) - } - - { - // F has an Alternative - implicit val G = ListWrapper.monad - implicit val F = ListWrapper.alternative - val SA = StateT.catsDataAlternativeForStateT[ListWrapper, Int](ListWrapper.monad, ListWrapper.alternative) - checkAll("StateT[ListWrapper, Int, Int]", AlternativeTests[StateT[ListWrapper, Int, ?]](SA).monoidK[Int]) - checkAll("Alternative[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(SA)) - - Monad[StateT[ListWrapper, Int, ?]] - FlatMap[StateT[ListWrapper, Int, ?]] - Alternative[StateT[ListWrapper, Int, ?]] - Applicative[StateT[ListWrapper, Int, ?]] - Apply[StateT[ListWrapper, Int, ?]] - Functor[StateT[ListWrapper, Int, ?]] - MonoidK[StateT[ListWrapper, Int, ?]] - SemigroupK[StateT[ListWrapper, Int, ?]] - } - - { - implicit val iso = CartesianTests.Isomorphisms.invariant[State[Long, ?]] - - checkAll("State[Long, ?]", MonadTests[State[Long, ?]].monad[Int, Int, Int]) - checkAll("Monad[State[Long, ?]]", SerializableTests.serializable(Monad[State[Long, ?]])) - } - - { - // F has a MonadError - implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[Option, Int, ?]] - implicit val eqEitherTFA: Eq[EitherT[StateT[Option, Int , ?], Unit, Int]] = EitherT.catsDataEqForEitherT[StateT[Option, Int , ?], Unit, Int] - - checkAll("StateT[Option, Int, Int]", MonadErrorTests[StateT[Option, Int, ?], Unit].monadError[Int, Int, Int]) - checkAll("MonadError[StateT[Option, Int, ?], Unit]", SerializableTests.serializable(MonadError[StateT[Option, Int , ?], Unit])) - } - -} - -object StateTTests extends StateTTestsInstances { - implicit def stateEq[S:Eq:Arbitrary, A:Eq]: Eq[State[S, A]] = - stateTEq[Eval, S, A] - - val add1: State[Int, Int] = State(n => (n + 1, n)) -} - -sealed trait StateTTestsInstances { - - implicit def stateTEq[F[_], S, A](implicit S: Arbitrary[S], FSA: Eq[F[(S, A)]], F: FlatMap[F]): Eq[StateT[F, S, A]] = - Eq.by[StateT[F, S, A], S => F[(S, A)]](state => - s => state.run(s)) -}