Skip to content

Commit

Permalink
get_entry_point: add proper deprecation pathway for old entry points (
Browse files Browse the repository at this point in the history
#5206)

All entry points that ship with aiida-core were deprecated for v2.0
as they were changed to be properly prefixed with core.. To make sure
the old entry points would still be automatically loaded, the factories
were updated to automatically catch them, print a deprecation warning,
and load the new entry point instead.

However, the factory was not the correct place to put this logic, since
the `get_entry_point` method, which the factories call and is the lowest
function in the stack that actually retrieves the entry point, can also
be called directly, circumventing the deprecation mechanic added to the
factories. This would result in the deprecated entry points raising an
exception when being loaded, for example in the parameter types of the
command line that have support for specific entry points, such as the
`IdentifierParamType`.

The solution is to move the deprecation mechanic from the factories to
the lowest layer of `get_entry_point`.
  • Loading branch information
sphuber authored Nov 2, 2021
1 parent 2d6df12 commit 5f259bd
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 54 deletions.
46 changes: 46 additions & 0 deletions aiida/plugins/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import functools
import traceback
from typing import Any, List, Optional, Sequence, Set, Tuple
from warnings import warn

# importlib.metadata was introduced into the standard library in python 3.8,
# but was then updated in python 3.10 to use an improved API.
Expand All @@ -20,6 +21,7 @@
from importlib_metadata import entry_points as _eps

from aiida.common.exceptions import LoadingEntryPointError, MissingEntryPointError, MultipleEntryPointError
from aiida.common.warnings import AiidaDeprecationWarning

__all__ = ('load_entry_point', 'load_entry_point_from_string', 'parse_entry_point', 'get_entry_points')

Expand Down Expand Up @@ -72,6 +74,21 @@ class EntryPointFormat(enum.Enum):
'aiida.workflows': 'aiida.workflows',
}

DEPRECATED_ENTRY_POINTS_MAPPING = {
'aiida.calculations': ['arithmetic.add', 'templatereplacer'],
'aiida.data': [
'array', 'array.bands', 'array.kpoints', 'array.projection', 'array.trajectory', 'array.xy', 'base', 'bool',
'cif', 'code', 'dict', 'float', 'folder', 'int', 'list', 'numeric', 'orbital', 'remote', 'remote.stash',
'remote.stash.folder', 'singlefile', 'str', 'structure', 'upf'
],
'aiida.tools.dbimporters': ['cod', 'icsd', 'materialsproject', 'mpds', 'mpod', 'nninc', 'oqmd', 'pcod', 'tcod'],
'aiida.tools.data.orbitals': ['orbital', 'realhydrogen'],
'aiida.parsers': ['arithmetic.add', 'templatereplacer.doubler'],
'aiida.schedulers': ['direct', 'lsf', 'pbspro', 'sge', 'slurm', 'torque'],
'aiida.transports': ['local', 'ssh'],
'aiida.workflows': ['arithmetic.multiply_add', 'arithmetic.add_multiply'],
}


def parse_entry_point(group: str, spec: str) -> EntryPoint:
"""Return an entry point, given its group and spec (as formatted in the setup)"""
Expand Down Expand Up @@ -260,6 +277,8 @@ def get_entry_point(group: str, name: str) -> EntryPoint:
:raises aiida.common.MissingEntryPointError: entry point was not registered
"""
# The next line should be removed for ``aiida-core==3.0`` when the old deprecated entry points are fully removed.
name = convert_potentially_deprecated_entry_point(group, name)
found = eps().select(group=group, name=name)
if name not in found.names:
raise MissingEntryPointError(f"Entry point '{name}' not found in group '{group}'")
Expand All @@ -268,6 +287,33 @@ def get_entry_point(group: str, name: str) -> EntryPoint:
return found[name]


def convert_potentially_deprecated_entry_point(group: str, name: str) -> str:
"""Check whether the specified entry point is deprecated, in which case print warning and convert to new name.
For `aiida-core==2.0` all existing entry points where properly prefixed with ``core.`` and the old entry points were
deprecated. To provide a smooth transition these deprecated entry points are detected in ``get_entry_point``, which
is the lowest function that tries to resolve an entry point string, by calling this function.
If the entry point corresponds to a deprecated one, a warning is raised and the new corresponding entry point name
is returned.
This method should be removed in ``aiida-core==3.0``.
"""
try:
deprecated_entry_points = DEPRECATED_ENTRY_POINTS_MAPPING[group]
except KeyError:
return name
else:
if name in deprecated_entry_points:
warn(
f'The entry point `{name}` is deprecated. Please replace it with `core.{name}`.',
AiidaDeprecationWarning
)
name = f'core.{name}'

return name


@functools.lru_cache(maxsize=100)
def get_entry_point_from_class(class_module: str, class_name: str) -> Tuple[Optional[str], Optional[EntryPoint]]:
"""
Expand Down
46 changes: 1 addition & 45 deletions aiida/plugins/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# pylint: disable=invalid-name,cyclic-import
"""Definition of factories to load classes from the various plugin groups."""
from inspect import isclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union

from importlib_metadata import EntryPoint

Expand All @@ -31,22 +31,6 @@
from aiida.transports import Transport


def warn_deprecated_entry_point(entry_point_name: str, deprecated_entry_points: List[str]) -> str:
"""If the ``entry_point_name`` is part of the list of ``deprecated_entry_points``, raise a warning."""
from warnings import warn

from aiida.common.warnings import AiidaDeprecationWarning

if entry_point_name in deprecated_entry_points:
warn(
f'The entry point `{entry_point_name}` is deprecated. Please replace it with `core.{entry_point_name}`.',
AiidaDeprecationWarning
)
entry_point_name = f'core.{entry_point_name}'

return entry_point_name


