Skip to content

Commit

Permalink
clean up types and remove use of entrypoints
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 committed Nov 9, 2022
1 parent 9b12e91 commit 89b4753
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 94 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ repos:
- id: mypy
exclude: tests
args: ["--config-file", "pyproject.toml"]
additional_dependencies: [pyzmq, tornado, types-paramiko]
additional_dependencies: [pyzmq, tornado, types-paramiko, traitlets, "jupyter_core>=5.0", ipykernel]
stages: [manual]

- repo: https://github.com/PyCQA/doc8
Expand Down
2 changes: 1 addition & 1 deletion jupyter_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __del__(self):
self.log.debug("Destroying zmq context for %s", self)
self.context.destroy()
try:
super_del = super().__del__
super_del = super().__del__ # type:ignore[misc]
except AttributeError:
pass
else:
Expand Down
16 changes: 9 additions & 7 deletions jupyter_client/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@

def write_connection_file(
fname: Optional[str] = None,
shell_port: Union[Integer, Int, int] = 0,
iopub_port: Union[Integer, Int, int] = 0,
stdin_port: Union[Integer, Int, int] = 0,
hb_port: Union[Integer, Int, int] = 0,
control_port: Union[Integer, Int, int] = 0,
shell_port: int = 0,
iopub_port: int = 0,
stdin_port: int = 0,
hb_port: int = 0,
control_port: int = 0,
ip: str = "",
key: bytes = b"",
transport: str = "tcp",
Expand Down Expand Up @@ -334,7 +334,7 @@ def tunnel_to_kernel(
class ConnectionFileMixin(LoggingConfigurable):
"""Mixin for configurable classes that work with connection files"""

data_dir = Unicode()
data_dir: Union[str, Unicode] = Unicode()

def _data_dir_default(self):
return jupyter_data_dir()
Expand All @@ -353,7 +353,9 @@ def _data_dir_default(self):
_connection_file_written = Bool(False)

transport = CaselessStrEnum(["tcp", "ipc"], default_value="tcp", config=True)
kernel_name = Unicode()
kernel_name: Union[str, Unicode] = Unicode()

context = Instance(zmq.Context)

ip = Unicode(
config=True,
Expand Down
30 changes: 16 additions & 14 deletions jupyter_client/consoleapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import os
import signal
import sys
import typing as t
import uuid
import warnings
from typing import cast

from jupyter_core.application import base_aliases
from jupyter_core.application import base_flags
Expand Down Expand Up @@ -94,7 +94,7 @@
# Classes
# -----------------------------------------------------------------------------

classes = [KernelManager, KernelRestarter, Session]
classes: t.List[t.Type[t.Any]] = [KernelManager, KernelRestarter, Session]


class JupyterConsoleApp(ConnectionFileMixin):
Expand Down Expand Up @@ -160,7 +160,7 @@ def build_kernel_argv(self, argv: object = None) -> None:
Override in subclasses if any args should be passed to the kernel
"""
self.kernel_argv = self.extra_args
self.kernel_argv = self.extra_args # type:ignore[attr-defined]

def init_connection_file(self) -> None:
"""find the connection file, and load the info if found.
Expand All @@ -177,31 +177,32 @@ def init_connection_file(self) -> None:
After this method is called, self.connection_file contains the *full path*
to the connection file, never just its name.
"""
runtime_dir = self.runtime_dir # type:ignore[attr-defined]
if self.existing:
try:
cf = find_connection_file(self.existing, [".", self.runtime_dir])
cf = find_connection_file(self.existing, [".", runtime_dir])
except Exception:
self.log.critical(
"Could not find existing kernel connection file %s", self.existing
)
self.exit(1)
self.exit(1) # type:ignore[attr-defined]
self.log.debug("Connecting to existing kernel: %s" % cf)
self.connection_file = cf
else:
# not existing, check if we are going to write the file
# and ensure that self.connection_file is a full path, not just the shortname
try:
cf = find_connection_file(self.connection_file, [self.runtime_dir])
cf = find_connection_file(self.connection_file, [runtime_dir])
except Exception:
# file might not exist
if self.connection_file == os.path.basename(self.connection_file):
# just shortname, put it in security dir
cf = os.path.join(self.runtime_dir, self.connection_file)
cf = os.path.join(runtime_dir, self.connection_file)
else:
cf = self.connection_file
self.connection_file = cf
try:
self.connection_file = _filefind(self.connection_file, [".", self.runtime_dir])
self.connection_file = _filefind(self.connection_file, [".", runtime_dir])
except OSError:
self.log.debug("Connection File not found: %s", self.connection_file)
return
Expand All @@ -217,7 +218,7 @@ def init_connection_file(self) -> None:
self.connection_file,
exc_info=True,
)
self.exit(1)
self.exit(1) # type:ignore[attr-defined]

def init_ssh(self) -> None:
"""set up ssh tunnels, if needed."""
Expand Down Expand Up @@ -256,7 +257,7 @@ def init_ssh(self) -> None:
except: # noqa
# even catch KeyboardInterrupt
self.log.error("Could not setup tunnels", exc_info=True)
self.exit(1)
self.exit(1) # type:ignore[attr-defined]

(
self.shell_port,
Expand All @@ -280,7 +281,8 @@ def _new_connection_file(self) -> str:
# 48b node segment (12 hex chars). Users running more than 32k simultaneous
# kernels can subclass.
ident = str(uuid.uuid4()).split("-")[-1]
cf = os.path.join(self.runtime_dir, "kernel-%s.json" % ident)
runtime_dir = self.runtime_dir # type:ignore[attr-defined]
cf = os.path.join(runtime_dir, "kernel-%s.json" % ident)
# only keep if it's actually new. Protect against unlikely collision
# in 48b random search space
cf = cf if not os.path.exists(cf) else ""
Expand Down Expand Up @@ -311,9 +313,9 @@ def init_kernel_manager(self) -> None:
)
except NoSuchKernel:
self.log.critical("Could not find kernel %s", self.kernel_name)
self.exit(1)
self.exit(1) # type:ignore[attr-defined]

self.kernel_manager = cast(KernelManager, self.kernel_manager)
self.kernel_manager = t.cast(KernelManager, self.kernel_manager)
self.kernel_manager.client_factory = self.kernel_client_class
kwargs = {}
kwargs["extra_arguments"] = self.kernel_argv
Expand Down Expand Up @@ -359,7 +361,7 @@ def initialize(self, argv: object = None) -> None:
Classes which mix this class in should call:
JupyterConsoleApp.initialize(self,argv)
"""
if self._dispatching:
if self._dispatching: # type:ignore[attr-defined]
return
self.init_connection_file()
self.init_ssh()
Expand Down
8 changes: 6 additions & 2 deletions jupyter_client/ioloop/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""A kernel manager with a tornado IOLoop"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import typing as t

import zmq
from tornado import ioloop
from traitlets import Instance
Expand Down Expand Up @@ -48,7 +50,7 @@ def _loop_default(self):
),
config=True,
)
_restarter = Instance("jupyter_client.ioloop.IOLoopKernelRestarter", allow_none=True)
_restarter: t.Any = Instance("jupyter_client.ioloop.IOLoopKernelRestarter", allow_none=True)

def start_restarter(self):
if self.autorestart and self.has_kernel:
Expand Down Expand Up @@ -87,7 +89,9 @@ def _loop_default(self):
),
config=True,
)
_restarter = Instance("jupyter_client.ioloop.AsyncIOLoopKernelRestarter", allow_none=True)
_restarter: t.Any = Instance(
"jupyter_client.ioloop.AsyncIOLoopKernelRestarter", allow_none=True
)

