Skip to content

Commit

Permalink
Add default resolution info detection heuristics
Browse files Browse the repository at this point in the history
  • Loading branch information
ileasile committed Aug 2, 2020
1 parent ea1dc5a commit 5d984d9
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import java.net.URL
import java.nio.file.Paths

class LibraryFactory(
var defaultResolutionInfo: LibraryResolutionInfo,
val resolutionInfoProvider: ResolutionInfoProvider,
private val parsers: Map<String, LibraryResolutionInfoParser> = defaultParsers,
) {
fun parseReferenceWithArgs(str: String): Pair<LibraryReference, List<Variable>> {
Expand All @@ -25,16 +25,16 @@ class LibraryFactory(

private fun parseResolutionInfo(string: String): LibraryResolutionInfo {
// In case of empty string after `@`: %use lib@
if(string.isBlank()) return defaultResolutionInfo
if(string.isBlank()) return resolutionInfoProvider.get()

val (type, vars) = parseCall(string, Brackets.SQUARE)
val parser = parsers[type] ?: return LibraryResolutionInfo.getInfoByRef(type)
val parser = parsers[type] ?: return resolutionInfoProvider.get(type)
return parser.getInfo(vars)
}

private fun parseReference(string: String): LibraryReference {
val sepIndex = string.indexOf('@')
if (sepIndex == -1) return LibraryReference(defaultResolutionInfo, string)
if (sepIndex == -1) return LibraryReference(resolutionInfoProvider.get(), string)

val nameString = string.substring(0, sepIndex)
val infoString = string.substring(sepIndex + 1)
Expand All @@ -43,8 +43,6 @@ class LibraryFactory(
}

companion object {
fun withDefaultDirectoryResolution(dir: File) = LibraryFactory(LibraryResolutionInfo.ByDir(dir))

private val defaultParsers = listOf(
LibraryResolutionInfoParser.make("ref", listOf(Parameter.Required("ref"))) { args ->
LibraryResolutionInfo.getInfoByRef(args["ref"] ?: error("Argument 'ref' should be specified"))
Expand All @@ -59,5 +57,9 @@ class LibraryFactory(
LibraryResolutionInfo.ByURL(URL(args["url"] ?: error("Argument 'url' should be specified")))
},
).map { it.name to it }.toMap()

val EMPTY = LibraryFactory(EmptyResolutionInfoProvider)

fun withDefaultDirectoryResolution(dir: File) = LibraryFactory(StandardResolutionInfoProvider(LibraryResolutionInfo.ByDir(dir)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package org.jetbrains.kotlin.jupyter.libraries

import org.jetbrains.kotlin.jupyter.GitHubApiPrefix
import org.jetbrains.kotlin.jupyter.LibrariesDir
import java.io.File
import java.net.URL

interface ResolutionInfoProvider {
var fallback: LibraryResolutionInfo

fun get(): LibraryResolutionInfo = fallback
fun get(string: String): LibraryResolutionInfo
}

object EmptyResolutionInfoProvider : ResolutionInfoProvider {
private val fallbackInfo = LibraryResolutionInfo.ByNothing()

override var fallback: LibraryResolutionInfo
get() = fallbackInfo
set(_) {}

override fun get(string: String) = LibraryResolutionInfo.getInfoByRef(string)
}

class StandardResolutionInfoProvider(override var fallback: LibraryResolutionInfo) : ResolutionInfoProvider {
override fun get(string: String): LibraryResolutionInfo {
return tryGetAsRef(string) ?: tryGetAsDir(string) ?: tryGetAsFile(string) ?: tryGetAsURL(string) ?: fallback
}

private fun tryGetAsRef(ref: String): LibraryResolutionInfo? {
val response = khttp.get("$GitHubApiPrefix/contents/$LibrariesDir?ref=$ref")
return if (response.statusCode == 200) LibraryResolutionInfo.getInfoByRef(ref) else null
}

private fun tryGetAsDir(dirName: String): LibraryResolutionInfo? {
val file = File(dirName)
return if (file.isDirectory) LibraryResolutionInfo.ByDir(file) else null
}

private fun tryGetAsFile(fileName: String): LibraryResolutionInfo? {
val file = File(fileName)
return if (file.isFile) LibraryResolutionInfo.ByFile(file) else null
}

private fun tryGetAsURL(url: String): LibraryResolutionInfo? {
val response = khttp.get(url)
return if (response.statusCode == 200) LibraryResolutionInfo.ByURL(URL(url)) else null
}
}
10 changes: 5 additions & 5 deletions src/main/kotlin/org/jetbrains/kotlin/jupyter/libraries/util.kt
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@ enum class DefaultInfoSwitch {
GIT_REFERENCE, DIRECTORY
}

class LibraryFactoryDefaultInfoSwitcher<T>(private val libraryFactory: LibraryFactory, initialSwitchVal: T, private val switcher: (T) -> LibraryResolutionInfo) {
class LibraryFactoryDefaultInfoSwitcher<T>(private val infoProvider: ResolutionInfoProvider, initialSwitchVal: T, private val switcher: (T) -> LibraryResolutionInfo) {
private val defaultInfoCache = hashMapOf<T, LibraryResolutionInfo>()

var switch: T = initialSwitchVal
set(value) {
libraryFactory.defaultResolutionInfo = defaultInfoCache.getOrPut(value) { switcher(value) }
infoProvider.fallback = defaultInfoCache.getOrPut(value) { switcher(value) }
field = value
}

companion object {
fun default(factory: LibraryFactory, defaultDir: File, defaultRef: String): LibraryFactoryDefaultInfoSwitcher<DefaultInfoSwitch> {
val initialInfo = factory.defaultResolutionInfo
fun default(provider: ResolutionInfoProvider, defaultDir: File, defaultRef: String): LibraryFactoryDefaultInfoSwitcher<DefaultInfoSwitch> {
val initialInfo = provider.fallback
val dirInfo = if (initialInfo is LibraryResolutionInfo.ByDir) initialInfo else LibraryResolutionInfo.ByDir(defaultDir)
val refInfo = if (initialInfo is LibraryResolutionInfo.ByGitRef) initialInfo else LibraryResolutionInfo.getInfoByRef(defaultRef)
return LibraryFactoryDefaultInfoSwitcher(factory, DefaultInfoSwitch.DIRECTORY) { switch ->
return LibraryFactoryDefaultInfoSwitcher(provider, DefaultInfoSwitch.DIRECTORY) { switch ->
when(switch) {
DefaultInfoSwitch.DIRECTORY -> dirInfo
DefaultInfoSwitch.GIT_REFERENCE -> refInfo
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/org/jetbrains/kotlin/jupyter/magics.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ data class MagicProcessingResult(val code: String, val libraries: List<LibraryDe

class MagicsProcessor(val repl: ReplOptions, private val libraries: LibrariesProcessor) {

private val libraryResolutionInfoSwitcher = LibraryFactoryDefaultInfoSwitcher.default(libraries.libraryFactory, repl.librariesDir, repl.currentBranch)
private val libraryResolutionInfoSwitcher = LibraryFactoryDefaultInfoSwitcher.default(libraries.libraryFactory.resolutionInfoProvider, repl.librariesDir, repl.currentBranch)

private fun updateOutputConfig(conf: OutputConfig, argv: List<String>): OutputConfig {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.jetbrains.kotlin.jupyter.Message
import org.jetbrains.kotlin.jupyter.defaultRuntimeProperties
import org.jetbrains.kotlin.jupyter.iKotlinClass
import org.jetbrains.kotlin.jupyter.kernelServer
import org.jetbrains.kotlin.jupyter.libraries.EmptyResolutionInfoProvider
import org.jetbrains.kotlin.jupyter.libraries.LibraryFactory
import org.jetbrains.kotlin.jupyter.libraries.LibraryResolutionInfo
import org.jetbrains.kotlin.jupyter.makeHeader
Expand Down Expand Up @@ -37,7 +38,7 @@ open class KernelServerTestsBase {
scriptClasspath = classpath,
resolverConfig = null,
homeDir = File(""),
libraryFactory = LibraryFactory(LibraryResolutionInfo.ByNothing())
libraryFactory = LibraryFactory.EMPTY
)

private val sessionId = UUID.randomUUID().toString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.jetbrains.kotlin.jupyter.MagicsProcessor
import org.jetbrains.kotlin.jupyter.OutputConfig
import org.jetbrains.kotlin.jupyter.ReplOptions
import org.jetbrains.kotlin.jupyter.defaultRuntimeProperties
import org.jetbrains.kotlin.jupyter.libraries.EmptyResolutionInfoProvider
import org.jetbrains.kotlin.jupyter.libraries.LibraryFactory
import org.jetbrains.kotlin.jupyter.libraries.LibraryResolutionInfo
import org.jetbrains.kotlin.jupyter.repl.SourceCodeImpl
Expand All @@ -19,7 +20,7 @@ import java.io.File
import kotlin.test.assertTrue

class ParseArgumentsTests {
private val libraryFactory = LibraryFactory(LibraryResolutionInfo.ByNothing())
private val libraryFactory = LibraryFactory.EMPTY

@Test
fun test1() {
Expand Down Expand Up @@ -101,7 +102,7 @@ class ParseMagicsTests {
private val options = TestReplOptions()

private fun test(code: String, expectedProcessedCode: String, librariesChecker: (List<LibraryDefinition>) -> Unit = {}) {
val libraryFactory = LibraryFactory(LibraryResolutionInfo.ByNothing())
val libraryFactory = LibraryFactory.EMPTY
val processor = MagicsProcessor(options, LibrariesProcessor(libraryFactory.testResolverConfig.libraries, defaultRuntimeProperties, libraryFactory))
with(processor.processMagics(code, tryIgnoreErrors = true)) {
assertEquals(expectedProcessedCode, this.code)
Expand Down
15 changes: 11 additions & 4 deletions src/test/kotlin/org/jetbrains/kotlin/jupyter/test/replTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ abstract class AbstractReplTest {

companion object {
@JvmStatic
protected val libraryFactory = LibraryFactory(LibraryResolutionInfo.ByNothing())
protected val libraryFactory = LibraryFactory.EMPTY

@JvmStatic
protected val homeDir = File("")
Expand Down Expand Up @@ -569,15 +569,16 @@ class ReplWithResolverTest : AbstractReplTest() {
@Test
fun testDefaultInfoSwitcher() {
val repl = getReplWithStandardResolver()
val infoProvider = repl.libraryFactory.resolutionInfoProvider

val initialDefaultResolutionInfo = repl.libraryFactory.defaultResolutionInfo
val initialDefaultResolutionInfo = infoProvider.fallback
assertTrue(initialDefaultResolutionInfo is LibraryResolutionInfo.ByDir)

repl.eval("%useLatestDescriptors")
assertTrue(repl.libraryFactory.defaultResolutionInfo is LibraryResolutionInfo.ByGitRef)
assertTrue(infoProvider.fallback is LibraryResolutionInfo.ByGitRef)

repl.eval("%useLatestDescriptors -off")
assertTrue(repl.libraryFactory.defaultResolutionInfo === initialDefaultResolutionInfo)
assertTrue(infoProvider.fallback === initialDefaultResolutionInfo)
}

@Test
Expand All @@ -604,6 +605,12 @@ class ReplWithResolverTest : AbstractReplTest() {
assertEquals(1, displays.count())
assertNull(res3.resultValue)
displays.clear()

val res4 = repl.eval("""
%use @$libraryPath(name=z, value=44)
z
""".trimIndent())
assertEquals(44, res4.resultValue)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fun Collection<Pair<String, String>>.toLibraries(libraryFactory: LibraryFactory)
}

fun LibraryFactory.getResolverFromNamesMap(map: Map<String, LibraryDescriptor>): LibraryResolver {
return InMemoryLibraryResolver(null, map.mapKeys { entry -> LibraryReference(defaultResolutionInfo, entry.key) })
return InMemoryLibraryResolver(null, map.mapKeys { entry -> LibraryReference(resolutionInfoProvider.get(), entry.key) })
}

fun readLibraries(basePath: String? = null): Map<String, JsonObject> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import org.jetbrains.kotlin.jupyter.ReplCompilerException
import org.jetbrains.kotlin.jupyter.ReplForJupyterImpl
import org.jetbrains.kotlin.jupyter.ResolverConfig
import org.jetbrains.kotlin.jupyter.defaultRepositories
import org.jetbrains.kotlin.jupyter.libraries.EmptyResolutionInfoProvider
import org.jetbrains.kotlin.jupyter.libraries.LibraryFactory
import org.jetbrains.kotlin.jupyter.libraries.LibraryResolutionInfo
import org.jetbrains.kotlin.jupyter.libraries.parseLibraryDescriptors
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
Expand All @@ -30,7 +30,7 @@ class TypeProviderTests {
""".trimIndent()
val cp = classpath + File(TypeProviderReceiver::class.java.protectionDomain.codeSource.location.toURI().path)
val libJsons = mapOf("mylib" to parser.parse(StringBuilder(descriptor)) as JsonObject)
val libraryFactory = LibraryFactory(LibraryResolutionInfo.ByNothing())
val libraryFactory = LibraryFactory.EMPTY
val config = ResolverConfig(defaultRepositories, libraryFactory.getResolverFromNamesMap(parseLibraryDescriptors(libJsons)))
val repl = ReplForJupyterImpl(libraryFactory, cp, null, config, scriptReceivers = listOf(TypeProviderReceiver()))

Expand Down

0 comments on commit 5d984d9

Please sign in to comment.