diff --git a/aiida/plugins/entry_point.py b/aiida/plugins/entry_point.py index 97d8067010..8c14cb1cda 100644 --- a/aiida/plugins/entry_point.py +++ b/aiida/plugins/entry_point.py @@ -12,6 +12,7 @@ import functools import traceback from typing import Any, List, Optional, Sequence, Set, Tuple +from warnings import warn # importlib.metadata was introduced into the standard library in python 3.8, # but was then updated in python 3.10 to use an improved API. @@ -20,6 +21,7 @@ from importlib_metadata import entry_points as _eps from aiida.common.exceptions import LoadingEntryPointError, MissingEntryPointError, MultipleEntryPointError +from aiida.common.warnings import AiidaDeprecationWarning __all__ = ('load_entry_point', 'load_entry_point_from_string', 'parse_entry_point', 'get_entry_points') @@ -72,6 +74,21 @@ class EntryPointFormat(enum.Enum): 'aiida.workflows': 'aiida.workflows', } +DEPRECATED_ENTRY_POINTS_MAPPING = { + 'aiida.calculations': ['arithmetic.add', 'templatereplacer'], + 'aiida.data': [ + 'array', 'array.bands', 'array.kpoints', 'array.projection', 'array.trajectory', 'array.xy', 'base', 'bool', + 'cif', 'code', 'dict', 'float', 'folder', 'int', 'list', 'numeric', 'orbital', 'remote', 'remote.stash', + 'remote.stash.folder', 'singlefile', 'str', 'structure', 'upf' + ], + 'aiida.tools.dbimporters': ['cod', 'icsd', 'materialsproject', 'mpds', 'mpod', 'nninc', 'oqmd', 'pcod', 'tcod'], + 'aiida.tools.data.orbitals': ['orbital', 'realhydrogen'], + 'aiida.parsers': ['arithmetic.add', 'templatereplacer.doubler'], + 'aiida.schedulers': ['direct', 'lsf', 'pbspro', 'sge', 'slurm', 'torque'], + 'aiida.transports': ['local', 'ssh'], + 'aiida.workflows': ['arithmetic.multiply_add', 'arithmetic.add_multiply'], +} + def parse_entry_point(group: str, spec: str) -> EntryPoint: """Return an entry point, given its group and spec (as formatted in the setup)""" @@ -260,6 +277,8 @@ def get_entry_point(group: str, name: str) -> EntryPoint: :raises aiida.common.MissingEntryPointError: entry point was not registered """ + # The next line should be removed for ``aiida-core==3.0`` when the old deprecated entry points are fully removed. + name = convert_potentially_deprecated_entry_point(group, name) found = eps().select(group=group, name=name) if name not in found.names: raise MissingEntryPointError(f"Entry point '{name}' not found in group '{group}'") @@ -268,6 +287,33 @@ def get_entry_point(group: str, name: str) -> EntryPoint: return found[name] +def convert_potentially_deprecated_entry_point(group: str, name: str) -> str: + """Check whether the specified entry point is deprecated, in which case print warning and convert to new name. + + For `aiida-core==2.0` all existing entry points where properly prefixed with ``core.`` and the old entry points were + deprecated. To provide a smooth transition these deprecated entry points are detected in ``get_entry_point``, which + is the lowest function that tries to resolve an entry point string, by calling this function. + + If the entry point corresponds to a deprecated one, a warning is raised and the new corresponding entry point name + is returned. + + This method should be removed in ``aiida-core==3.0``. + """ + try: + deprecated_entry_points = DEPRECATED_ENTRY_POINTS_MAPPING[group] + except KeyError: + return name + else: + if name in deprecated_entry_points: + warn( + f'The entry point `{name}` is deprecated. Please replace it with `core.{name}`.', + AiidaDeprecationWarning + ) + name = f'core.{name}' + + return name + + @functools.lru_cache(maxsize=100) def get_entry_point_from_class(class_module: str, class_name: str) -> Tuple[Optional[str], Optional[EntryPoint]]: """ diff --git a/aiida/plugins/factories.py b/aiida/plugins/factories.py index 079ab29c19..d55ea3a17a 100644 --- a/aiida/plugins/factories.py +++ b/aiida/plugins/factories.py @@ -10,7 +10,7 @@ # pylint: disable=invalid-name,cyclic-import """Definition of factories to load classes from the various plugin groups.""" from inspect import isclass -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union from importlib_metadata import EntryPoint @@ -31,22 +31,6 @@ from aiida.transports import Transport -def warn_deprecated_entry_point(entry_point_name: str, deprecated_entry_points: List[str]) -> str: - """If the ``entry_point_name`` is part of the list of ``deprecated_entry_points``, raise a warning.""" - from warnings import warn - - from aiida.common.warnings import AiidaDeprecationWarning - - if entry_point_name in deprecated_entry_points: - warn( - f'The entry point `{entry_point_name}` is deprecated. Please replace it with `core.{entry_point_name}`.', - AiidaDeprecationWarning - ) - entry_point_name = f'core.{entry_point_name}' - - return entry_point_name - - def raise_invalid_type_error(entry_point_name: str, entry_point_group: str, valid_classes: Tuple[Any, ...]) -> None: """Raise an `InvalidEntryPointTypeError` with formatted message. @@ -90,9 +74,6 @@ def CalculationFactory(entry_point_name: str, load: bool = True) -> Optional[Uni from aiida.engine import CalcJob, calcfunction, is_process_function from aiida.orm import CalcFunctionNode - deprecated_entry_points = ['arithmetic.add', 'templatereplacer'] - entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points) - entry_point_group = 'aiida.calculations' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (CalcJob, calcfunction) @@ -135,13 +116,6 @@ def DataFactory(entry_point_name: str, load: bool = True) -> Optional[Union[Entr """ from aiida.orm import Data - deprecated_entry_points = [ - 'array', 'array.bands', 'array.kpoints', 'array.projection', 'array.trajectory', 'array.xy', 'base', 'bool', - 'cif', 'code', 'dict', 'float', 'folder', 'int', 'list', 'numeric', 'orbital', 'remote', 'remote.stash', - 'remote.stash.folder', 'singlefile', 'str', 'structure', 'upf' - ] - entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points) - entry_point_group = 'aiida.data' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Data,) @@ -162,9 +136,6 @@ def DbImporterFactory(entry_point_name: str, load: bool = True) -> Optional[Unio """ from aiida.tools.dbimporters import DbImporter - deprecated_entry_points = ['cod', 'icsd', 'materialsproject', 'mpds', 'mpod', 'nninc', 'oqmd', 'pcod', 'tcod'] - entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points) - entry_point_group = 'aiida.tools.dbimporters' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (DbImporter,) @@ -205,9 +176,6 @@ def OrbitalFactory(entry_point_name: str, load: bool = True) -> Optional[Union[E """ from aiida.tools.data.orbital import Orbital - deprecated_entry_points = ['orbital', 'realhydrogen'] - entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points) - entry_point_group = 'aiida.tools.data.orbitals' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Orbital,) @@ -228,9 +196,6 @@ def ParserFactory(entry_point_name: str, load: bool = True) -> Optional[Union[En """ from aiida.parsers import Parser - deprecated_entry_points = ['arithmetic.add', 'templatereplacer.doubler'] - entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points) - entry_point_group = 'aiida.parsers' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Parser,) @@ -251,9 +216,6 @@ def SchedulerFactory(entry_point_name: str, load: bool = True) -> Optional[Union """ from aiida.schedulers import Scheduler - deprecated_entry_points = ['direct', 'lsf', 'pbspro', 'sge', 'slurm', 'torque'] - entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points) - entry_point_group = 'aiida.schedulers' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Scheduler,) @@ -273,9 +235,6 @@ def TransportFactory(entry_point_name: str, load: bool = True) -> Optional[Union """ from aiida.transports import Transport - deprecated_entry_points = ['local', 'ssh'] - entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points) - entry_point_group = 'aiida.transports' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (Transport,) @@ -297,9 +256,6 @@ def WorkflowFactory(entry_point_name: str, load: bool = True) -> Optional[Union[ from aiida.engine import WorkChain, is_process_function, workfunction from aiida.orm import WorkFunctionNode - deprecated_entry_points = ['arithmetic.multiply_add', 'arithmetic.add_multiply'] - entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points) - entry_point_group = 'aiida.workflows' entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) valid_classes = (WorkChain, workfunction) diff --git a/tests/plugins/test_entry_point.py b/tests/plugins/test_entry_point.py index ba177e10ba..cc7d2c463f 100644 --- a/tests/plugins/test_entry_point.py +++ b/tests/plugins/test_entry_point.py @@ -7,16 +7,41 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Tests for the :py:mod:`~aiida.plugins.entry_point` module.""" +"""Tests for the :mod:`~aiida.plugins.entry_point` module.""" +import pytest -from aiida.backends.testbase import AiidaTestCase -from aiida.plugins.entry_point import validate_registered_entry_points +from aiida.common.warnings import AiidaDeprecationWarning +from aiida.plugins.entry_point import EntryPoint, get_entry_point, validate_registered_entry_points -class TestEntryPoint(AiidaTestCase): - """Tests for the :py:mod:`~aiida.plugins.entry_point` module.""" +def test_validate_registered_entry_points(): + """Test the ``validate_registered_entry_points`` function.""" + validate_registered_entry_points() - @staticmethod - def test_validate_registered_entry_points(): - """Test the `validate_registered_entry_points` function.""" - validate_registered_entry_points() + +@pytest.mark.parametrize( + 'group, name', ( + ('aiida.calculations', 'arithmetic.add'), + ('aiida.data', 'array'), + ('aiida.tools.dbimporters', 'cod'), + ('aiida.tools.data.orbitals', 'orbital'), + ('aiida.parsers', 'arithmetic.add'), + ('aiida.schedulers', 'direct'), + ('aiida.transports', 'local'), + ('aiida.workflows', 'arithmetic.multiply_add'), + ) +) +def test_get_entry_point_deprecated(group, name): + """Test the ``get_entry_point`` method for a deprecated entry point. + + The entry points in the parametrization were deprecated in ``aiida-core==2.0``. To provide a deprecation pathway, + the ``get_entry_point`` method was patched to go through the factories, which would automatically load the new entry + point and issue a deprecation warning. This is what we are testing here. This test can be removed once the + deprecated entry points are removed in ``aiida-core==3.0``. + """ + warning = f'The entry point `{name}` is deprecated. Please replace it with `core.{name}`.' + + with pytest.warns(AiidaDeprecationWarning, match=warning): + entry_point = get_entry_point(group, name) + + assert isinstance(entry_point, EntryPoint)