Skip to content

Commit

Permalink
Merge connection info into existing connection file if it already exi…
Browse files Browse the repository at this point in the history
…sts (#1133)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jasongrout and pre-commit-ci[bot] authored Jul 25, 2023
1 parent 58e9d15 commit 09c3c35
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 10 deletions.
26 changes: 17 additions & 9 deletions ipykernel/kernelapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
)
from IPython.core.profiledir import ProfileDir
from IPython.core.shellapp import InteractiveShellApp, shell_aliases, shell_flags
from jupyter_client import write_connection_file
from jupyter_client.connect import ConnectionFileMixin
from jupyter_client.session import Session, session_aliases, session_flags
from jupyter_core.paths import jupyter_runtime_dir
Expand All @@ -44,10 +43,11 @@
from traitlets.utils.importstring import import_item
from zmq.eventloop.zmqstream import ZMQStream

from .control import ControlThread
from .heartbeat import Heartbeat
from .connect import get_connection_info, write_connection_file

# local imports
from .control import ControlThread
from .heartbeat import Heartbeat
from .iostream import IOPubThread
from .ipkernel import IPythonKernel
from .parentpoller import ParentPollerUnix, ParentPollerWindows
Expand Down Expand Up @@ -260,12 +260,7 @@ def _bind_socket(self, s, port):
def write_connection_file(self):
"""write connection info to JSON file"""
cf = self.abs_connection_file
if os.path.exists(cf):
self.log.debug("Connection file %s already exists", cf)
return
self.log.debug("Writing connection file: %s", cf)
write_connection_file(
cf,
connection_info = dict(
ip=self.ip,
key=self.session.key,
transport=self.transport,
Expand All @@ -275,6 +270,19 @@ def write_connection_file(self):
iopub_port=self.iopub_port,
control_port=self.control_port,
)
if os.path.exists(cf):
# If the file exists, merge our info into it. For example, if the
# original file had port number 0, we update with the actual port
# used.
existing_connection_info = get_connection_info(cf, unpack=True)
connection_info = dict(existing_connection_info, **connection_info)
if connection_info == existing_connection_info:
self.log.debug("Connection file %s with current information already exists", cf)
return

self.log.debug("Writing connection file: %s", cf)

write_connection_file(cf, **connection_info)

def cleanup_connection_file(self):
"""Clean up our connection file."""
Expand Down
2 changes: 1 addition & 1 deletion ipykernel/tests/test_ipkernel_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class user_mod:
__dict__ = {}


async def test_properities(ipkernel: IPythonKernel) -> None:
async def test_properties(ipkernel: IPythonKernel) -> None:
ipkernel.user_module = user_mod()
ipkernel.user_ns = {}

Expand Down
71 changes: 71 additions & 0 deletions ipykernel/tests/test_kernelapp.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import json
import os
import threading
import time
from unittest.mock import patch

import pytest
from jupyter_core.paths import secure_write
from traitlets.config.loader import Config

from ipykernel.kernelapp import IPKernelApp

from .conftest import MockKernel
from .utils import TemporaryWorkingDirectory

try:
import trio
Expand Down Expand Up @@ -47,6 +51,73 @@ def trigger_stop():
app.close()


@pytest.mark.skipif(os.name == "nt", reason="permission errors on windows")
def test_merge_connection_file():
cfg = Config()
with TemporaryWorkingDirectory() as d:
cfg.ProfileDir.location = d
cf = os.path.join(d, "kernel.json")
initial_connection_info = {
"ip": "*",
"transport": "tcp",
"shell_port": 0,
"hb_port": 0,
"iopub_port": 0,
"stdin_port": 0,
"control_port": 53555,
"key": "abc123",
"signature_scheme": "hmac-sha256",
"kernel_name": "My Kernel",
}
# We cannot use connect.write_connection_file since
# it replaces port number 0 with a random port
# and we want IPKernelApp to do that replacement.
with secure_write(cf) as f:
json.dump(initial_connection_info, f)
assert os.path.exists(cf)

app = IPKernelApp(config=cfg, connection_file=cf)

# Calling app.initialize() does not work in the test, so we call the relevant functions that initialize() calls
# We must pass in an empty argv, otherwise the default is to try to parse the test runner's argv
super(IPKernelApp, app).initialize(argv=[""])
app.init_connection_file()
app.init_sockets()
app.init_heartbeat()
app.write_connection_file()

# Initialize should have merged the actual connection info
# with the connection info in the file
assert cf == app.abs_connection_file
assert os.path.exists(cf)

with open(cf) as f:
new_connection_info = json.load(f)

# ports originally set as 0 have been replaced
for port in ("shell", "hb", "iopub", "stdin"):
key = f"{port}_port"
# We initially had the port as 0
assert initial_connection_info[key] == 0
# the port is not 0 now
assert new_connection_info[key] > 0
# the port matches the port the kernel actually used
assert new_connection_info[key] == getattr(app, key), f"{key}"
del new_connection_info[key]
del initial_connection_info[key]

# The wildcard ip address was also replaced
assert new_connection_info["ip"] != "*"
del new_connection_info["ip"]
del initial_connection_info["ip"]

# everything else in the connection file is the same
assert initial_connection_info == new_connection_info

app.close()
os.remove(cf)


@pytest.mark.skipif(trio is None, reason="requires trio")
def test_trio_loop():
app = IPKernelApp(trio_loop=True)
Expand Down

0 comments on commit 09c3c35

Please sign in to comment.