diff --git a/core/src/main/scala/cats/data/StateT.scala b/core/src/main/scala/cats/data/StateT.scala index 96ec613a3a..bf365c79c0 100644 --- a/core/src/main/scala/cats/data/StateT.scala +++ b/core/src/main/scala/cats/data/StateT.scala @@ -150,6 +150,12 @@ object StateT extends StateTInstances { def modifyF[F[_], S](f: S => F[S])(implicit F: Applicative[F]): StateT[F, S, Unit] = StateT(s => F.map(f(s))(s => (s, ()))) + + def set[F[_], S](s: S)(implicit F: Applicative[F]): StateT[F, S, Unit] = + StateT(_ => F.pure((s, ()))) + + def setF[F[_], S](fs: F[S])(implicit F: Applicative[F]): StateT[F, S, Unit] = + StateT(_ => F.map(fs)(s => (s, ()))) } private[data] sealed trait StateTInstances extends StateTInstances1 { diff --git a/tests/src/test/scala/cats/tests/StateTTests.scala b/tests/src/test/scala/cats/tests/StateTTests.scala index 34ebcddc6f..e83da7704a 100644 --- a/tests/src/test/scala/cats/tests/StateTTests.scala +++ b/tests/src/test/scala/cats/tests/StateTTests.scala @@ -69,6 +69,22 @@ class StateTTests extends CatsSuite { } } + test("State.set and StateT.set are consistent") { + forAll { (init: String, s: String) => + val state: State[String, Unit] = State.set(s) + val stateT: StateT[Eval, String, Unit] = StateT.set(s) + state.run(init) should === (stateT.run(init)) + } + } + + test("State.set and StateT.setF are consistent") { + forAll { (init: String, s: String) => + val state: State[String, Unit] = State.set(s) + val stateT: StateT[Eval, String, Unit] = StateT.setF(Eval.now(s)) + state.run(init) should === (stateT.run(init)) + } + } + test("Cartesian syntax is usable on State") { val x = add1 *> add1 x.runS(0).value should === (2) @@ -124,6 +140,22 @@ class StateTTests extends CatsSuite { } } + test("StateT.set equivalent to modify ignoring first param") { + forAll { (init: String, update: String) => + val s1 = StateT.modify[Eval, String](_ => update) + val s2 = StateT.set[Eval, String](update) + s1.run(init) should === (s2.run(init)) + } + } + + test("StateT.setF equivalent to modifyF ignoring first param") { + forAll { (init: String, update: String) => + val s1 = StateT.modifyF[Eval, String](_ => Eval.now(update)) + val s2 = StateT.setF(Eval.now(update)) + s1.run(init) should === (s2.run(init)) + } + } + test(".get and then .run produces same state as value"){ forAll { (s: State[Long, Int], initial: Long) => val (finalS, finalA) = s.get.run(initial).value