Skip to content

Commit

Permalink
Refactor execution model
Browse files Browse the repository at this point in the history
  • Loading branch information
ileasile committed Jun 23, 2022
1 parent 4782275 commit b2028d6
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 155 deletions.
13 changes: 7 additions & 6 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/config.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.decodeFromJsonElement
import kotlinx.serialization.serializer
import org.jetbrains.kotlinx.jupyter.api.KotlinKernelVersion
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterSocket
import org.jetbrains.kotlinx.jupyter.common.getNameForUser
import org.jetbrains.kotlinx.jupyter.config.getLogger
import org.jetbrains.kotlinx.jupyter.config.readResourceAsIniFile
Expand All @@ -36,12 +37,12 @@ val defaultRuntimeProperties by lazy {
RuntimeKernelProperties(readResourceAsIniFile("runtime.properties"))
}

enum class JupyterSocketInfo(val zmqKernelType: SocketType, val zmqClientType: SocketType) {
HB(SocketType.REP, SocketType.REQ),
SHELL(SocketType.ROUTER, SocketType.REQ),
CONTROL(SocketType.ROUTER, SocketType.REQ),
STDIN(SocketType.ROUTER, SocketType.REQ),
IOPUB(SocketType.PUB, SocketType.SUB);
enum class JupyterSocketInfo(val type: JupyterSocket, val zmqKernelType: SocketType, val zmqClientType: SocketType) {
HB(JupyterSocket.HB, SocketType.REP, SocketType.REQ),
SHELL(JupyterSocket.SHELL, SocketType.ROUTER, SocketType.REQ),
CONTROL(JupyterSocket.CONTROL, SocketType.ROUTER, SocketType.REQ),
STDIN(JupyterSocket.STDIN, SocketType.ROUTER, SocketType.REQ),
IOPUB(JupyterSocket.IOPUB, SocketType.PUB, SocketType.SUB);

val nameForUser = getNameForUser(name)
}
Expand Down
80 changes: 20 additions & 60 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package org.jetbrains.kotlinx.jupyter

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
Expand All @@ -15,7 +12,6 @@ import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessage
import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessageCallback
import org.jetbrains.kotlinx.jupyter.api.libraries.header
import org.jetbrains.kotlinx.jupyter.api.libraries.type
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
import org.jetbrains.kotlinx.jupyter.messaging.InputReply
import org.jetbrains.kotlinx.jupyter.messaging.InputRequest
import org.jetbrains.kotlinx.jupyter.messaging.JupyterConnectionInternal
Expand All @@ -29,6 +25,7 @@ import org.jetbrains.kotlinx.jupyter.messaging.emptyJsonObjectStringBytes
import org.jetbrains.kotlinx.jupyter.messaging.jsonObject
import org.jetbrains.kotlinx.jupyter.messaging.makeJsonHeader
import org.jetbrains.kotlinx.jupyter.messaging.makeReplyMessage
import org.jetbrains.kotlinx.jupyter.messaging.makeSimpleMessage
import org.jetbrains.kotlinx.jupyter.messaging.sendMessage
import org.jetbrains.kotlinx.jupyter.messaging.toMessage
import org.jetbrains.kotlinx.jupyter.messaging.toRawMessage
Expand All @@ -37,10 +34,10 @@ import org.zeromq.SocketType
import org.zeromq.ZMQ
import java.io.Closeable
import java.io.IOException
import java.io.InputStream
import java.security.SignatureException
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import kotlin.concurrent.thread
import kotlin.math.min

typealias SocketMessageCallback = JupyterConnectionImpl.Socket.(Message) -> Unit
Expand Down Expand Up @@ -106,14 +103,8 @@ class JupyterConnectionImpl(
}
}

fun sendStatus(status: KernelStatus, msg: Message) {
connection.iopub.sendMessage(makeReplyMessage(msg, MessageType.STATUS, content = StatusReply(status)))
}

fun sendWrapped(incomingMessage: Message, msg: Message) {
sendStatus(KernelStatus.BUSY, incomingMessage)
sendMessage(msg)
sendStatus(KernelStatus.IDLE, incomingMessage)
doWrappedInBusyIdle(incomingMessage) { sendMessage(msg) }
}

