From b2028d632ac239d39f52f1ac7da872792b26fadb Mon Sep 17 00:00:00 2001 From: Ilya Muradyan Date: Thu, 23 Jun 2022 04:21:13 +0300 Subject: [PATCH] Refactor execution model --- .../org/jetbrains/kotlinx/jupyter/config.kt | 13 +- .../jetbrains/kotlinx/jupyter/connection.kt | 80 +++-------- .../jetbrains/kotlinx/jupyter/execution.kt | 72 ++++++++++ .../jupyter/messaging/CommManagerImpl.kt | 75 +++++++--- .../messaging/JupyterConnectionInternal.kt | 12 +- .../jupyter/messaging/message_types.kt | 4 +- .../kotlinx/jupyter/messaging/protocol.kt | 129 +++++++++--------- .../repl/creating/MockJupyterConnection.kt | 16 +++ .../kotlinx/jupyter/test/executeTests.kt | 14 +- 9 files changed, 260 insertions(+), 155 deletions(-) create mode 100644 src/main/kotlin/org/jetbrains/kotlinx/jupyter/execution.kt diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/config.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/config.kt index c6c2c24ca..b93c04944 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/config.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/config.kt @@ -13,6 +13,7 @@ import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.decodeFromJsonElement import kotlinx.serialization.serializer import org.jetbrains.kotlinx.jupyter.api.KotlinKernelVersion +import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterSocket import org.jetbrains.kotlinx.jupyter.common.getNameForUser import org.jetbrains.kotlinx.jupyter.config.getLogger import org.jetbrains.kotlinx.jupyter.config.readResourceAsIniFile @@ -36,12 +37,12 @@ val defaultRuntimeProperties by lazy { RuntimeKernelProperties(readResourceAsIniFile("runtime.properties")) } -enum class JupyterSocketInfo(val zmqKernelType: SocketType, val zmqClientType: SocketType) { - HB(SocketType.REP, SocketType.REQ), - SHELL(SocketType.ROUTER, SocketType.REQ), - CONTROL(SocketType.ROUTER, SocketType.REQ), - STDIN(SocketType.ROUTER, SocketType.REQ), - IOPUB(SocketType.PUB, SocketType.SUB); +enum class JupyterSocketInfo(val type: JupyterSocket, val zmqKernelType: SocketType, val zmqClientType: SocketType) { + HB(JupyterSocket.HB, SocketType.REP, SocketType.REQ), + SHELL(JupyterSocket.SHELL, SocketType.ROUTER, SocketType.REQ), + CONTROL(JupyterSocket.CONTROL, SocketType.ROUTER, SocketType.REQ), + STDIN(JupyterSocket.STDIN, SocketType.ROUTER, SocketType.REQ), + IOPUB(JupyterSocket.IOPUB, SocketType.PUB, SocketType.SUB); val nameForUser = getNameForUser(name) } diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt index a955ff294..6b1f4bf28 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt @@ -1,8 +1,5 @@ package org.jetbrains.kotlinx.jupyter -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.launch import kotlinx.serialization.decodeFromString import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json @@ -15,7 +12,6 @@ import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessage import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessageCallback import org.jetbrains.kotlinx.jupyter.api.libraries.header import org.jetbrains.kotlinx.jupyter.api.libraries.type -import org.jetbrains.kotlinx.jupyter.exceptions.ReplException import org.jetbrains.kotlinx.jupyter.messaging.InputReply import org.jetbrains.kotlinx.jupyter.messaging.InputRequest import org.jetbrains.kotlinx.jupyter.messaging.JupyterConnectionInternal @@ -29,6 +25,7 @@ import org.jetbrains.kotlinx.jupyter.messaging.emptyJsonObjectStringBytes import org.jetbrains.kotlinx.jupyter.messaging.jsonObject import org.jetbrains.kotlinx.jupyter.messaging.makeJsonHeader import org.jetbrains.kotlinx.jupyter.messaging.makeReplyMessage +import org.jetbrains.kotlinx.jupyter.messaging.makeSimpleMessage import org.jetbrains.kotlinx.jupyter.messaging.sendMessage import org.jetbrains.kotlinx.jupyter.messaging.toMessage import org.jetbrains.kotlinx.jupyter.messaging.toRawMessage @@ -37,10 +34,10 @@ import org.zeromq.SocketType import org.zeromq.ZMQ import java.io.Closeable import java.io.IOException +import java.io.InputStream import java.security.SignatureException import javax.crypto.Mac import javax.crypto.spec.SecretKeySpec -import kotlin.concurrent.thread import kotlin.math.min typealias SocketMessageCallback = JupyterConnectionImpl.Socket.(Message) -> Unit @@ -106,14 +103,8 @@ class JupyterConnectionImpl( } } - fun sendStatus(status: KernelStatus, msg: Message) { - connection.iopub.sendMessage(makeReplyMessage(msg, MessageType.STATUS, content = StatusReply(status))) - } - fun sendWrapped(incomingMessage: Message, msg: Message) { - sendStatus(KernelStatus.BUSY, incomingMessage) - sendMessage(msg) - sendStatus(KernelStatus.IDLE, incomingMessage) + doWrappedInBusyIdle(incomingMessage) { sendMessage(msg) } } override fun sendRawMessage(msg: RawMessage) { @@ -140,7 +131,7 @@ class JupyterConnectionImpl( override val connection: JupyterConnectionImpl = this@JupyterConnectionImpl } - inner class StdinInputStream : java.io.InputStream() { + inner class StdinInputStream : InputStream() { private var currentBuf: ByteArray? = null private var currentBufPos = 0 @@ -263,57 +254,26 @@ class JupyterConnectionImpl( send(socketName, message) } - val stdinIn = StdinInputStream() - - var contextMessage: Message? = null - - private val currentExecutions = HashSet() - private val coroutineScope = CoroutineScope(Dispatchers.Default) - - data class ConnectionExecutionResult( - val result: T?, - val throwable: Throwable?, - val isInterrupted: Boolean, - ) - - fun runExecution(body: () -> T, classLoader: ClassLoader): ConnectionExecutionResult { - var execRes: T? = null - var execException: Throwable? = null - val execThread = thread(contextClassLoader = classLoader) { - try { - execRes = body() - } catch (e: Throwable) { - execException = e - } - } - currentExecutions.add(execThread) - execThread.join() - currentExecutions.remove(execThread) - - val exception = execException - val isInterrupted = exception is ThreadDeath || - (exception is ReplException && exception.cause is ThreadDeath) - return ConnectionExecutionResult(execRes, exception, isInterrupted) + override fun sendStatus(status: KernelStatus, incomingMessage: Message?) { + val message = if (incomingMessage != null) makeReplyMessage(incomingMessage, MessageType.STATUS, content = StatusReply(status)) + else makeSimpleMessage(MessageType.STATUS, content = StatusReply(status)) + iopub.sendMessage(message) } - /** - * We cannot use [Thread.interrupt] here because we have no way - * to control the code user executes. [Thread.interrupt] will do nothing for - * the simple calculation (like `while (true) 1`). Consider replacing with - * something more smart in the future. - */ - fun interruptExecution() { - @Suppress("deprecation") - while (currentExecutions.isNotEmpty()) { - val execution = currentExecutions.firstOrNull() - execution?.stop() - currentExecutions.remove(execution) + override fun doWrappedInBusyIdle(incomingMessage: Message?, action: () -> Unit) { + sendStatus(KernelStatus.BUSY, incomingMessage) + try { + action() + } finally { + sendStatus(KernelStatus.IDLE, incomingMessage) } } - fun launchJob(runnable: suspend CoroutineScope.() -> Unit) { - coroutineScope.launch(block = runnable) - } + override val stdinIn = StdinInputStream() + + var contextMessage: Message? = null + + override val executor: JupyterExecutor = JupyterExecutorImpl() override fun close() { heartbeat.close() @@ -399,7 +359,7 @@ fun ZMQ.Socket.receiveRawMessage(start: ByteArray, hmac: HMAC): RawMessage { ) } -object DisabledStdinInputStream : java.io.InputStream() { +object DisabledStdinInputStream : InputStream() { override fun read(): Int { throw IOException("Input from stdin is unsupported by the client") } diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/execution.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/execution.kt new file mode 100644 index 000000000..a6f70dad2 --- /dev/null +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/execution.kt @@ -0,0 +1,72 @@ +package org.jetbrains.kotlinx.jupyter + +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import org.jetbrains.kotlinx.jupyter.exceptions.ReplException +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap +import kotlin.concurrent.thread + +sealed interface ExecutionResult { + class Success(val result: T) : ExecutionResult + class Failure(val throwable: Throwable) : ExecutionResult + object Interrupted : ExecutionResult +} + +interface JupyterExecutor { + fun runExecution(classLoader: ClassLoader? = null, body: () -> T): ExecutionResult + fun interruptExecutions() + + fun launchJob(runnable: suspend CoroutineScope.() -> Unit) +} + +class JupyterExecutorImpl : JupyterExecutor { + private val currentExecutions: MutableSet = Collections.newSetFromMap(ConcurrentHashMap()) + private val coroutineScope = CoroutineScope(Dispatchers.Default) + + override fun runExecution(classLoader: ClassLoader?, body: () -> T): ExecutionResult { + var execRes: T? = null + var execException: Throwable? = null + val execThread = thread(contextClassLoader = classLoader ?: Thread.currentThread().contextClassLoader) { + try { + execRes = body() + } catch (e: Throwable) { + execException = e + } + } + currentExecutions.add(execThread) + execThread.join() + currentExecutions.remove(execThread) + + val exception = execException + + return if (exception == null) { + ExecutionResult.Success(execRes!!) + } else { + val isInterrupted = exception is ThreadDeath || + (exception is ReplException && exception.cause is ThreadDeath) + if (isInterrupted) ExecutionResult.Interrupted + else ExecutionResult.Failure(exception) + } + } + + /** + * We cannot use [Thread.interrupt] here because we have no way + * to control the code user executes. [Thread.interrupt] will do nothing for + * the simple calculation (like `while (true) 1`). Consider replacing with + * something more smart in the future. + */ + override fun interruptExecutions() { + @Suppress("deprecation") + while (currentExecutions.isNotEmpty()) { + val execution = currentExecutions.firstOrNull() + execution?.stop() + currentExecutions.remove(execution) + } + } + + override fun launchJob(runnable: suspend CoroutineScope.() -> Unit) { + coroutineScope.launch(block = runnable) + } +} diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/CommManagerImpl.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/CommManagerImpl.kt index 354a74093..937604bfc 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/CommManagerImpl.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/CommManagerImpl.kt @@ -1,6 +1,7 @@ package org.jetbrains.kotlinx.jupyter.messaging import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive import org.jetbrains.kotlinx.jupyter.api.libraries.Comm import org.jetbrains.kotlinx.jupyter.api.libraries.CommCloseCallback import org.jetbrains.kotlinx.jupyter.api.libraries.CommManager @@ -11,9 +12,9 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.CopyOnWriteArrayList interface CommManagerInternal : CommManager { - fun processCommOpen(target: String, id: String, data: JsonObject): Comm - fun processCommMessage(id: String, data: JsonObject) - fun processCommClose(id: String, data: JsonObject) + fun processCommOpen(message: Message, content: CommOpen): Comm? + fun processCommMessage(message: Message, content: CommMsg) + fun processCommClose(message: Message, content: CommClose) } class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommManagerInternal { @@ -25,26 +26,51 @@ class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommM override fun openComm(target: String, data: JsonObject): Comm { val id = UUID.randomUUID().toString() - val newComm = processCommOpen(target, id, data) + val newComm = registerNewComm(target, id) // send comm_open iopub.sendSimpleMessage( MessageType.COMM_OPEN, - CommOpen(newComm.id, newComm.target) + CommOpen(newComm.id, newComm.target, data) ) return newComm } - override fun processCommOpen(target: String, id: String, data: JsonObject): Comm { + override fun processCommOpen(message: Message, content: CommOpen): Comm? { + val target = content.targetName + val id = content.commId + val data = content.data + + val callback = commOpenCallbacks[target] + if (callback == null) { + // If no callback is registered, we should send `comm_close` immediately in response. + iopub.sendSimpleMessage( + MessageType.COMM_CLOSE, + CommClose(id, commFailureJson("Target $target was not registered")) + ) + return null + } + + val newComm = registerNewComm(target, id) + try { + callback(newComm, data) + } catch (e: Throwable) { + iopub.sendSimpleMessage( + MessageType.COMM_CLOSE, + CommClose(id, commFailureJson("Unable to crete comm $id (with target $target), exception was thrown: ${e.stackTraceToString()}")) + ) + removeComm(id) + } + + return newComm + } + + private fun registerNewComm(target: String, id: String): Comm { val commIds = commTargetToIds.getOrPut(target) { CopyOnWriteArrayList() } val newComm = CommImpl(target, id) commIds.add(id) commIdToComm[id] = newComm - - val callback = commOpenCallbacks[target] - callback?.invoke(newComm, data) - return newComm } @@ -53,9 +79,9 @@ class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommM comm.close(data, notifyClient = true) } - override fun processCommClose(id: String, data: JsonObject) { - val comm = commIdToComm[id] ?: return - comm.close(data, notifyClient = false) + override fun processCommClose(message: Message, content: CommClose) { + val comm = commIdToComm[content.commId] ?: return + comm.close(content.data, notifyClient = false) } fun removeComm(id: String) { @@ -73,8 +99,8 @@ class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommM } } - override fun processCommMessage(id: String, data: JsonObject) { - commIdToComm[id]?.messageReceived(data) + override fun processCommMessage(message: Message, content: CommMsg) { + commIdToComm[content.commId]?.messageReceived(message, content.data) } override fun registerCommTarget(target: String, callback: (Comm, JsonObject) -> Unit) { @@ -144,11 +170,24 @@ class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommM } } - fun messageReceived(data: JsonObject) { + fun messageReceived(message: Message, data: JsonObject) { if (closed) return - for (callback in onMessageCallbacks) { - callback(data) + + connection.doWrappedInBusyIdle(message) { + for (callback in onMessageCallbacks) { + callback(data) + } } } } + + companion object { + private fun commFailureJson(errorMessage: String): JsonObject { + return JsonObject( + mapOf( + "error" to JsonPrimitive(errorMessage) + ) + ) + } + } } diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/JupyterConnectionInternal.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/JupyterConnectionInternal.kt index 2db83fdd3..89a471f96 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/JupyterConnectionInternal.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/JupyterConnectionInternal.kt @@ -1,7 +1,9 @@ package org.jetbrains.kotlinx.jupyter.messaging +import org.jetbrains.kotlinx.jupyter.JupyterExecutor import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterConnection import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessage +import java.io.InputStream interface JupyterConnectionInternal : JupyterConnection { val heartbeat: JupyterServerSocket @@ -13,6 +15,12 @@ interface JupyterConnectionInternal : JupyterConnection { val messageId: List val sessionId: String val username: String + + val executor: JupyterExecutor + val stdinIn: InputStream + + fun sendStatus(status: KernelStatus, incomingMessage: Message? = null) + fun doWrappedInBusyIdle(incomingMessage: Message? = null, action: () -> Unit) } fun JupyterConnectionInternal.makeDefaultHeader(msgType: MessageType): MessageHeader { @@ -39,8 +47,8 @@ fun JupyterServerSocket.sendMessage(msg: Message) { sendRawMessage(msg.toRawMessage()) } -fun JupyterServerSocket.sendOut(msg: Message, stream: JupyterOutType, text: String) { - sendMessage(makeReplyMessage(msg, header = makeHeader(MessageType.STREAM, msg), content = StreamResponse(stream.optionName(), text))) +fun JupyterConnectionInternal.sendOut(msg: Message, stream: JupyterOutType, text: String) { + iopub.sendMessage(makeReplyMessage(msg, header = makeHeader(MessageType.STREAM, msg), content = StreamResponse(stream.optionName(), text))) } fun JupyterServerSocket.sendSimpleMessage(msgType: MessageType, content: MessageContent) { diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/message_types.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/message_types.kt index a29f50f2f..601c44d1e 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/message_types.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/message_types.kt @@ -39,7 +39,7 @@ enum class MessageType(val contentClass: KClass) { EXECUTE_REQUEST(ExecuteRequest::class), EXECUTE_REPLY(ExecuteReply::class), EXECUTE_INPUT(ExecutionInputReply::class), - EXECUTE_RESULT(ExecutionResult::class), + EXECUTE_RESULT(ExecutionResultMessage::class), INSPECT_REQUEST(InspectRequest::class), INSPECT_REPLY(InspectReply::class), @@ -377,7 +377,7 @@ class ExecutionInputReply( ) : MessageContent() @Serializable -class ExecutionResult( +class ExecutionResultMessage( val data: JsonElement, val metadata: JsonElement, diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/protocol.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/protocol.kt index 897f71e12..84fab6964 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/protocol.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/protocol.kt @@ -7,6 +7,7 @@ import kotlinx.serialization.json.encodeToJsonElement import org.jetbrains.annotations.TestOnly import org.jetbrains.kotlinx.jupyter.DisabledStdinInputStream import org.jetbrains.kotlinx.jupyter.EvalRequestData +import org.jetbrains.kotlinx.jupyter.ExecutionResult import org.jetbrains.kotlinx.jupyter.JupyterConnectionImpl import org.jetbrains.kotlinx.jupyter.JupyterSocketInfo import org.jetbrains.kotlinx.jupyter.LoggingManagement.disableLogging @@ -61,10 +62,10 @@ abstract class Response( fun send(socket: JupyterConnectionImpl.Socket, requestCount: Long, requestMsg: Message, startedTime: String) { if (stdOut != null && stdOut.isNotEmpty()) { - socket.connection.iopub.sendOut(requestMsg, JupyterOutType.STDOUT, stdOut) + socket.connection.sendOut(requestMsg, JupyterOutType.STDOUT, stdOut) } if (stdErr != null && stdErr.isNotEmpty()) { - socket.connection.iopub.sendOut(requestMsg, JupyterOutType.STDERR, stdErr) + socket.connection.sendOut(requestMsg, JupyterOutType.STDERR, stdErr) } sendBody(socket, requestCount, requestMsg, startedTime) } @@ -86,7 +87,7 @@ class OkResponseWithMessage( makeReplyMessage( requestMsg, MessageType.EXECUTE_RESULT, - content = ExecutionResult( + content = ExecutionResultMessage( executionCount = requestCount, data = resultJson["data"]!!, metadata = resultJson["metadata"]!! @@ -219,7 +220,7 @@ class ErrorResponseWithMessage( fun JupyterConnectionImpl.Socket.controlMessagesHandler(msg: Message, repl: ReplForJupyter?) { when (msg.content) { is InterruptRequest -> { - connection.interruptExecution() + connection.executor.interruptExecutions() sendMessage(makeReplyMessage(msg, MessageType.INTERRUPT_REPLY, content = msg.content)) } is ShutdownRequest -> { @@ -292,34 +293,32 @@ fun JupyterConnectionImpl.Socket.shellMessagesHandler(msg: Message, repl: ReplFo } val startedTime = ISO8601DateNow - connection.iopub.sendStatus(KernelStatus.BUSY, msg) - - val code = content.code - connection.iopub.sendMessage( - makeReplyMessage( - msg, - MessageType.EXECUTE_INPUT, - content = ExecutionInputReply(code, count) + connection.doWrappedInBusyIdle(msg) { + val code = content.code + connection.iopub.sendMessage( + makeReplyMessage( + msg, + MessageType.EXECUTE_INPUT, + content = ExecutionInputReply(code, count) + ) ) - ) - val res: Response = if (looksLikeReplCommand(code)) { - runCommand(code, repl) - } else { - connection.evalWithIO(repl, msg) { - repl.eval( - EvalRequestData( - code, - count.toInt(), - content.storeHistory, - content.silent, + val res: Response = if (looksLikeReplCommand(code)) { + runCommand(code, repl) + } else { + connection.evalWithIO(repl, msg) { + repl.eval( + EvalRequestData( + code, + count.toInt(), + content.storeHistory, + content.silent, + ) ) - ) + } } - } - res.send(this, count, msg, startedTime) - - connection.iopub.sendStatus(KernelStatus.IDLE, msg) + res.send(this, count, msg, startedTime) + } connection.contextMessage = null } is CommInfoRequest -> { @@ -328,23 +327,29 @@ fun JupyterConnectionImpl.Socket.shellMessagesHandler(msg: Message, repl: ReplFo sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_INFO_REPLY, content = CommInfoReply(replyMap))) } is CommOpen -> { - commManager.processCommOpen(content.targetName, content.commId, content.data) + connection.executor.runExecution { + commManager.processCommOpen(msg, content) + } } is CommClose -> { - commManager.processCommClose(content.commId, content.data) + connection.executor.runExecution { + commManager.processCommClose(msg, content) + } } is CommMsg -> { - commManager.processCommMessage(content.commId, content.data) + connection.executor.runExecution { + commManager.processCommMessage(msg, content) + } } is CompleteRequest -> { - connection.launchJob { + connection.executor.launchJob { repl.complete(content.code, content.cursorPos) { result -> sendWrapped(msg, makeReplyMessage(msg, MessageType.COMPLETE_REPLY, content = result.message)) } } } is ListErrorsRequest -> { - connection.launchJob { + connection.executor.launchJob { repl.listErrors(content.code) { result -> sendWrapped(msg, makeReplyMessage(msg, MessageType.LIST_ERRORS_REPLY, content = result.message)) } @@ -455,7 +460,7 @@ fun Any?.toDisplayResult(notebook: Notebook): DisplayResult? = when (this) { else -> textResult(this.toString()) } -fun JupyterConnectionImpl.evalWithIO(repl: ReplForJupyter, srcMessage: Message, body: () -> EvalResult?): Response { +fun JupyterConnectionInternal.evalWithIO(repl: ReplForJupyter, incomingMessage: Message, body: () -> EvalResult?): Response { val config = repl.outputConfig val out = System.out val err = System.err @@ -469,7 +474,7 @@ fun JupyterConnectionImpl.evalWithIO(repl: ReplForJupyter, srcMessage: Message, captureOutput ) { text -> cell()?.appendStreamOutput(text) - this.iopub.sendOut(srcMessage, outType, text) + this.sendOut(incomingMessage, outType, text) } } @@ -493,45 +498,39 @@ fun JupyterConnectionImpl.evalWithIO(repl: ReplForJupyter, srcMessage: Message, System.setErr(printForkedErr) val `in` = System.`in` - val allowStdIn = (srcMessage.content as? ExecuteRequest)?.allowStdin ?: true + val allowStdIn = (incomingMessage.content as? ExecuteRequest)?.allowStdin ?: true System.setIn(if (allowStdIn) stdinIn else DisabledStdinInputStream) try { - return try { - val (exec, execException, executionInterrupted) = runExecution(body, repl.currentClassLoader) - when { - executionInterrupted -> { - flushStreams() - AbortResponseWithMessage("The execution was interrupted") - } - execException != null -> { - throw execException - } - exec == null -> { + return when (val res = executor.runExecution(repl.currentClassLoader, body)) { + is ExecutionResult.Success -> { + if (res.result == null) { AbortResponseWithMessage("NO REPL!") - } - else -> { - flushStreams() + } else { try { - exec.toResponse(repl.notebook) + res.result.toResponse(repl.notebook) } catch (e: Exception) { AbortResponseWithMessage("error: Unable to convert result to a string: $e") } } } - } catch (ex: ReplException) { - flushStreams() - - (ex as? ReplEvalRuntimeException)?.cause?.let { originalThrowable -> - repl.throwableRenderersProcessor.renderThrowable(originalThrowable) - }?.let { renderedThrowable -> - rawToResponse(renderedThrowable, repl.notebook) - } ?: ErrorResponseWithMessage( - ex.render(), - ex.javaClass.canonicalName, - ex.message ?: "", - ex.stackTrace.map { it.toString() }, - ex.getAdditionalInfoJson() ?: Json.EMPTY - ) + is ExecutionResult.Failure -> { + val ex = res.throwable + if (ex !is ReplException) throw ex + (ex as? ReplEvalRuntimeException)?.cause?.let { originalThrowable -> + repl.throwableRenderersProcessor.renderThrowable(originalThrowable) + }?.let { renderedThrowable -> + rawToResponse(renderedThrowable, repl.notebook) + } ?: ErrorResponseWithMessage( + ex.render(), + ex.javaClass.canonicalName, + ex.message ?: "", + ex.stackTrace.map { it.toString() }, + ex.getAdditionalInfoJson() ?: Json.EMPTY + ) + } + ExecutionResult.Interrupted -> { + AbortResponseWithMessage("The execution was interrupted") + } } } finally { flushStreams() diff --git a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/creating/MockJupyterConnection.kt b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/creating/MockJupyterConnection.kt index c89f4c616..9e9ee2114 100644 --- a/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/creating/MockJupyterConnection.kt +++ b/src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/creating/MockJupyterConnection.kt @@ -1,11 +1,15 @@ package org.jetbrains.kotlinx.jupyter.repl.creating import kotlinx.serialization.json.JsonObject +import org.jetbrains.kotlinx.jupyter.JupyterExecutor import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterSocket import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessage import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessageCallback import org.jetbrains.kotlinx.jupyter.messaging.JupyterConnectionInternal import org.jetbrains.kotlinx.jupyter.messaging.JupyterServerSocket +import org.jetbrains.kotlinx.jupyter.messaging.KernelStatus +import org.jetbrains.kotlinx.jupyter.messaging.Message +import java.io.InputStream object MockJupyterConnection : JupyterConnectionInternal { override val heartbeat: JupyterServerSocket @@ -24,6 +28,18 @@ object MockJupyterConnection : JupyterConnectionInternal { get() = throw NotImplementedError() override val username: String get() = throw NotImplementedError() + override val executor: JupyterExecutor + get() = throw NotImplementedError() + override val stdinIn: InputStream + get() = throw NotImplementedError() + + override fun sendStatus(status: KernelStatus, incomingMessage: Message?) { + throw NotImplementedError() + } + + override fun doWrappedInBusyIdle(incomingMessage: Message?, action: () -> Unit) { + throw NotImplementedError() + } override fun addMessageCallback(callback: RawMessageCallback): RawMessageCallback { throw NotImplementedError() diff --git a/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/executeTests.kt b/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/executeTests.kt index acded323a..7edb062d9 100644 --- a/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/executeTests.kt +++ b/src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/executeTests.kt @@ -21,7 +21,7 @@ import org.jetbrains.kotlinx.jupyter.messaging.CommMsg import org.jetbrains.kotlinx.jupyter.messaging.CommOpen import org.jetbrains.kotlinx.jupyter.messaging.ExecuteReply import org.jetbrains.kotlinx.jupyter.messaging.ExecuteRequest -import org.jetbrains.kotlinx.jupyter.messaging.ExecutionResult +import org.jetbrains.kotlinx.jupyter.messaging.ExecutionResultMessage import org.jetbrains.kotlinx.jupyter.messaging.InputReply import org.jetbrains.kotlinx.jupyter.messaging.IsCompleteReply import org.jetbrains.kotlinx.jupyter.messaging.IsCompleteRequest @@ -122,7 +122,7 @@ class ExecuteTests : KernelServerTestsBase() { var response: Any? = null if (hasResult) { msg = ioPub.receiveMessage() - val content = msg.content as ExecutionResult + val content = msg.content as ExecutionResultMessage assertEquals(MessageType.EXECUTE_RESULT, msg.type) response = content.data } @@ -446,10 +446,20 @@ class ExecuteTests : KernelServerTestsBase() { ) ) + iopub.receiveMessage().apply { + val c = content.shouldBeTypeOf() + c.status shouldBe KernelStatus.BUSY + } + iopub.receiveMessage().apply { val c = content.shouldBeTypeOf() c.commId shouldBe commId c.data["y"]!!.jsonPrimitive.content shouldBe "received: 4321" } + + iopub.receiveMessage().apply { + val c = content.shouldBeTypeOf() + c.status shouldBe KernelStatus.IDLE + } } }