diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0d1689ae23..329103e7d0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,7 +39,6 @@ repos: - types-docutils - types-requests - types-paramiko - - types-pkg_resources - types-PyYAML - types-setuptools - types-psutil diff --git a/distributed/comm/registry.py b/distributed/comm/registry.py index 00b10336a7..47ba730a7d 100644 --- a/distributed/comm/registry.py +++ b/distributed/comm/registry.py @@ -1,6 +1,29 @@ from __future__ import annotations +import importlib.metadata +import sys from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Protocol + + +class _EntryPoints(Protocol): + def __call__(self, **kwargs: str) -> Iterable[importlib.metadata.EntryPoint]: + ... + + +if sys.version_info >= (3, 10): + # py3.10 importlib.metadata type annotations are not in mypy yet + # https://github.com/python/typeshed/pull/7331 + _entry_points: _EntryPoints = importlib.metadata.entry_points # type: ignore[assignment] +else: + + def _entry_points( + *, group: str, name: str + ) -> Iterable[importlib.metadata.EntryPoint]: + for ep in importlib.metadata.entry_points().get(group, []): + if ep.name == name: + yield ep class Backend(ABC): @@ -59,40 +82,24 @@ def get_local_address_for(self, loc): backends: dict[str, Backend] = {} -def get_backend(scheme: str, require: bool = True) -> Backend: +def get_backend(scheme: str) -> Backend: """ Get the Backend instance for the given *scheme*. It looks for matching scheme in dask's internal cache, and falls-back to package metadata for the group name ``distributed.comm.backends`` - - Parameters - ---------- - - require : bool - Verify that the backends requirements are properly installed. See - https://setuptools.readthedocs.io/en/latest/pkg_resources.html for more - information. """ backend = backends.get(scheme) - if backend is None: - import pkg_resources - - backend = None - for backend_class_ep in pkg_resources.iter_entry_points( - "distributed.comm.backends", scheme - ): - # resolve and require are equivalent to load - backend_factory = backend_class_ep.resolve() - if require: - backend_class_ep.require() - backend = backend_factory() - - if backend is None: - raise ValueError( - "unknown address scheme %r (known schemes: %s)" - % (scheme, sorted(backends)) - ) - else: - backends[scheme] = backend - return backend + if backend is not None: + return backend + + for backend_class_ep in _entry_points( + name=scheme, group="distributed.comm.backends" + ): + backend = backend_class_ep.load()() + backends[scheme] = backend + return backend + + raise ValueError( + f"unknown address scheme {scheme!r} (known schemes: {sorted(backends)})" + ) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 46b5ab8d03..f9ecf5e072 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -2,18 +2,15 @@ import os import sys import threading -import types import warnings from functools import partial -import pkg_resources import pytest from tornado import ioloop from tornado.concurrent import Future import dask -import distributed from distributed.comm import ( CommClosedError, asyncio_tcp, @@ -30,7 +27,7 @@ from distributed.comm.registry import backends, get_backend from distributed.metrics import time from distributed.protocol import Serialized, deserialize, serialize, to_serialize -from distributed.utils import get_ip, get_ipv6 +from distributed.utils import get_ip, get_ipv6, mp_context from distributed.utils_test import ( get_cert, get_client_ssl_context, @@ -1313,30 +1310,18 @@ async def test_inproc_adresses(): await check_addresses(a, b) -def test_register_backend_entrypoint(): - # Code adapted from pandas backend entry point testing - # https://github.com/pandas-dev/pandas/blob/2470690b9f0826a8feb426927694fa3500c3e8d2/pandas/tests/plotting/test_backend.py#L50-L76 +def _get_backend_on_path(path): + sys.path.append(os.fsdecode(path)) + return get_backend("udp") - dist = pkg_resources.get_distribution("distributed") - if dist.module_path not in distributed.__file__: - # We are running from a non-installed distributed, and this test is invalid - pytest.skip("Testing a non-installed distributed") - mod = types.ModuleType("dask_udp") - mod.UDPBackend = lambda: 1 - sys.modules[mod.__name__] = mod - - entry_point_name = "distributed.comm.backends" - backends_entry_map = pkg_resources.get_entry_map("distributed") - if entry_point_name not in backends_entry_map: - backends_entry_map[entry_point_name] = dict() - backends_entry_map[entry_point_name]["udp"] = pkg_resources.EntryPoint( - "udp", mod.__name__, attrs=["UDPBackend"], dist=dist +def test_register_backend_entrypoint(tmp_path): + (tmp_path / "dask_udp.py").write_bytes(b"def udp_backend():\n return 1\n") + dist_info = tmp_path / "dask_udp-0.0.0.dist-info" + dist_info.mkdir() + (dist_info / "entry_points.txt").write_bytes( + b"[distributed.comm.backends]\nudp = dask_udp:udp_backend\n" ) - - # The require is disabled here since particularly unit tests may install - # dirty or dev versions which are conflicting with backend entrypoints if - # they are demanding for exact, stable versions. This should not fail the - # test - result = get_backend("udp", require=False) - assert result == 1 + with mp_context.Pool(1) as pool: + assert pool.apply(_get_backend_on_path, args=(tmp_path,)) == 1 + pool.join() diff --git a/distributed/utils.py b/distributed/utils.py index eecb673473..afe048c2ac 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -74,8 +74,6 @@ def _initialize_mp_context(): if method == "forkserver": # Makes the test suite much faster preload = ["distributed"] - if "pkg_resources" in sys.modules: - preload.append("pkg_resources") from distributed.versions import optional_packages, required_packages diff --git a/requirements.txt b/requirements.txt index 558b057097..4d535c0ad4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,3 @@ toolz >= 0.8.2 tornado >= 6.0.3 zict >= 0.1.3 pyyaml -setuptools