def raise_invalid_type_error(entry_point_name: str, entry_point_group: str, valid_classes: Tuple[Any, ...]) -> None:
"""Raise an `InvalidEntryPointTypeError` with formatted message.
Expand Down Expand Up @@ -90,9 +74,6 @@ def CalculationFactory(entry_point_name: str, load: bool = True) -> Optional[Uni
from aiida.engine import CalcJob, calcfunction, is_process_function
from aiida.orm import CalcFunctionNode

deprecated_entry_points = ['arithmetic.add', 'templatereplacer']
entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points)

entry_point_group = 'aiida.calculations'
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (CalcJob, calcfunction)
Expand Down Expand Up @@ -135,13 +116,6 @@ def DataFactory(entry_point_name: str, load: bool = True) -> Optional[Union[Entr
"""
from aiida.orm import Data

deprecated_entry_points = [
'array', 'array.bands', 'array.kpoints', 'array.projection', 'array.trajectory', 'array.xy', 'base', 'bool',
'cif', 'code', 'dict', 'float', 'folder', 'int', 'list', 'numeric', 'orbital', 'remote', 'remote.stash',
'remote.stash.folder', 'singlefile', 'str', 'structure', 'upf'
]
entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points)

entry_point_group = 'aiida.data'
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Data,)
Expand All @@ -162,9 +136,6 @@ def DbImporterFactory(entry_point_name: str, load: bool = True) -> Optional[Unio
"""
from aiida.tools.dbimporters import DbImporter

deprecated_entry_points = ['cod', 'icsd', 'materialsproject', 'mpds', 'mpod', 'nninc', 'oqmd', 'pcod', 'tcod']
entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points)

entry_point_group = 'aiida.tools.dbimporters'
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (DbImporter,)
Expand Down Expand Up @@ -205,9 +176,6 @@ def OrbitalFactory(entry_point_name: str, load: bool = True) -> Optional[Union[E
"""
from aiida.tools.data.orbital import Orbital

deprecated_entry_points = ['orbital', 'realhydrogen']
entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points)

entry_point_group = 'aiida.tools.data.orbitals'
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Orbital,)
Expand All @@ -228,9 +196,6 @@ def ParserFactory(entry_point_name: str, load: bool = True) -> Optional[Union[En
"""
from aiida.parsers import Parser

deprecated_entry_points = ['arithmetic.add', 'templatereplacer.doubler']
entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points)

entry_point_group = 'aiida.parsers'
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Parser,)
Expand All @@ -251,9 +216,6 @@ def SchedulerFactory(entry_point_name: str, load: bool = True) -> Optional[Union
"""
from aiida.schedulers import Scheduler

deprecated_entry_points = ['direct', 'lsf', 'pbspro', 'sge', 'slurm', 'torque']
entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points)

entry_point_group = 'aiida.schedulers'
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Scheduler,)
Expand All @@ -273,9 +235,6 @@ def TransportFactory(entry_point_name: str, load: bool = True) -> Optional[Union
"""
from aiida.transports import Transport

deprecated_entry_points = ['local', 'ssh']
entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points)

entry_point_group = 'aiida.transports'
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Transport,)
Expand All @@ -297,9 +256,6 @@ def WorkflowFactory(entry_point_name: str, load: bool = True) -> Optional[Union[
from aiida.engine import WorkChain, is_process_function, workfunction
from aiida.orm import WorkFunctionNode

deprecated_entry_points = ['arithmetic.multiply_add', 'arithmetic.add_multiply']
entry_point_name = warn_deprecated_entry_point(entry_point_name, deprecated_entry_points)

entry_point_group = 'aiida.workflows'
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (WorkChain, workfunction)
Expand Down
43 changes: 34 additions & 9 deletions tests/plugins/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,41 @@
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for the :py:mod:`~aiida.plugins.entry_point` module."""
"""Tests for the :mod:`~aiida.plugins.entry_point` module."""
import pytest

from aiida.backends.testbase import AiidaTestCase
from aiida.plugins.entry_point import validate_registered_entry_points
from aiida.common.warnings import AiidaDeprecationWarning
from aiida.plugins.entry_point import EntryPoint, get_entry_point, validate_registered_entry_points


class TestEntryPoint(AiidaTestCase):
"""Tests for the :py:mod:`~aiida.plugins.entry_point` module."""
def test_validate_registered_entry_points():
"""Test the ``validate_registered_entry_points`` function."""
validate_registered_entry_points()

@staticmethod
def test_validate_registered_entry_points():
"""Test the `validate_registered_entry_points` function."""
validate_registered_entry_points()

@pytest.mark.parametrize(
'group, name', (
('aiida.calculations', 'arithmetic.add'),
('aiida.data', 'array'),
('aiida.tools.dbimporters', 'cod'),
('aiida.tools.data.orbitals', 'orbital'),
('aiida.parsers', 'arithmetic.add'),
('aiida.schedulers', 'direct'),
('aiida.transports', 'local'),
('aiida.workflows', 'arithmetic.multiply_add'),
)
)
def test_get_entry_point_deprecated(group, name):
"""Test the ``get_entry_point`` method for a deprecated entry point.
The entry points in the parametrization were deprecated in ``aiida-core==2.0``. To provide a deprecation pathway,
the ``get_entry_point`` method was patched to go through the factories, which would automatically load the new entry
point and issue a deprecation warning. This is what we are testing here. This test can be removed once the
deprecated entry points are removed in ``aiida-core==3.0``.
"""
warning = f'The entry point `{name}` is deprecated. Please replace it with `core.{name}`.'

with pytest.warns(AiidaDeprecationWarning, match=warning):
entry_point = get_entry_point(group, name)

assert isinstance(entry_point, EntryPoint)

0 comments on commit 5f259bd

Please sign in to comment.