override fun sendRawMessage(msg: RawMessage) {
Expand All @@ -140,7 +131,7 @@ class JupyterConnectionImpl(
override val connection: JupyterConnectionImpl = this@JupyterConnectionImpl
}

inner class StdinInputStream : java.io.InputStream() {
inner class StdinInputStream : InputStream() {
private var currentBuf: ByteArray? = null
private var currentBufPos = 0

Expand Down Expand Up @@ -263,57 +254,26 @@ class JupyterConnectionImpl(
send(socketName, message)
}

val stdinIn = StdinInputStream()

var contextMessage: Message? = null

private val currentExecutions = HashSet<Thread>()
private val coroutineScope = CoroutineScope(Dispatchers.Default)

data class ConnectionExecutionResult<T>(
val result: T?,
val throwable: Throwable?,
val isInterrupted: Boolean,
)

fun <T> runExecution(body: () -> T, classLoader: ClassLoader): ConnectionExecutionResult<T> {
var execRes: T? = null
var execException: Throwable? = null
val execThread = thread(contextClassLoader = classLoader) {
try {
execRes = body()
} catch (e: Throwable) {
execException = e
}
}
currentExecutions.add(execThread)
execThread.join()
currentExecutions.remove(execThread)

val exception = execException
val isInterrupted = exception is ThreadDeath ||
(exception is ReplException && exception.cause is ThreadDeath)
return ConnectionExecutionResult(execRes, exception, isInterrupted)
override fun sendStatus(status: KernelStatus, incomingMessage: Message?) {
val message = if (incomingMessage != null) makeReplyMessage(incomingMessage, MessageType.STATUS, content = StatusReply(status))
else makeSimpleMessage(MessageType.STATUS, content = StatusReply(status))
iopub.sendMessage(message)
}

/**
* We cannot use [Thread.interrupt] here because we have no way
* to control the code user executes. [Thread.interrupt] will do nothing for
* the simple calculation (like `while (true) 1`). Consider replacing with
* something more smart in the future.
*/
fun interruptExecution() {
@Suppress("deprecation")
while (currentExecutions.isNotEmpty()) {
val execution = currentExecutions.firstOrNull()
execution?.stop()
currentExecutions.remove(execution)
override fun doWrappedInBusyIdle(incomingMessage: Message?, action: () -> Unit) {
sendStatus(KernelStatus.BUSY, incomingMessage)
try {
action()
} finally {
sendStatus(KernelStatus.IDLE, incomingMessage)
}
}

fun launchJob(runnable: suspend CoroutineScope.() -> Unit) {
coroutineScope.launch(block = runnable)
}
override val stdinIn = StdinInputStream()

var contextMessage: Message? = null

override val executor: JupyterExecutor = JupyterExecutorImpl()

override fun close() {
heartbeat.close()
Expand Down Expand Up @@ -399,7 +359,7 @@ fun ZMQ.Socket.receiveRawMessage(start: ByteArray, hmac: HMAC): RawMessage {
)
}

object DisabledStdinInputStream : java.io.InputStream() {
object DisabledStdinInputStream : InputStream() {
override fun read(): Int {
throw IOException("Input from stdin is unsupported by the client")
}
Expand Down
72 changes: 72 additions & 0 deletions src/main/kotlin/org/jetbrains/kotlinx/jupyter/execution.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package org.jetbrains.kotlinx.jupyter

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
import java.util.Collections
import java.util.concurrent.ConcurrentHashMap
import kotlin.concurrent.thread

sealed interface ExecutionResult<out T> {
class Success<out T>(val result: T) : ExecutionResult<T>
class Failure(val throwable: Throwable) : ExecutionResult<Nothing>
object Interrupted : ExecutionResult<Nothing>
}

interface JupyterExecutor {
fun <T> runExecution(classLoader: ClassLoader? = null, body: () -> T): ExecutionResult<T>
fun interruptExecutions()

fun launchJob(runnable: suspend CoroutineScope.() -> Unit)
}

class JupyterExecutorImpl : JupyterExecutor {
private val currentExecutions: MutableSet<Thread> = Collections.newSetFromMap(ConcurrentHashMap())
private val coroutineScope = CoroutineScope(Dispatchers.Default)

override fun <T> runExecution(classLoader: ClassLoader?, body: () -> T): ExecutionResult<T> {
var execRes: T? = null
var execException: Throwable? = null
val execThread = thread(contextClassLoader = classLoader ?: Thread.currentThread().contextClassLoader) {
try {
execRes = body()
} catch (e: Throwable) {
execException = e
}
}
currentExecutions.add(execThread)
execThread.join()
currentExecutions.remove(execThread)

val exception = execException

return if (exception == null) {
ExecutionResult.Success(execRes!!)
} else {
val isInterrupted = exception is ThreadDeath ||
(exception is ReplException && exception.cause is ThreadDeath)
if (isInterrupted) ExecutionResult.Interrupted
else ExecutionResult.Failure(exception)
}
}

/**
* We cannot use [Thread.interrupt] here because we have no way
* to control the code user executes. [Thread.interrupt] will do nothing for
* the simple calculation (like `while (true) 1`). Consider replacing with
* something more smart in the future.
*/
override fun interruptExecutions() {
@Suppress("deprecation")
while (currentExecutions.isNotEmpty()) {
val execution = currentExecutions.firstOrNull()
execution?.stop()
currentExecutions.remove(execution)
}
}

override fun launchJob(runnable: suspend CoroutineScope.() -> Unit) {
coroutineScope.launch(block = runnable)
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jetbrains.kotlinx.jupyter.messaging

import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import org.jetbrains.kotlinx.jupyter.api.libraries.Comm
import org.jetbrains.kotlinx.jupyter.api.libraries.CommCloseCallback
import org.jetbrains.kotlinx.jupyter.api.libraries.CommManager
Expand All @@ -11,9 +12,9 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArrayList

interface CommManagerInternal : CommManager {
fun processCommOpen(target: String, id: String, data: JsonObject): Comm
fun processCommMessage(id: String, data: JsonObject)
fun processCommClose(id: String, data: JsonObject)
fun processCommOpen(message: Message, content: CommOpen): Comm?
fun processCommMessage(message: Message, content: CommMsg)
fun processCommClose(message: Message, content: CommClose)
}

class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommManagerInternal {
Expand All @@ -25,26 +26,51 @@ class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommM

override fun openComm(target: String, data: JsonObject): Comm {
val id = UUID.randomUUID().toString()
val newComm = processCommOpen(target, id, data)
val newComm = registerNewComm(target, id)

// send comm_open
iopub.sendSimpleMessage(
MessageType.COMM_OPEN,
CommOpen(newComm.id, newComm.target)
CommOpen(newComm.id, newComm.target, data)
)

return newComm
}

override fun processCommOpen(target: String, id: String, data: JsonObject): Comm {
override fun processCommOpen(message: Message, content: CommOpen): Comm? {
val target = content.targetName
val id = content.commId
val data = content.data

val callback = commOpenCallbacks[target]
if (callback == null) {
// If no callback is registered, we should send `comm_close` immediately in response.
iopub.sendSimpleMessage(
MessageType.COMM_CLOSE,
CommClose(id, commFailureJson("Target $target was not registered"))
)
return null
}

val newComm = registerNewComm(target, id)
try {
callback(newComm, data)
} catch (e: Throwable) {
iopub.sendSimpleMessage(
MessageType.COMM_CLOSE,
CommClose(id, commFailureJson("Unable to crete comm $id (with target $target), exception was thrown: ${e.stackTraceToString()}"))
)
removeComm(id)
}

return newComm
}

private fun registerNewComm(target: String, id: String): Comm {
val commIds = commTargetToIds.getOrPut(target) { CopyOnWriteArrayList() }
val newComm = CommImpl(target, id)
commIds.add(id)
commIdToComm[id] = newComm

val callback = commOpenCallbacks[target]
callback?.invoke(newComm, data)

return newComm
}

Expand All @@ -53,9 +79,9 @@ class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommM
comm.close(data, notifyClient = true)
}

override fun processCommClose(id: String, data: JsonObject) {
val comm = commIdToComm[id] ?: return
comm.close(data, notifyClient = false)
override fun processCommClose(message: Message, content: CommClose) {
val comm = commIdToComm[content.commId] ?: return
comm.close(content.data, notifyClient = false)
}

fun removeComm(id: String) {
Expand All @@ -73,8 +99,8 @@ class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommM
}
}

override fun processCommMessage(id: String, data: JsonObject) {
commIdToComm[id]?.messageReceived(data)
override fun processCommMessage(message: Message, content: CommMsg) {
commIdToComm[content.commId]?.messageReceived(message, content.data)
}

override fun registerCommTarget(target: String, callback: (Comm, JsonObject) -> Unit) {
Expand Down Expand Up @@ -144,11 +170,24 @@ class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommM
}
}

fun messageReceived(data: JsonObject) {
fun messageReceived(message: Message, data: JsonObject) {
if (closed) return
for (callback in onMessageCallbacks) {
callback(data)

connection.doWrappedInBusyIdle(message) {
for (callback in onMessageCallbacks) {
callback(data)
}
}
}
}

companion object {
private fun commFailureJson(errorMessage: String): JsonObject {
return JsonObject(
mapOf(
"error" to JsonPrimitive(errorMessage)
)
)
}
}
}
Loading

0 comments on commit b2028d6

Please sign in to comment.