diff --git a/bench/src/main/scala/cats/bench/TrampolineBench.scala b/bench/src/main/scala/cats/bench/TrampolineBench.scala index e7f547f252..9bc3e991d5 100644 --- a/bench/src/main/scala/cats/bench/TrampolineBench.scala +++ b/bench/src/main/scala/cats/bench/TrampolineBench.scala @@ -5,6 +5,7 @@ import org.openjdk.jmh.annotations.{Benchmark, Scope, State} import cats._ import cats.implicits._ import cats.free.Trampoline +import scala.util.control.TailCalls @State(Scope.Benchmark) class TrampolineBench { @@ -30,14 +31,12 @@ class TrampolineBench { y <- Trampoline.defer(trampolineFib(n - 2)) } yield x + y - // TailRec[A] only has .flatMap in 2.11. + @Benchmark + def stdlib(): Int = stdlibFib(N).result - // @Benchmark - // def stdlib(): Int = stdlibFib(N).result - // - // def stdlibFib(n: Int): TailCalls.TailRec[Int] = - // if (n < 2) TailCalls.done(n) else for { - // x <- TailCalls.tailcall(stdlibFib(n - 1)) - // y <- TailCalls.tailcall(stdlibFib(n - 2)) - // } yield x + y + def stdlibFib(n: Int): TailCalls.TailRec[Int] = + if (n < 2) TailCalls.done(n) else for { + x <- TailCalls.tailcall(stdlibFib(n - 1)) + y <- TailCalls.tailcall(stdlibFib(n - 2)) + } yield x + y } diff --git a/core/src/main/scala-2.12/cats/instances/all.scala b/core/src/main/scala-2.12/cats/instances/all.scala index 001e20ccf0..bb949276a7 100644 --- a/core/src/main/scala-2.12/cats/instances/all.scala +++ b/core/src/main/scala-2.12/cats/instances/all.scala @@ -38,6 +38,7 @@ trait AllInstances with StreamInstances with StringInstances with SymbolInstances + with TailRecInstances with TryInstances with TupleInstances with UUIDInstances diff --git a/core/src/main/scala-2.12/cats/instances/package.scala b/core/src/main/scala-2.12/cats/instances/package.scala index d905f919e7..c35f425d9c 100644 --- a/core/src/main/scala-2.12/cats/instances/package.scala +++ b/core/src/main/scala-2.12/cats/instances/package.scala @@ -39,6 +39,7 @@ package object instances { object sortedSet extends SortedSetInstances with SortedSetInstancesBinCompat0 with SortedSetInstancesBinCompat1 object stream extends StreamInstances with StreamInstancesBinCompat0 object string extends StringInstances + object tailRec extends TailRecInstances object try_ extends TryInstances object tuple extends TupleInstances with Tuple2InstancesBinCompat0 object unit extends UnitInstances diff --git a/core/src/main/scala-2.13+/cats/instances/all.scala b/core/src/main/scala-2.13+/cats/instances/all.scala index 315b99eefd..b6c7744315 100644 --- a/core/src/main/scala-2.13+/cats/instances/all.scala +++ b/core/src/main/scala-2.13+/cats/instances/all.scala @@ -39,6 +39,7 @@ trait AllInstances with StreamInstances with StringInstances with SymbolInstances + with TailRecInstances with TryInstances with TupleInstances with UUIDInstances diff --git a/core/src/main/scala-2.13+/cats/instances/package.scala b/core/src/main/scala-2.13+/cats/instances/package.scala index 6081280fb5..4dfa9bbadc 100644 --- a/core/src/main/scala-2.13+/cats/instances/package.scala +++ b/core/src/main/scala-2.13+/cats/instances/package.scala @@ -42,6 +42,7 @@ package object instances { object stream extends StreamInstances with StreamInstancesBinCompat0 object lazyList extends LazyListInstances object string extends StringInstances + object tailRec extends TailRecInstances object try_ extends TryInstances object tuple extends TupleInstances with Tuple2InstancesBinCompat0 object unit extends UnitInstances diff --git a/core/src/main/scala/cats/Eval.scala b/core/src/main/scala/cats/Eval.scala index 53676f0e04..f477492785 100644 --- a/core/src/main/scala/cats/Eval.scala +++ b/core/src/main/scala/cats/Eval.scala @@ -378,6 +378,7 @@ sealed abstract private[cats] class EvalInstances extends EvalInstances0 { def flatMap[A, B](fa: Eval[A])(f: A => Eval[B]): Eval[B] = fa.flatMap(f) def extract[A](la: Eval[A]): A = la.value def coflatMap[A, B](fa: Eval[A])(f: Eval[A] => B): Eval[B] = Later(f(fa)) + override def unit: Eval[Unit] = Eval.Unit } implicit val catsDeferForEval: Defer[Eval] = diff --git a/core/src/main/scala/cats/instances/tailrec.scala b/core/src/main/scala/cats/instances/tailrec.scala new file mode 100644 index 0000000000..db2dd1c7a8 --- /dev/null +++ b/core/src/main/scala/cats/instances/tailrec.scala @@ -0,0 +1,26 @@ +package cats +package instances + +import scala.util.control.TailCalls.{done, tailcall, TailRec} + +trait TailRecInstances { + implicit def catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] = + TailRecInstances.catsInstancesForTailRec +} + +private object TailRecInstances { + val catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] = + new StackSafeMonad[TailRec] with Defer[TailRec] { + def defer[A](fa: => TailRec[A]): TailRec[A] = tailcall(fa) + + def pure[A](a: A): TailRec[A] = done(a) + + override def map[A, B](fa: TailRec[A])(f: A => B): TailRec[B] = + fa.map(f) + + def flatMap[A, B](fa: TailRec[A])(f: A => TailRec[B]): TailRec[B] = + fa.flatMap(f) + + override val unit: TailRec[Unit] = done(()) + } +} diff --git a/tests/src/test/scala/cats/tests/TailRecSuite.scala b/tests/src/test/scala/cats/tests/TailRecSuite.scala new file mode 100644 index 0000000000..f6c9cafc13 --- /dev/null +++ b/tests/src/test/scala/cats/tests/TailRecSuite.scala @@ -0,0 +1,30 @@ +package cats +package tests + +import scala.util.control.TailCalls.{done, tailcall, TailRec} +import org.scalacheck.{Arbitrary, Cogen, Gen} + +import Arbitrary.arbitrary + +import cats.laws.discipline.{DeferTests, MonadTests, SerializableTests} + +class TailRecSuite extends CatsSuite { + + implicit def tailRecArb[A: Arbitrary: Cogen]: Arbitrary[TailRec[A]] = + Arbitrary( + Gen.frequency( + (3, arbitrary[A].map(done(_))), + (1, Gen.lzy(arbitrary[(A, A => TailRec[A])].map { case (a, fn) => tailcall(fn(a)) })), + (1, Gen.lzy(arbitrary[(TailRec[A], A => TailRec[A])].map { case (a, fn) => a.flatMap(fn) })) + ) + ) + + implicit def eqTailRec[A: Eq]: Eq[TailRec[A]] = + Eq.by[TailRec[A], A](_.result) + + checkAll("TailRec[Int]", MonadTests[TailRec].monad[Int, Int, Int]) + checkAll("Monad[TailRec]", SerializableTests.serializable(Monad[TailRec])) + + checkAll("TailRec[Int]", DeferTests[TailRec].defer[Int]) + checkAll("Defer[TailRec]", SerializableTests.serializable(Defer[TailRec])) +}