Skip to content

Commit

Permalink
Add context to dataloader creation
Browse files Browse the repository at this point in the history
To allow setting the GraphQLContext as the
batch context in a dataloader, pass it through
the KotlinDataLoaderRegistryFactory.generate()
call.
  • Loading branch information
josephlbarnett committed Aug 1, 2022
1 parent 7711942 commit 1a6460d
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 10 deletions.
1 change: 1 addition & 0 deletions executions/graphql-kotlin-dataloader/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ val reactorExtensionsVersion: String by project

dependencies {
api("com.graphql-java:java-dataloader:$graphQLJavaDataLoaderVersion")
api(project(path = ":graphql-kotlin-schema-generator"))
testImplementation("io.projectreactor.kotlin:reactor-kotlin-extensions:$reactorExtensionsVersion")
testImplementation("io.projectreactor:reactor-core:$reactorVersion")
testImplementation("org.junit.jupiter:junit-jupiter-api:$junitVersion")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.expediagroup.graphql.dataloader

import com.expediagroup.graphql.generator.execution.GraphQLContext
import org.dataloader.DataLoader

/**
Expand All @@ -24,5 +25,7 @@ import org.dataloader.DataLoader
*/
interface KotlinDataLoader<K, V> {
val dataLoaderName: String
fun getDataLoader(): DataLoader<K, V>
fun getDataLoader(graphQLContext: GraphQLContext?, graphQLContextMap: Map<*, Any>?): DataLoader<K, V> = getDataLoader()
@Deprecated("Should use getDataLoader(context/contextMap) instead", replaceWith = ReplaceWith("getDataLoader(null, null)"))
fun getDataLoader(): DataLoader<K, V> = TODO("${this::class} needs to implement getDataLoader")
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.expediagroup.graphql.dataloader

import com.expediagroup.graphql.generator.execution.GraphQLContext
import org.dataloader.DataLoaderRegistry

/**
Expand All @@ -30,12 +31,12 @@ class KotlinDataLoaderRegistryFactory(
/**
* Generate [KotlinDataLoaderRegistry] to be used for GraphQL request execution.
*/
fun generate(): KotlinDataLoaderRegistry {
fun generate(graphQLContext: GraphQLContext? = null, graphQLContextMap: Map<*, Any>? = null): KotlinDataLoaderRegistry {
val registry = DataLoaderRegistry()
dataLoaders.forEach { dataLoader ->
registry.register(
dataLoader.dataLoaderName,
dataLoader.getDataLoader()
dataLoader.getDataLoader(graphQLContext, graphQLContextMap)
)
}
return KotlinDataLoaderRegistry(registry)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,18 @@

package com.expediagroup.graphql.dataloader

import com.expediagroup.graphql.generator.execution.GraphQLContext
import com.expediagroup.graphql.generator.extensions.get
import io.mockk.mockk
import kotlinx.coroutines.future.await
import kotlinx.coroutines.runBlocking
import org.dataloader.DataLoader
import org.dataloader.DataLoaderFactory
import org.dataloader.DataLoaderOptions
import org.junit.jupiter.api.Test
import reactor.kotlin.core.publisher.toMono
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue

class KotlinDataLoaderRegistryFactoryTest {
Expand All @@ -39,10 +47,63 @@ class KotlinDataLoaderRegistryFactoryTest {
fun `generate registry with basic loader`() {
val mockLoader: KotlinDataLoader<String, String> = object : KotlinDataLoader<String, String> {
override val dataLoaderName: String = "MockDataLoader"
override fun getDataLoader(): DataLoader<String, String> = mockk()
override fun getDataLoader(graphQLContext: GraphQLContext?, graphQLContextMap: Map<*, Any>?): DataLoader<String, String> = mockk()
}

val registry = KotlinDataLoaderRegistryFactory(listOf(mockLoader)).generate()
assertEquals(1, registry.dataLoaders.size)
}

@Test
fun `generate registry with minimal compilable loader throws TODO`() {
val mockLoader: KotlinDataLoader<String, String> = object : KotlinDataLoader<String, String> {
override val dataLoaderName = "Unimplemented"
}
assertFailsWith(NotImplementedError::class) {
KotlinDataLoaderRegistryFactory(listOf(mockLoader)).generate()
}
}

@Test
fun `generate registry with GraphQLContext`() = runBlocking {
val mockLoader = object : KotlinDataLoader<String, String> {
override val dataLoaderName = "withGraphQLContext"
override fun getDataLoader(graphQLContext: GraphQLContext?, graphQLContextMap: Map<*, Any>?): DataLoader<String, String> {
val options = DataLoaderOptions.newOptions().setBatchLoaderContextProvider {
graphQLContext
}
return DataLoaderFactory.newDataLoader({ keys, environment ->
keys.map { (environment.getContext() as GraphQLContext).toString() }.toMono().toFuture()
}, options)
}
}
val context = object : GraphQLContext {
override fun toString(): String {
return "blah"
}
}
val registry = KotlinDataLoaderRegistryFactory(mockLoader).generate(context)
val result = registry.getDataLoader<String, String>(mockLoader.dataLoaderName).load("123")
registry.dispatchAll()
assertEquals(result.await(), "blah")
}

@Test
fun `generate registry with GraphQLContextMap`() = runBlocking {
val mockLoader = object : KotlinDataLoader<String, String> {
override val dataLoaderName = "withGraphQLContext"
override fun getDataLoader(graphQLContext: GraphQLContext?, graphQLContextMap: Map<*, Any>?): DataLoader<String, String> {
val options = DataLoaderOptions.newOptions().setBatchLoaderContextProvider {
graphql.GraphQLContext.of(graphQLContextMap)
}
return DataLoaderFactory.newDataLoader({ keys, environment ->
keys.map { (environment.getContext() as graphql.GraphQLContext).get<String>() }.toMono().toFuture()
}, options)
}
}
val registry = KotlinDataLoaderRegistryFactory(mockLoader).generate(null, mapOf(String::class to "abc"))
val result = registry.getDataLoader<String, String>(mockLoader.dataLoaderName).load("123")
registry.dispatchAll()
assertEquals(result.await(), "abc")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.expediagroup.graphql.dataloader

import com.expediagroup.graphql.generator.execution.GraphQLContext
import org.dataloader.DataLoader
import org.dataloader.DataLoaderFactory
import org.junit.jupiter.api.Test
Expand All @@ -30,14 +31,14 @@ class KotlinDataLoaderRegistryTest {
fun `Decorator will keep track of DataLoaders futures`() {
val stringToUpperCaseDataLoader: KotlinDataLoader<String, String> = object : KotlinDataLoader<String, String> {
override val dataLoaderName: String = "ToUppercaseDataLoader"
override fun getDataLoader(): DataLoader<String, String> = DataLoaderFactory.newDataLoader { keys ->
override fun getDataLoader(graphQLContext: GraphQLContext?, graphQLContextMap: Map<*, Any>?): DataLoader<String, String> = DataLoaderFactory.newDataLoader { keys ->
keys.toFlux().map(String::uppercase).collectList().delayElement(Duration.ofMillis(300)).toFuture()
}
}

val stringToLowerCaseDataLoader: KotlinDataLoader<String, String> = object : KotlinDataLoader<String, String> {
override val dataLoaderName: String = "ToLowercaseDataLoader"
override fun getDataLoader(): DataLoader<String, String> = DataLoaderFactory.newDataLoader { keys ->
override fun getDataLoader(graphQLContext: GraphQLContext?, graphQLContextMap: Map<*, Any>?): DataLoader<String, String> = DataLoaderFactory.newDataLoader { keys ->
keys.toFlux().map(String::lowercase).collectList().delayElement(Duration.ofMillis(300)).toFuture()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ open class GraphQLRequestHandler(
context: GraphQLContext? = null,
graphQLContext: Map<*, Any> = emptyMap<Any, Any>()
): GraphQLServerResponse {
val dataLoaderRegistry = dataLoaderRegistryFactory?.generate()
val dataLoaderRegistry = dataLoaderRegistryFactory?.generate(context, graphQLContext)
return when (graphQLRequest) {
is GraphQLRequest -> {
val batchGraphQLContext = graphQLContext + getBatchContext(1, dataLoaderRegistry)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.expediagroup.graphql.server.extensions

import com.expediagroup.graphql.dataloader.KotlinDataLoader
import com.expediagroup.graphql.dataloader.KotlinDataLoaderRegistryFactory
import com.expediagroup.graphql.generator.execution.GraphQLContext
import com.expediagroup.graphql.server.types.GraphQLRequest
import io.mockk.mockk
import org.dataloader.DataLoader
Expand Down Expand Up @@ -64,7 +65,7 @@ class RequestExtensionsKtTest {
val dataLoaderRegistry = KotlinDataLoaderRegistryFactory(
object : KotlinDataLoader<String, String> {
override val dataLoaderName: String = "abc"
override fun getDataLoader(): DataLoader<String, String> = mockk()
override fun getDataLoader(graphQLContext: GraphQLContext?, graphQLContextMap: Map<*, Any>?): DataLoader<String, String> = mockk()
}
).generate()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ open class SpringGraphQLSubscriptionHandler(
context: GraphQLContext?,
graphQLContext: Map<*, Any> = emptyMap<Any, Any>()
): Flow<GraphQLResponse<*>> {
val dataLoaderRegistry = dataLoaderRegistryFactory?.generate()
val dataLoaderRegistry = dataLoaderRegistryFactory?.generate(context, graphQLContext)
val input = graphQLRequest.toExecutionInput(dataLoaderRegistry, context, graphQLContext)

return graphQL.execute(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.expediagroup.graphql.server.execution.GraphQLContextFactory
import com.expediagroup.graphql.server.execution.GraphQLRequestHandler
import com.expediagroup.graphql.dataloader.KotlinDataLoader
import com.expediagroup.graphql.dataloader.KotlinDataLoaderRegistryFactory
import com.expediagroup.graphql.generator.execution.GraphQLContext
import com.expediagroup.graphql.server.extensions.getValueFromDataLoader
import com.expediagroup.graphql.server.operations.Query
import com.expediagroup.graphql.server.spring.execution.SpringGraphQLContextFactory
Expand Down Expand Up @@ -193,7 +194,7 @@ class SchemaConfigurationTest {
}

override val dataLoaderName = name
override fun getDataLoader(): DataLoader<String, Foo> = DataLoaderFactory.newDataLoader { keys ->
override fun getDataLoader(graphQLContext: GraphQLContext?, graphQLContextMap: Map<*, Any>?): DataLoader<String, Foo> = DataLoaderFactory.newDataLoader { keys ->
CompletableFuture.supplyAsync {
keys.mapNotNull { Foo(it) }
}
Expand Down

0 comments on commit 1a6460d

Please sign in to comment.