Skip to content

Commit

Permalink
implement trampolines for flatmap, map, filter, merge. (kotest#2900)
Browse files Browse the repository at this point in the history
* implement trampolines for flatmap, map, filter, and merge. Remove suspension point allocation in single shot builder.

* make sure flatmap preserves immutability
  • Loading branch information
myuwono authored Mar 26, 2022
1 parent fe90190 commit 8e9c423
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import kotlin.coroutines.Continuation
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
import kotlin.coroutines.RestrictsSuspension
import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED
import kotlin.coroutines.intrinsics.startCoroutineUninterceptedOrReturn
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
import kotlin.coroutines.resume
Expand Down Expand Up @@ -213,8 +212,8 @@ fun <A> arbitraryBuilder(
edgecaseFn: EdgecaseFn<A>? = null,
builderFn: suspend ArbitraryBuilderContext.(RandomSource) -> A
): Arb<A> = object : Arb<A>() {
override fun edgecase(rs: RandomSource): A? = singleShotArb().edgecase(rs)
override fun sample(rs: RandomSource): Sample<A> = singleShotArb().sample(rs)
override fun edgecase(rs: RandomSource): A? = singleShotArb(SingleShotGenerationMode.Edgecase, rs).edgecase(rs)
override fun sample(rs: RandomSource): Sample<A> = singleShotArb(SingleShotGenerationMode.Sample, rs).sample(rs)
override val classifier: Classifier<out A>? = classifier

/**
Expand All @@ -228,13 +227,13 @@ fun <A> arbitraryBuilder(
* will provide another single shot Arb. Hence the reason why this function is invoked
* on every call to [sample] / [edgecase].
*/
private fun singleShotArb(): Arb<A> {
val restrictedContinuation = SingleShotArbContinuation.Restricted {
private fun singleShotArb(mode: SingleShotGenerationMode, rs: RandomSource): Arb<A> {
val restrictedContinuation = SingleShotArbContinuation.Restricted(mode, rs) {
/**
* At the end of the suspension we got a generated value [A] as a comprehension result.
* This value can either be a sample, or an edgecase.
*/
val value: A = builderFn(randomSource.bind())
val value: A = builderFn(rs)

/**
* Here we point A into an Arb<A> with the appropriate enrichments including
Expand Down Expand Up @@ -263,8 +262,8 @@ suspend fun <A> suspendArbitraryBuilder(
fn: suspend GenerateArbitraryBuilderContext.(RandomSource) -> A
): Arb<A> = suspendCoroutineUninterceptedOrReturn { cont ->
val arb = object : Arb<A>() {
override fun edgecase(rs: RandomSource): A? = singleShotArb().edgecase(rs)
override fun sample(rs: RandomSource): Sample<A> = singleShotArb().sample(rs)
override fun edgecase(rs: RandomSource): A? = singleShotArb(SingleShotGenerationMode.Edgecase, rs).edgecase(rs)
override fun sample(rs: RandomSource): Sample<A> = singleShotArb(SingleShotGenerationMode.Sample, rs).sample(rs)
override val classifier: Classifier<out A>? = classifier

/**
Expand All @@ -278,13 +277,13 @@ suspend fun <A> suspendArbitraryBuilder(
* will provide another single shot Arb. Hence the reason why this function is invoked
* on every call to [sample] / [edgecase].
*/
private fun singleShotArb(): Arb<A> {
val suspendableContinuation = SingleShotArbContinuation.Suspendedable(cont.context) {
private fun singleShotArb(genMode: SingleShotGenerationMode, rs: RandomSource): Arb<A> {
val suspendableContinuation = SingleShotArbContinuation.Suspendedable(genMode, rs, cont.context) {
/**
* At the end of the suspension we got a generated value [A] as a comprehension result.
* This value can either be a sample, or an edgecase.
*/
val value: A = fn(randomSource.bind())
val value: A = fn(rs)

/**
* Here we point A into an Arb<A> with the appropriate enrichments including
Expand All @@ -303,13 +302,6 @@ suspend fun <A> suspendArbitraryBuilder(
cont.resume(arb)
}

/**
* passthrough arb to extract the propagated RandomSource. It's important to pass rs through both the
* sample and the edgecases to ensure that flatMap can evaluate on both [sample] and [edgecase]
* regardless of any absence of edgecases in the firstly bound arb.
*/
private val randomSource: Arb<RandomSource> = ArbitraryBuilder.create { it }.withEdgecaseFn { it }.build()

typealias SampleFn<A> = (RandomSource) -> A
typealias EdgecaseFn<A> = (RandomSource) -> A?

Expand Down Expand Up @@ -352,18 +344,29 @@ interface ArbitraryBuilderContext : BaseArbitraryBuilderSyntax

interface GenerateArbitraryBuilderContext : BaseArbitraryBuilderSyntax

enum class SingleShotGenerationMode { Edgecase, Sample }

sealed class SingleShotArbContinuation<F : BaseArbitraryBuilderSyntax, A>(
override val context: CoroutineContext,
private val generationMode: SingleShotGenerationMode,
private val randomSource: RandomSource,
private val fn: suspend F.() -> Arb<A>
) : Continuation<Arb<A>>, BaseArbitraryBuilderSyntax {

class Restricted<A>(
genMode: SingleShotGenerationMode,
rs: RandomSource,
fn: suspend ArbitraryBuilderContext.() -> Arb<A>
) : SingleShotArbContinuation<ArbitraryBuilderContext, A>(EmptyCoroutineContext, fn), ArbitraryBuilderContext
) : SingleShotArbContinuation<ArbitraryBuilderContext, A>(EmptyCoroutineContext, genMode, rs, fn),
ArbitraryBuilderContext

class Suspendedable<A>(
genMode: SingleShotGenerationMode,
rs: RandomSource,
override val context: CoroutineContext,
fn: suspend GenerateArbitraryBuilderContext.() -> Arb<A>
) : SingleShotArbContinuation<GenerateArbitraryBuilderContext, A>(context, fn), GenerateArbitraryBuilderContext
) : SingleShotArbContinuation<GenerateArbitraryBuilderContext, A>(context, genMode, rs, fn),
GenerateArbitraryBuilderContext

private lateinit var returnedArb: Arb<A>
private var hasExecuted: Boolean = false
Expand All @@ -373,24 +376,9 @@ sealed class SingleShotArbContinuation<F : BaseArbitraryBuilderSyntax, A>(
result.map { resultArb -> returnedArb = resultArb }.getOrThrow()
}

override suspend fun <T> Arb<T>.bind(): T = suspendCoroutineUninterceptedOrReturn { c ->
// we call flatMap on the bound arb, and then returning the `returnedArb`, without modification
returnedArb = this.flatMap { value: T ->
/**
* we resume the suspension with the value passed inside the flatMap function.
* this [value] can be either sample or edgecases. This is important
* because from the point of view of a user of kotest, when we talk about transformation,
* we care about transforming the generated value of this arb for both sample and edgecases.
*/
c.resume(value)
returnedArb
}
/**
* Notice this block returns the special COROUTINE_SUSPENDED value
* this means the Continuation provided to the block shall be resumed by invoking [resumeWith]
* at some moment in the future when the result becomes available to resume the computation.
*/
COROUTINE_SUSPENDED
override suspend fun <T> Arb<T>.bind(): T = when (generationMode) {
SingleShotGenerationMode.Edgecase -> this.edgecase(randomSource) ?: this.sample(randomSource).value
SingleShotGenerationMode.Sample -> this.sample(randomSource).value
}

/**
Expand All @@ -404,7 +392,10 @@ sealed class SingleShotArbContinuation<F : BaseArbitraryBuilderSyntax, A>(
*/
fun F.createSingleShotArb(): Arb<A> {
require(!hasExecuted) { "continuation has already been executed, if you see this error please raise a bug report" }
fn.startCoroutineUninterceptedOrReturn(this@createSingleShotArb, this@SingleShotArbContinuation)
val result = fn.startCoroutineUninterceptedOrReturn(this@createSingleShotArb, this@SingleShotArbContinuation)

@Suppress("UNCHECKED_CAST")
returnedArb = result as Arb<A>
return returnedArb
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@ import io.kotest.property.filter
* predicate. This gen will continue to request elements from the underlying gen until one satisfies
* the predicate.
*/
fun <A> Arb<A>.filter(predicate: (A) -> Boolean): Arb<A> = object : Arb<A>() {
fun <A> Arb<A>.filter(predicate: (A) -> Boolean): Arb<A> = trampoline { sampleA ->
object : Arb<A>() {
override fun edgecase(rs: RandomSource): A? =
sequenceOf(sampleA.value)
.plus(generateSequence { this@filter.edgecase(rs) })
.take(PropertyTesting.maxFilterAttempts)
.filter(predicate)
.firstOrNull()

override fun edgecase(rs: RandomSource): A? =
generateSequence { this@filter.edgecase(rs) }
.take(PropertyTesting.maxFilterAttempts)
.filter(predicate)
.firstOrNull()

override fun sample(rs: RandomSource): Sample<A> {
val sample = this@filter.samples(rs).filter { predicate(it.value) }.first()
return Sample(sample.value, sample.shrinks.filter(predicate) ?: RTree({ sample.value }))
override fun sample(rs: RandomSource): Sample<A> {
val sample = sequenceOf(sampleA).plus(this@filter.samples(rs)).filter { predicate(it.value) }.first()
return Sample(sample.value, sample.shrinks.filter(predicate) ?: RTree({ sample.value }))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,69 @@ import io.kotest.property.map
/**
* Returns a new [Arb] which takes its elements from the receiver and maps them using the supplied function.
*/
fun <A, B> Arb<A>.map(f: (A) -> B): Arb<B> = object : Arb<B>() {

override fun edgecase(rs: RandomSource): B? = this@map.edgecase(rs)?.let(f)

override fun sample(rs: RandomSource): Sample<B> =
this@map.sample(rs).let {
Sample(f(it.value), it.shrinks.map(f))
fun <A, B> Arb<A>.map(fn: (A) -> B): Arb<B> = trampoline { sampleA ->
object : Arb<B>() {
override fun edgecase(rs: RandomSource): B? = fn(sampleA.value)
override fun sample(rs: RandomSource): Sample<B> {
val value = fn(sampleA.value)
val shrinks = sampleA.shrinks.map(fn)
return Sample(value, shrinks)
}
}
}

/**
* Returns a new [Arb] which takes its elements from the receiver and maps them using the supplied function.
*/
fun <A, B> Arb<A>.flatMap(f: (A) -> Arb<B>): Arb<B> = object : Arb<B>() {
fun <A, B> Arb<A>.flatMap(fn: (A) -> Arb<B>): Arb<B> = trampoline { fn(it.value) }

override fun edgecase(rs: RandomSource): B? {
// generate an edge case, map it to another arb, and generate an edge case again
val a = this@flatMap.edgecase(rs) ?: this@flatMap.next(rs)
return f(a).edgecase(rs)
}
/**
* Returns a new [TrampolineArb] from the receiver [Arb] which composes the operations of [next] lambda
* using a trampoline method. This allows [next] function to be executed without exhausting call stack.
*/
internal fun <A, B> Arb<A>.trampoline(next: (Sample<A>) -> Arb<B>): Arb<B> = when (this) {
is TrampolineArb -> this.thunk(next)
else -> TrampolineArb(this).thunk(next)
}

/**
* The [TrampolineArb] is a special Arb that exchanges call stack with heap.
* In a nutshell, this arb stores command chains to be applied to the original arb inside a list.
* This technique is an imperative reduction of Free Monads. This eliminates the need of creating intermediate
* Trampoline Monad and tail-recursive function on those which can be expensive.
* This minimizes the amount of code and unnecessary object allocation during sample generation in the expense of typesafety.
*
* This is an internal implementation. Do not use this TrampolineArb as is and please do not expose this
* to users outside of the library. For library maintainers, please use the [Arb.trampoline] extension function.
* The extension function will provide some type-guardrails to workaround the loss of types within this Arb.
*/
@Suppress("UNCHECKED_CAST")
internal class TrampolineArb<A> private constructor(
private val first: Arb<A>,
commands: List<(Sample<Any>) -> Arb<Any>>
) : Arb<A>() {
constructor(first: Arb<A>) : this(first, emptyList())

private val commandList: MutableList<(Sample<Any>) -> Arb<Any>> = commands.toMutableList()

fun <A, B> thunk(fn: (Sample<A>) -> Arb<B>): TrampolineArb<B> =
TrampolineArb(
first,
commandList.toList() + (fn as (Sample<Any>) -> Arb<Any>)
) as TrampolineArb<B>

override fun edgecase(rs: RandomSource): A? =
commandList
.fold(first as Arb<Any>) { currentArb, next ->
val currentEdge = currentArb.edgecase(rs) ?: currentArb.sample(rs).value
next(Sample(currentEdge))
}
.edgecase(rs) as A?

override fun sample(rs: RandomSource): Sample<B> = f(this@flatMap.sample(rs).value).sample(rs)
override fun sample(rs: RandomSource): Sample<A> =
commandList
.fold(first as Arb<Any>) { currentArb, next ->
next(currentArb.sample(rs))
}
.sample(rs) as Sample<A>
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@ import io.kotest.property.Sample
* @param other the arg to merge with this one
* @return the merged arg.
*/
fun <A, B : A> Arb<A>.merge(other: Gen<B>): Arb<A> = object : Arb<A>() {

override fun edgecase(rs: RandomSource): A? = when (other) {
is Arb -> listOf(this@merge, other).random(rs.random).edgecase(rs)
is Exhaustive -> this@merge.edgecase(rs)
}
fun <A, B : A> Arb<A>.merge(other: Gen<B>): Arb<A> = trampoline { sampleA ->
object : Arb<A>() {
override fun edgecase(rs: RandomSource): A? = when (other) {
is Arb -> if (rs.random.nextBoolean()) sampleA.value else other.edgecase(rs)
is Exhaustive -> sampleA.value
}

override fun sample(rs: RandomSource): Sample<A> =
if (rs.random.nextBoolean()) {
this@merge.sample(rs)
} else {
when (other) {
is Arb -> other.sample(rs)
is Exhaustive -> other.toArb().sample(rs)
override fun sample(rs: RandomSource): Sample<A> =
if (rs.random.nextBoolean()) {
sampleA
} else {
when (other) {
is Arb -> other.sample(rs)
is Exhaustive -> other.toArb().sample(rs)
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.sksamuel.kotest.property.arbitrary

import io.kotest.assertions.throwables.shouldNotThrowAny
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.collections.shouldContainExactly
Expand Down Expand Up @@ -55,6 +56,17 @@ class BuilderTest : FunSpec() {
}

context("arbitrary builder using restricted continuation") {
test("should be stack safe") {
val arb: Arb<Int> = arbitrary {
(1..100000).map {
Arb.int().bind()
}.last()
}

val result = shouldNotThrowAny { arb.single(RandomSource.seeded(1234)) }
result shouldBe -1486934023
}

test("should be equivalent to chaining flatMaps") {
val arbFlatMaps: Arb<String> =
Arb.string(5, Codepoint.alphanumeric()).withEdgecases("edge1", "edge2").flatMap { first ->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
package com.sksamuel.kotest.property.arbitrary

import io.kotest.assertions.throwables.shouldNotThrow
import io.kotest.assertions.throwables.shouldNotThrowAny
import io.kotest.core.spec.style.FunSpec
import io.kotest.inspectors.forAll
import io.kotest.matchers.collections.shouldContainExactly
import io.kotest.matchers.collections.shouldNotBeIn
import io.kotest.matchers.shouldBe
import io.kotest.property.Arb
import io.kotest.property.EdgeConfig
import io.kotest.property.RandomSource
import io.kotest.property.Sample
import io.kotest.property.arbitrary.filter
import io.kotest.property.arbitrary.int
import io.kotest.property.arbitrary.map
import io.kotest.property.arbitrary.of
import io.kotest.property.arbitrary.single
import io.kotest.property.arbitrary.take
import io.kotest.property.arbitrary.withEdgecases

Expand Down Expand Up @@ -54,4 +59,13 @@ class FilterTest : FunSpec({
}
}
}

test("Arb.filter composition should not exhaust call stack") {
var arb: Arb<Int> = Arb.of(0, 1)
repeat(10000) {
arb = arb.filter { it == 0 }
}
val result = shouldNotThrowAny { arb.single(RandomSource.seeded(1234L)) }
result shouldBe 0
}
})
Loading

0 comments on commit 8e9c423

Please sign in to comment.