def start_restarter(self):
if self.autorestart and self.has_kernel:
Expand Down
2 changes: 1 addition & 1 deletion jupyter_client/kernelspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def find_kernel_specs(self):

if self.ensure_native_kernel and NATIVE_KERNEL_NAME not in d:
try:
from ipykernel.kernelspec import RESOURCES # type: ignore
from ipykernel.kernelspec import RESOURCES

self.log.debug(
"Native kernel (%s) available from %s",
Expand Down
12 changes: 6 additions & 6 deletions jupyter_client/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _client_factory_default(self) -> Type:
def _client_class_changed(self, change: t.Dict[str, DottedObjectName]) -> None:
self.client_factory = import_item(str(change["new"]))

kernel_id: str = Unicode(None, allow_none=True)
kernel_id: t.Union[str, Unicode] = Unicode(None, allow_none=True)

# The kernel provisioner with which this KernelManager is communicating.
# This will generally be a LocalProvisioner instance unless the kernelspec
Expand Down Expand Up @@ -161,10 +161,10 @@ def _kernel_spec_manager_changed(self, change: t.Dict[str, Instance]) -> None:
"vary by provisioned environment.",
)

kernel_name: Unicode = Unicode(kernelspec.NATIVE_KERNEL_NAME)
kernel_name: t.Union[str, Unicode] = Unicode(kernelspec.NATIVE_KERNEL_NAME)

