Skip to content

Commit

Permalink
Extract CommManager to a separate interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ileasile committed Jun 21, 2022
1 parent e64f0ac commit 4782275
Show file tree
Hide file tree
Showing 15 changed files with 290 additions and 231 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.jetbrains.kotlinx.jupyter.api

import org.jetbrains.kotlinx.jupyter.api.libraries.CommManager
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterConnection
import org.jetbrains.kotlinx.jupyter.api.libraries.LibraryResolutionRequest

Expand Down Expand Up @@ -102,4 +103,6 @@ interface Notebook {
val libraryRequests: Collection<LibraryResolutionRequest>

val connection: JupyterConnection

val commManager: CommManager
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ interface JupyterConnection {
* Simpler-to-use version of [send].
*/
fun sendReply(socketName: JupyterSocket, parentMessage: RawMessage, type: String, content: JsonObject, metadata: JsonObject? = null)
}

interface CommManager {
/**
* Creates a comm with a given target, generates unique ID for it. Sends comm_open request to frontend
*
Expand Down
4 changes: 3 additions & 1 deletion src/main/kotlin/org/jetbrains/kotlinx/jupyter/apiImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.jetbrains.kotlinx.jupyter.api.Notebook
import org.jetbrains.kotlinx.jupyter.api.RenderersProcessor
import org.jetbrains.kotlinx.jupyter.api.ResultsAccessor
import org.jetbrains.kotlinx.jupyter.api.VariableState
import org.jetbrains.kotlinx.jupyter.api.libraries.CommManager
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterConnection
import org.jetbrains.kotlinx.jupyter.api.libraries.LibraryResolutionRequest
import org.jetbrains.kotlinx.jupyter.repl.impl.SharedReplContext
Expand Down Expand Up @@ -136,7 +137,8 @@ class EvalData(

class NotebookImpl(
private val runtimeProperties: ReplRuntimeProperties,
override val connection: JupyterConnection
override val connection: JupyterConnection,
override val commManager: CommManager,
) : MutableNotebook {
private val cells = hashMapOf<Int, MutableCodeCell>()
override var sharedReplContext: SharedReplContext? = null
Expand Down
213 changes: 23 additions & 190 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,26 @@ import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.jsonObject
import org.jetbrains.kotlinx.jupyter.api.libraries.Comm
import org.jetbrains.kotlinx.jupyter.api.libraries.CommCloseCallback
import org.jetbrains.kotlinx.jupyter.api.libraries.CommMsgCallback
import org.jetbrains.kotlinx.jupyter.api.libraries.CommOpenCallback
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterConnection
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterSocket
import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessage
import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessageCallback
import org.jetbrains.kotlinx.jupyter.api.libraries.header
import org.jetbrains.kotlinx.jupyter.api.libraries.type
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
import org.jetbrains.kotlinx.jupyter.messaging.CommClose
import org.jetbrains.kotlinx.jupyter.messaging.CommMsg
import org.jetbrains.kotlinx.jupyter.messaging.CommOpen
import org.jetbrains.kotlinx.jupyter.messaging.InputReply
import org.jetbrains.kotlinx.jupyter.messaging.InputRequest
import org.jetbrains.kotlinx.jupyter.messaging.JupyterOutType
import org.jetbrains.kotlinx.jupyter.messaging.JupyterConnectionInternal
import org.jetbrains.kotlinx.jupyter.messaging.JupyterServerSocket
import org.jetbrains.kotlinx.jupyter.messaging.KernelStatus
import org.jetbrains.kotlinx.jupyter.messaging.Message
import org.jetbrains.kotlinx.jupyter.messaging.MessageContent
import org.jetbrains.kotlinx.jupyter.messaging.MessageData
import org.jetbrains.kotlinx.jupyter.messaging.MessageHeader
import org.jetbrains.kotlinx.jupyter.messaging.MessageType
import org.jetbrains.kotlinx.jupyter.messaging.RawMessageImpl
import org.jetbrains.kotlinx.jupyter.messaging.StatusReply
import org.jetbrains.kotlinx.jupyter.messaging.StreamResponse
import org.jetbrains.kotlinx.jupyter.messaging.emptyJsonObjectStringBytes
import org.jetbrains.kotlinx.jupyter.messaging.jsonObject
import org.jetbrains.kotlinx.jupyter.messaging.makeHeader
import org.jetbrains.kotlinx.jupyter.messaging.makeJsonHeader
import org.jetbrains.kotlinx.jupyter.messaging.makeReplyMessage
import org.jetbrains.kotlinx.jupyter.messaging.sendMessage
import org.jetbrains.kotlinx.jupyter.messaging.toMessage
import org.jetbrains.kotlinx.jupyter.messaging.toRawMessage
import org.jetbrains.kotlinx.jupyter.util.EMPTY
Expand All @@ -49,9 +38,6 @@ import org.zeromq.ZMQ
import java.io.Closeable
import java.io.IOException
import java.security.SignatureException
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArrayList
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import kotlin.concurrent.thread
Expand All @@ -60,27 +46,20 @@ import kotlin.math.min
typealias SocketMessageCallback = JupyterConnectionImpl.Socket.(Message) -> Unit
typealias SocketRawMessageCallback = JupyterConnectionImpl.Socket.(RawMessage) -> Unit

class JupyterConnectionImpl(val config: KernelConfig) : JupyterConnection, Closeable {
class JupyterConnectionImpl(
val config: KernelConfig
) : JupyterConnectionInternal, Closeable {

private var messageId: List<ByteArray> = listOf(byteArrayOf(1))
private var sessionId = ""
private var username = ""
private var _messageId: List<ByteArray> = listOf(byteArrayOf(1))
override val messageId: List<ByteArray> get() = _messageId

private fun makeDefaultHeader(msgType: MessageType): MessageHeader {
return makeHeader(msgType, sessionId = sessionId, username = username)
}
private var _sessionId = ""
override val sessionId: String get() = _sessionId

fun makeSimpleMessage(msgType: MessageType, content: MessageContent): Message {
return Message(
id = messageId,
data = MessageData(
header = makeDefaultHeader(msgType),
content = content
)
)
}
private var _username = ""
override val username: String get() = _username

inner class Socket(private val socket: JupyterSocketInfo, type: SocketType = socket.zmqKernelType) : ZMQ.Socket(context, type) {
inner class Socket(private val socket: JupyterSocketInfo, type: SocketType = socket.zmqKernelType) : ZMQ.Socket(context, type), JupyterServerSocket {
val name: String get() = socket.name
init {
val port = config.ports[socket.ordinal]
Expand Down Expand Up @@ -137,15 +116,7 @@ class JupyterConnectionImpl(val config: KernelConfig) : JupyterConnection, Close
sendStatus(KernelStatus.IDLE, incomingMessage)
}

fun sendOut(msg: Message, stream: JupyterOutType, text: String) {
sendMessage(makeReplyMessage(msg, header = makeHeader(MessageType.STREAM, msg), content = StreamResponse(stream.optionName(), text)))
}

fun sendMessage(msg: Message) {
sendRawMessage(msg.toRawMessage())
}

fun sendRawMessage(msg: RawMessage) {
override fun sendRawMessage(msg: RawMessage) {
log.debug("[$name] snd>: $msg")
sendRawMessage(msg, hmac)
}
Expand All @@ -166,7 +137,7 @@ class JupyterConnectionImpl(val config: KernelConfig) : JupyterConnection, Close
}
}

val connection: JupyterConnectionImpl = this@JupyterConnectionImpl
override val connection: JupyterConnectionImpl = this@JupyterConnectionImpl
}

inner class StdinInputStream : java.io.InputStream() {
Expand Down Expand Up @@ -227,85 +198,14 @@ class JupyterConnectionImpl(val config: KernelConfig) : JupyterConnection, Close
}
}

inner class CommImpl(
override val target: String,
override val id: String
) : Comm {

private val onMessageCallbacks = mutableListOf<CommMsgCallback>()
private val onCloseCallbacks = mutableListOf<CommCloseCallback>()
private var closed = false

private fun assertOpen() {
if (closed) {
throw AssertionError("Comm '$target' has been already closed")
}
}
override fun send(data: JsonObject) {
assertOpen()
iopub.sendMessage(
makeSimpleMessage(
MessageType.COMM_MSG,
CommMsg(id, data)
)
)
}

override fun onMessage(action: CommMsgCallback): CommMsgCallback {
assertOpen()
onMessageCallbacks.add(action)
return action
}

override fun removeMessageCallback(callback: CommMsgCallback) {
onMessageCallbacks.remove(callback)
}

override fun onClose(action: CommCloseCallback): CommCloseCallback {
assertOpen()
onCloseCallbacks.add(action)
return action
}

override fun removeCloseCallback(callback: CommCloseCallback) {
onCloseCallbacks.remove(callback)
}

override fun close(data: JsonObject, notifyClient: Boolean) {
assertOpen()
closed = true
onMessageCallbacks.clear()

removeComm(id)

onCloseCallbacks.forEach { it(data) }

if (notifyClient) {
iopub.sendMessage(
makeSimpleMessage(
MessageType.COMM_CLOSE,
CommClose(id, data)
)
)
}
}

fun messageReceived(data: JsonObject) {
if (closed) return
for (callback in onMessageCallbacks) {
callback(data)
}
}
}

private val hmac = HMAC(config.signatureScheme.replace("-", ""), config.signatureKey)
private val context = ZMQ.context(1)

val heartbeat = Socket(JupyterSocketInfo.HB)
val shell = Socket(JupyterSocketInfo.SHELL)
val control = Socket(JupyterSocketInfo.CONTROL)
val stdin = Socket(JupyterSocketInfo.STDIN)
val iopub = Socket(JupyterSocketInfo.IOPUB)
override val heartbeat = Socket(JupyterSocketInfo.HB)
override val shell = Socket(JupyterSocketInfo.SHELL)
override val control = Socket(JupyterSocketInfo.CONTROL)
override val stdin = Socket(JupyterSocketInfo.STDIN)
override val iopub = Socket(JupyterSocketInfo.IOPUB)

private fun fromSocketName(socket: JupyterSocket): Socket {
return when (socket) {
Expand Down Expand Up @@ -338,9 +238,9 @@ class JupyterConnectionImpl(val config: KernelConfig) : JupyterConnection, Close

fun updateSessionInfo(message: Message) {
val header = message.data.header ?: return
header.session?.let { sessionId = it }
header.username?.let { username = it }
messageId = message.id
header.session?.let { _sessionId = it }
header.username?.let { _username = it }
_messageId = message.id
}

override fun send(socketName: JupyterSocket, message: RawMessage) {
Expand All @@ -363,73 +263,6 @@ class JupyterConnectionImpl(val config: KernelConfig) : JupyterConnection, Close
send(socketName, message)
}

private val commOpenCallbacks = ConcurrentHashMap<String, CommOpenCallback>()
private val commTargetToIds = ConcurrentHashMap<String, CopyOnWriteArrayList<String>>()
private val commIdToComm = ConcurrentHashMap<String, CommImpl>()
override fun openComm(target: String, data: JsonObject): Comm {
val id = UUID.randomUUID().toString()
val newComm = processCommOpen(target, id, data)

// send comm_open
iopub.sendMessage(
makeSimpleMessage(
MessageType.COMM_OPEN,
CommOpen(newComm.id, newComm.target)
)
)

return newComm
}

fun processCommOpen(target: String, id: String, data: JsonObject): Comm {
val commIds = commTargetToIds.getOrPut(target) { CopyOnWriteArrayList() }
val newComm = CommImpl(target, id)
commIds.add(id)
commIdToComm[id] = newComm

val callback = commOpenCallbacks[target]
callback?.invoke(newComm, data)

return newComm
}

override fun closeComm(id: String, data: JsonObject) {
val comm = commIdToComm[id] ?: return
comm.close(data, notifyClient = true)
}

fun processCommClose(id: String, data: JsonObject) {
val comm = commIdToComm[id] ?: return
comm.close(data, notifyClient = false)
}

fun removeComm(id: String) {
val comm = commIdToComm[id] ?: return
val commIds = commTargetToIds[comm.target]!!
commIds.remove(id)
commIdToComm.remove(id)
}

override fun getComms(target: String?): Collection<Comm> {
return if (target == null) {
commIdToComm.values.toList()
} else {
commTargetToIds[target].orEmpty().mapNotNull { commIdToComm[it] }
}
}

fun processCommMessage(id: String, data: JsonObject) {
commIdToComm[id]?.messageReceived(data)
}

override fun registerCommTarget(target: String, callback: (Comm, JsonObject) -> Unit) {
commOpenCallbacks[target] = callback
}

override fun unregisterCommTarget(target: String) {
commOpenCallbacks.remove(target)
}

val stdinIn = StdinInputStream()

var contextMessage: Message? = null
Expand Down
6 changes: 4 additions & 2 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/ikotlin.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.jupyter.libraries.EmptyResolutionInfoProvider
import org.jetbrains.kotlinx.jupyter.libraries.KERNEL_LIBRARIES
import org.jetbrains.kotlinx.jupyter.libraries.ResolutionInfoProvider
import org.jetbrains.kotlinx.jupyter.libraries.getDefaultDirectoryResolutionInfoProvider
import org.jetbrains.kotlinx.jupyter.messaging.CommManagerImpl
import org.jetbrains.kotlinx.jupyter.messaging.controlMessagesHandler
import org.jetbrains.kotlinx.jupyter.messaging.shellMessagesHandler
import org.jetbrains.kotlinx.jupyter.repl.creating.DefaultReplFactory
Expand Down Expand Up @@ -122,7 +123,8 @@ fun kernelServer(config: KernelConfig, runtimeProperties: ReplRuntimeProperties

val executionCount = AtomicLong(1)

val repl = DefaultReplFactory(config, runtimeProperties, scriptReceivers, conn).createRepl()
val commManager = CommManagerImpl(conn)
val repl = DefaultReplFactory(config, runtimeProperties, scriptReceivers, conn, commManager).createRepl()

val mainThread = Thread.currentThread()

Expand All @@ -148,7 +150,7 @@ fun kernelServer(config: KernelConfig, runtimeProperties: ReplRuntimeProperties

conn.shell.onMessage { message ->
conn.updateSessionInfo(message)
shellMessagesHandler(message, repl, executionCount)
shellMessagesHandler(message, repl, commManager, executionCount)
}

val controlThread = thread {
Expand Down
Loading

0 comments on commit 4782275

Please sign in to comment.