Skip to content

Commit

Permalink
Add scala.util.control.TailCalls.TailRec instances (#3041)
Browse files Browse the repository at this point in the history
* Add scala.util.control.TailCalls.TailRec instances

* format

* review comments

* fix conflict

* Format

* Avoid val in trait for bincompat
  • Loading branch information
johnynek authored and LukaJCB committed Nov 6, 2019
1 parent 0aaa637 commit fc7b8b9
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 9 deletions.
17 changes: 8 additions & 9 deletions bench/src/main/scala/cats/bench/TrampolineBench.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import org.openjdk.jmh.annotations.{Benchmark, Scope, State}
import cats._
import cats.implicits._
import cats.free.Trampoline
import scala.util.control.TailCalls

@State(Scope.Benchmark)
class TrampolineBench {
Expand All @@ -30,14 +31,12 @@ class TrampolineBench {
y <- Trampoline.defer(trampolineFib(n - 2))
} yield x + y

// TailRec[A] only has .flatMap in 2.11.
@Benchmark
def stdlib(): Int = stdlibFib(N).result

// @Benchmark
// def stdlib(): Int = stdlibFib(N).result
//
// def stdlibFib(n: Int): TailCalls.TailRec[Int] =
// if (n < 2) TailCalls.done(n) else for {
// x <- TailCalls.tailcall(stdlibFib(n - 1))
// y <- TailCalls.tailcall(stdlibFib(n - 2))
// } yield x + y
def stdlibFib(n: Int): TailCalls.TailRec[Int] =
if (n < 2) TailCalls.done(n) else for {
x <- TailCalls.tailcall(stdlibFib(n - 1))
y <- TailCalls.tailcall(stdlibFib(n - 2))
} yield x + y
}
1 change: 1 addition & 0 deletions core/src/main/scala-2.12/cats/instances/all.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ trait AllInstances
with StreamInstances
with StringInstances
with SymbolInstances
with TailRecInstances
with TryInstances
with TupleInstances
with UUIDInstances
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala-2.12/cats/instances/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ package object instances {
object sortedSet extends SortedSetInstances with SortedSetInstancesBinCompat0 with SortedSetInstancesBinCompat1
object stream extends StreamInstances with StreamInstancesBinCompat0
object string extends StringInstances
object tailRec extends TailRecInstances
object try_ extends TryInstances
object tuple extends TupleInstances with Tuple2InstancesBinCompat0
object unit extends UnitInstances
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala-2.13+/cats/instances/all.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ trait AllInstances
with StreamInstances
with StringInstances
with SymbolInstances
with TailRecInstances
with TryInstances
with TupleInstances
with UUIDInstances
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala-2.13+/cats/instances/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ package object instances {
object stream extends StreamInstances with StreamInstancesBinCompat0
object lazyList extends LazyListInstances
object string extends StringInstances
object tailRec extends TailRecInstances
object try_ extends TryInstances
object tuple extends TupleInstances with Tuple2InstancesBinCompat0
object unit extends UnitInstances
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/cats/Eval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ sealed abstract private[cats] class EvalInstances extends EvalInstances0 {
def flatMap[A, B](fa: Eval[A])(f: A => Eval[B]): Eval[B] = fa.flatMap(f)
def extract[A](la: Eval[A]): A = la.value
def coflatMap[A, B](fa: Eval[A])(f: Eval[A] => B): Eval[B] = Later(f(fa))
override def unit: Eval[Unit] = Eval.Unit
}

implicit val catsDeferForEval: Defer[Eval] =
Expand Down
26 changes: 26 additions & 0 deletions core/src/main/scala/cats/instances/tailrec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package cats
package instances

import scala.util.control.TailCalls.{done, tailcall, TailRec}

trait TailRecInstances {
implicit def catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] =
TailRecInstances.catsInstancesForTailRec
}

private object TailRecInstances {
val catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] =
new StackSafeMonad[TailRec] with Defer[TailRec] {
def defer[A](fa: => TailRec[A]): TailRec[A] = tailcall(fa)

def pure[A](a: A): TailRec[A] = done(a)

override def map[A, B](fa: TailRec[A])(f: A => B): TailRec[B] =
fa.map(f)

def flatMap[A, B](fa: TailRec[A])(f: A => TailRec[B]): TailRec[B] =
fa.flatMap(f)

override val unit: TailRec[Unit] = done(())
}
}
30 changes: 30 additions & 0 deletions tests/src/test/scala/cats/tests/TailRecSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package cats
package tests

import scala.util.control.TailCalls.{done, tailcall, TailRec}
import org.scalacheck.{Arbitrary, Cogen, Gen}

import Arbitrary.arbitrary

import cats.laws.discipline.{DeferTests, MonadTests, SerializableTests}

class TailRecSuite extends CatsSuite {

implicit def tailRecArb[A: Arbitrary: Cogen]: Arbitrary[TailRec[A]] =
Arbitrary(
Gen.frequency(
(3, arbitrary[A].map(done(_))),
(1, Gen.lzy(arbitrary[(A, A => TailRec[A])].map { case (a, fn) => tailcall(fn(a)) })),
(1, Gen.lzy(arbitrary[(TailRec[A], A => TailRec[A])].map { case (a, fn) => a.flatMap(fn) }))
)
)

implicit def eqTailRec[A: Eq]: Eq[TailRec[A]] =
Eq.by[TailRec[A], A](_.result)

checkAll("TailRec[Int]", MonadTests[TailRec].monad[Int, Int, Int])
checkAll("Monad[TailRec]", SerializableTests.serializable(Monad[TailRec]))

checkAll("TailRec[Int]", DeferTests[TailRec].defer[Int])
checkAll("Defer[TailRec]", SerializableTests.serializable(Defer[TailRec]))
}

0 comments on commit fc7b8b9

Please sign in to comment.