@observe("kernel_name") # type:ignore[misc]
def _kernel_name_changed(self, change: t.Dict[str, Unicode]) -> None:
def _kernel_name_changed(self, change: t.Dict[str, str]) -> None:
self._kernel_spec = None
if change["new"] == "python":
self.kernel_name = kernelspec.NATIVE_KERNEL_NAME
Expand Down Expand Up @@ -240,7 +240,7 @@ def remove_restart_callback(self, callback: t.Callable, event: str = "restart")

def client(self, **kwargs: t.Any) -> KernelClient:
"""Create a client configured to connect to our kernel"""
kw = {}
kw: dict = {}
kw.update(self.get_connection_info(session=True))
kw.update(
dict(
Expand Down Expand Up @@ -684,7 +684,7 @@ def start_new_kernel(
kc = km.client()
kc.start_channels()
try:
kc.wait_for_ready(timeout=startup_timeout)
kc.wait_for_ready(timeout=startup_timeout) # type:ignore[attr-defined]
except RuntimeError:
kc.stop_channels()
km.shutdown_kernel()
Expand All @@ -702,7 +702,7 @@ async def start_new_async_kernel(
kc = km.client()
kc.start_channels()
try:
await kc.wait_for_ready(timeout=startup_timeout)
await kc.wait_for_ready(timeout=startup_timeout) # type:ignore[attr-defined]
except RuntimeError:
kc.stop_channels()
await km.shutdown_kernel()
Expand Down
2 changes: 1 addition & 1 deletion jupyter_client/multikernelmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __del__(self):
self.log.debug("Destroying zmq context for %s", self)
self.context.destroy()
try:
super_del = super().__del__
super_del = super().__del__ # type:ignore[misc]
except AttributeError:
pass
else:
Expand Down
74 changes: 39 additions & 35 deletions jupyter_client/provisioning/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import glob
import sys
from os import getenv
from os import path
from typing import Any
from typing import Dict
from typing import List

from entrypoints import EntryPoint
from entrypoints import get_group_all
from entrypoints import get_single
from entrypoints import NoSuchEntryPoint

# See compatibility note on `group` keyword in https://docs.python.org/3/library/importlib.metadata.html#entry-points
if sys.version_info < (3, 10): # pragma: no cover
from importlib_metadata import entry_points, EntryPoint
else: # pragma: no cover
from importlib.metadata import entry_points, EntryPoint

from traitlets.config import default
from traitlets.config import SingletonConfigurable
from traitlets.config import Unicode
Expand Down Expand Up @@ -119,7 +123,7 @@ def _check_availability(self, provisioner_name: str) -> bool:
try:
ep = self._get_provisioner(provisioner_name)
self.provisioners[provisioner_name] = ep # Update cache
except NoSuchEntryPoint:
except Exception:
is_available = False
return is_available

Expand Down Expand Up @@ -161,41 +165,41 @@ def get_provisioner_entries(self) -> Dict[str, str]:
"""
entries = {}
for name, ep in self.provisioners.items():
entries[name] = f"{ep.module_name}:{ep.object_name}"
entries[name] = ep.value
return entries

@staticmethod
def _get_all_provisioners() -> List[EntryPoint]:
"""Wrapper around entrypoints.get_group_all() - primarily to facilitate testing."""
return get_group_all(KernelProvisionerFactory.GROUP_NAME)
return entry_points(group=KernelProvisionerFactory.GROUP_NAME)

def _get_provisioner(self, name: str) -> EntryPoint:
"""Wrapper around entrypoints.get_single() - primarily to facilitate testing."""
try:
ep = get_single(KernelProvisionerFactory.GROUP_NAME, name)
except NoSuchEntryPoint:
# Check if the entrypoint name is 'local-provisioner'. Although this should never
# happen, we have seen cases where the previous distribution of jupyter_client has
# remained which doesn't include kernel-provisioner entrypoints (so 'local-provisioner'
# is deemed not found even though its definition is in THIS package). In such cass,
# the entrypoints package uses what it first finds - which is the older distribution
# resulting in a violation of a supposed invariant condition. To address this scenario,
# we will log a warning message indicating this situation, then build the entrypoint
# instance ourselves - since we have that information.
if name == 'local-provisioner':
distros = glob.glob(f"{path.dirname(path.dirname(__file__))}-*")
self.log.warning(
f"Kernel Provisioning: The 'local-provisioner' is not found. This is likely "
f"due to the presence of multiple jupyter_client distributions and a previous "
f"distribution is being used as the source for entrypoints - which does not "
f"include 'local-provisioner'. That distribution should be removed such that "
f"only the version-appropriate distribution remains (version >= 7). Until "
f"then, a 'local-provisioner' entrypoint will be automatically constructed "
f"and used.\nThe candidate distribution locations are: {distros}"
)
ep = EntryPoint(
'local-provisioner', 'jupyter_client.provisioning', 'LocalProvisioner'
)
else:
raise
return ep
eps = entry_points(group=KernelProvisionerFactory.GROUP_NAME, name=name)
if eps:
return eps[0]

# Check if the entrypoint name is 'local-provisioner'. Although this should never
# happen, we have seen cases where the previous distribution of jupyter_client has
# remained which doesn't include kernel-provisioner entrypoints (so 'local-provisioner'
# is deemed not found even though its definition is in THIS package). In such cass,
# the entrypoints package uses what it first finds - which is the older distribution
# resulting in a violation of a supposed invariant condition. To address this scenario,
# we will log a warning message indicating this situation, then build the entrypoint
# instance ourselves - since we have that information.
if name == 'local-provisioner':
distros = glob.glob(f"{path.dirname(path.dirname(__file__))}-*")
self.log.warning(
f"Kernel Provisioning: The 'local-provisioner' is not found. This is likely "
f"due to the presence of multiple jupyter_client distributions and a previous "
f"distribution is being used as the source for entrypoints - which does not "
f"include 'local-provisioner'. That distribution should be removed such that "
f"only the version-appropriate distribution remains (version >= 7). Until "
f"then, a 'local-provisioner' entrypoint will be automatically constructed "
f"and used.\nThe candidate distribution locations are: {distros}"
)
return EntryPoint(
'local-provisioner', 'jupyter_client.provisioning', 'LocalProvisioner'
)

raise
Loading

0 comments on commit 89b4753

Please sign in to comment.