Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comm-message handlers and support for debugPort config retrieval #375

Merged
merged 1 commit into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.jetbrains.kotlinx.jupyter.messaging

import kotlinx.serialization.Serializable

object ProvidedCommMessages {
const val OPEN_DEBUG_PORT_TARGET: String = "open_debug_port_target"
}

@Serializable
class OpenDebugPortReply(
val port: Int?
) : OkReply()
3 changes: 3 additions & 0 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ class JupyterConnectionImpl(

override val stdinIn = StdinInputStream()

override val debugPort: Int?
get() = config.debugPort

private var _contextMessage: RawMessage? = null
override fun setContextMessage(message: RawMessage?) {
_contextMessage = message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ interface JupyterConnectionInternal : JupyterConnection {
val executor: JupyterExecutor
val stdinIn: InputStream

val debugPort: Int?

fun sendStatus(status: KernelStatus, incomingMessage: RawMessage? = null)
fun doWrappedInBusyIdle(incomingMessage: RawMessage? = null, action: () -> Unit)
fun updateSessionInfo(message: RawMessage)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.jetbrains.kotlinx.jupyter.messaging

import kotlinx.serialization.json.JsonObject
import org.jetbrains.kotlinx.jupyter.ReplForJupyterImpl
import org.jetbrains.kotlinx.jupyter.api.libraries.Comm
import org.jetbrains.kotlinx.jupyter.api.libraries.sendData
import org.jetbrains.kotlinx.jupyter.messaging.ProvidedCommMessages.OPEN_DEBUG_PORT_TARGET

interface CommHandler {
val targetId: String

fun onReceive(comm: Comm, messageContent: JsonObject, repl: ReplForJupyterImpl)
}

class DebugPortCommHandler : CommHandler {
override val targetId: String
get() = OPEN_DEBUG_PORT_TARGET

override fun onReceive(comm: Comm, messageContent: JsonObject, repl: ReplForJupyterImpl) {
comm.sendData(OpenDebugPortReply(repl.debugPort))
}
}
6 changes: 4 additions & 2 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ interface ReplOptions {
var executedCodeLogging: ExecutedCodeLogging
var writeCompiledClasses: Boolean
var outputConfig: OutputConfig
val debugPort: Int?
}

interface ReplForJupyter {
Expand Down Expand Up @@ -188,7 +189,8 @@ class ReplForJupyterImpl(
private val scriptReceivers: List<Any> = emptyList(),
override val isEmbedded: Boolean = false,
override val notebook: MutableNotebook,
override val librariesScanner: LibrariesScanner
override val librariesScanner: LibrariesScanner,
override val debugPort: Int? = null
) : ReplForJupyter, ReplOptions, BaseKernelHost, UserHandlesProvider {

override val currentBranch: String
Expand Down Expand Up @@ -418,7 +420,7 @@ class ReplForJupyterImpl(
@Suppress("unused")
private fun printUsagesInfo(cellId: Int, usedVariables: Set<String>?) {
log.debug(buildString {
if (usedVariables == null || usedVariables.isEmpty()) {
if (usedVariables.isNullOrEmpty()) {
append("No usages for cell $cellId")
return@buildString
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import org.jetbrains.kotlinx.jupyter.libraries.EmptyResolutionInfoProvider
import org.jetbrains.kotlinx.jupyter.libraries.LibrariesScanner
import org.jetbrains.kotlinx.jupyter.libraries.LibraryResolver
import org.jetbrains.kotlinx.jupyter.libraries.ResolutionInfoProvider
import org.jetbrains.kotlinx.jupyter.messaging.CommHandler
import org.jetbrains.kotlinx.jupyter.messaging.CommManagerImpl
import org.jetbrains.kotlinx.jupyter.messaging.DebugPortCommHandler
import org.jetbrains.kotlinx.jupyter.messaging.DisplayHandler
import org.jetbrains.kotlinx.jupyter.messaging.NoOpDisplayHandler
import java.io.File
Expand All @@ -28,4 +30,7 @@ abstract class BaseReplFactory : ReplFactory() {
override fun provideIsEmbedded() = false
override fun provideLibrariesScanner(): LibrariesScanner = LibrariesScanner(notebook)
override fun provideCommManager(): CommManager = CommManagerImpl(connection)
override fun provideCommHandlers(): List<CommHandler> = listOf(
DebugPortCommHandler()
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ object MockJupyterConnection : JupyterConnectionInternal {
get() = throw NotImplementedError()
override val stdinIn: InputStream
get() = throw NotImplementedError()
override val debugPort: Int?
get() = null

override fun sendStatus(status: KernelStatus, incomingMessage: RawMessage?) {
throw NotImplementedError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.jupyter.api.libraries.CommManager
import org.jetbrains.kotlinx.jupyter.libraries.LibrariesScanner
import org.jetbrains.kotlinx.jupyter.libraries.LibraryResolver
import org.jetbrains.kotlinx.jupyter.libraries.ResolutionInfoProvider
import org.jetbrains.kotlinx.jupyter.messaging.CommHandler
import org.jetbrains.kotlinx.jupyter.messaging.DisplayHandler
import org.jetbrains.kotlinx.jupyter.messaging.JupyterConnectionInternal
import java.io.File
Expand All @@ -26,8 +27,19 @@ abstract class ReplFactory {
scriptReceivers,
isEmbedded,
notebook,
librariesScanner
)
librariesScanner,
connection.debugPort
).also { repl ->
commHandlers.forEach { handler ->
repl.notebook.commManager.registerCommTarget(handler.targetId) { comm, data ->
// handler.onReceive(comm, data, repl) // maybe send right away?

comm.onMessage {
handler.onReceive(comm, it, repl)
}
}
}
}
}

protected val resolutionInfoProvider by lazy { provideResolutionInfoProvider() }
Expand Down Expand Up @@ -69,6 +81,17 @@ abstract class ReplFactory {
protected val commManager: CommManager by lazy { provideCommManager() }
protected abstract fun provideCommManager(): CommManager

protected val commHandlers: List<CommHandler> by lazy { provideCommHandlers() }
protected abstract fun provideCommHandlers(): List<CommHandler>

// TODO: add other methods incl. display handler and socket messages listener
// Inheritors should be constructed of connection (JupyterConnection)

init {
val uniqueTargets = commHandlers.map { it.targetId }.toSet().size
assert(uniqueTargets == commHandlers.size) {
val duplicates = commHandlers.groupingBy { it }.eachCount().filter { it.value > 1 }.keys
"Duplicate bundled comm targets found! $duplicates"
}
}
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
package org.jetbrains.kotlinx.jupyter.test

import io.kotest.matchers.shouldBe
import io.kotest.matchers.types.shouldBeTypeOf
import org.jetbrains.kotlinx.jupyter.ReplConfig
import org.jetbrains.kotlinx.jupyter.defaultRuntimeProperties
import org.jetbrains.kotlinx.jupyter.kernelServer
import org.jetbrains.kotlinx.jupyter.libraries.EmptyResolutionInfoProvider
import org.jetbrains.kotlinx.jupyter.messaging.KernelStatus
import org.jetbrains.kotlinx.jupyter.messaging.Message
import org.jetbrains.kotlinx.jupyter.messaging.MessageContent
import org.jetbrains.kotlinx.jupyter.messaging.MessageData
import org.jetbrains.kotlinx.jupyter.messaging.MessageType
import org.jetbrains.kotlinx.jupyter.messaging.StatusReply
import org.jetbrains.kotlinx.jupyter.messaging.makeHeader
import org.jetbrains.kotlinx.jupyter.messaging.toMessage
import org.jetbrains.kotlinx.jupyter.protocol.HMAC
import org.jetbrains.kotlinx.jupyter.protocol.JupyterSocket
import org.jetbrains.kotlinx.jupyter.protocol.JupyterSocketInfo
import org.jetbrains.kotlinx.jupyter.protocol.JupyterSocketSide
import org.jetbrains.kotlinx.jupyter.protocol.SocketWrapper
import org.jetbrains.kotlinx.jupyter.protocol.createSocket
import org.jetbrains.kotlinx.jupyter.protocol.receiveRawMessage
import org.jetbrains.kotlinx.jupyter.sendMessage
Expand All @@ -38,7 +43,7 @@ import kotlin.concurrent.thread
abstract class KernelServerTestsBase {
protected abstract val context: ZMQ.Context

private val kernelConfig = createKotlinKernelConfig(
protected val kernelConfig = createKotlinKernelConfig(
ports = createKernelPorts { randomPort() },
signatureKey = "",
scriptClasspath = classpath,
Expand Down Expand Up @@ -121,6 +126,19 @@ abstract class KernelServerTestsBase {

fun JupyterSocket.receiveMessage() = socket.receiveRawMessage(socket.recv(), hmac).toMessage()

fun JupyterSocket.receiveStatusReply(): StatusReply {
(this as? SocketWrapper)?.name shouldBe JupyterSocketInfo.IOPUB.name
receiveMessage().apply {
return content.shouldBeTypeOf()
}
}

inline fun JupyterSocket.wrapActionInBusyIdleStatusChange(action: () -> Unit) {
receiveStatusReply().status shouldBe KernelStatus.BUSY
action()
receiveStatusReply().status shouldBe KernelStatus.IDLE
}

companion object {
private val rng = Random()
private val usedPorts: MutableSet<Int> = ConcurrentHashMap.newKeySet()
Expand Down
42 changes: 35 additions & 7 deletions src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/executeTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ import org.jetbrains.kotlinx.jupyter.messaging.IsCompleteReply
import org.jetbrains.kotlinx.jupyter.messaging.IsCompleteRequest
import org.jetbrains.kotlinx.jupyter.messaging.KernelStatus
import org.jetbrains.kotlinx.jupyter.messaging.Message
import org.jetbrains.kotlinx.jupyter.messaging.MessageStatus
import org.jetbrains.kotlinx.jupyter.messaging.MessageType
import org.jetbrains.kotlinx.jupyter.messaging.OpenDebugPortReply
import org.jetbrains.kotlinx.jupyter.messaging.ProvidedCommMessages
import org.jetbrains.kotlinx.jupyter.messaging.StatusReply
import org.jetbrains.kotlinx.jupyter.messaging.StreamResponse
import org.jetbrains.kotlinx.jupyter.messaging.jsonObject
Expand Down Expand Up @@ -450,20 +453,45 @@ class ExecuteTests : KernelServerTestsBase() {
)
)

iopub.receiveMessage().apply {
val c = content.shouldBeTypeOf<StatusReply>()
c.status shouldBe KernelStatus.BUSY
}
iopub.receiveStatusReply().status shouldBe KernelStatus.BUSY

iopub.receiveMessage().apply {
val c = content.shouldBeTypeOf<CommMsg>()
c.commId shouldBe commId
c.data["y"]!!.jsonPrimitive.content shouldBe "received: 4321"
}
}

iopub.receiveMessage().apply {
val c = content.shouldBeTypeOf<StatusReply>()
c.status shouldBe KernelStatus.IDLE
@Test
fun testDebugPortCommHandler() {
val shell = shell!!
val iopub = ioPub!!

val targetName = ProvidedCommMessages.OPEN_DEBUG_PORT_TARGET
val commId = "some"
val actualDebugPort = kernelConfig.debugPort

shell.sendMessage(
MessageType.COMM_OPEN,
CommOpen(
commId,
targetName
)
)

shell.sendMessage(
MessageType.COMM_MSG,
CommMsg(commId)
)

iopub.wrapActionInBusyIdleStatusChange {
iopub.receiveMessage().apply {
val c = content.shouldBeTypeOf<CommMsg>()
val data = Json.decodeFromJsonElement<OpenDebugPortReply>(c.data).shouldBeTypeOf<OpenDebugPortReply>()
c.commId shouldBe commId
data.port shouldBe actualDebugPort
data.status shouldBe MessageStatus.OK
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class ParseMagicsTests {
override var executedCodeLogging = ExecutedCodeLogging.OFF
override var writeCompiledClasses = false
override var outputConfig = OutputConfig()
override val debugPort: Int? = null
}

private val options = TestReplOptions()
Expand Down