Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factories: do not explicitly check type of entry point if load=False #5352

Merged
merged 2 commits into from
Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ jobs:

- name: Upgrade pip and setuptools
run: |
pip install --upgrade pip
pip install --upgrade pip setuptools
pip --version

- name: Build pymatgen with compatible numpy
Expand Down
6 changes: 6 additions & 0 deletions aiida/cmdline/commands/cmd_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def setup_code(ctx, non_interactive, **kwargs):
else:
kwargs['code_type'] = CodeBuilder.CodeType.STORE_AND_UPLOAD

# Convert entry point to its name
kwargs['input_plugin'] = kwargs['input_plugin'].name

code_builder = CodeBuilder(**kwargs)

try:
Expand Down Expand Up @@ -160,6 +163,9 @@ def code_duplicate(ctx, code, non_interactive, **kwargs):
if kwargs.pop('hide_original'):
code.hide()

# Convert entry point to its name
kwargs['input_plugin'] = kwargs['input_plugin'].name

code_builder = ctx.code_builder
for key, value in kwargs.items():
setattr(code_builder, key, value)
Expand Down
5 changes: 0 additions & 5 deletions aiida/orm/utils/builders/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import enum
import os

import importlib_metadata

from aiida.cmdline.utils.decorators import with_dbenv
from aiida.common.utils import ErrorAccumulator

Expand Down Expand Up @@ -155,9 +153,6 @@ def _set_code_attr(self, key, value):

Checks compatibility with other code attributes.
"""
if key == 'input_plugin' and isinstance(value, importlib_metadata.EntryPoint):
value = value.name

if key == 'description' and value is None:
value = ''

Expand Down
56 changes: 41 additions & 15 deletions aiida/plugins/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ def CalculationFactory(entry_point_name: str, load: bool = True) -> Optional[Uni
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (CalcJob, calcfunction)

if (
isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, CalcJob)) or
(is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode)
):
if not load:
return entry_point

if ((isclass(entry_point) and issubclass(entry_point, CalcJob)) or
(is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode)): # type: ignore[union-attr]
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -100,6 +101,9 @@ def CalcJobImporterFactory(entry_point_name: str, load: bool = True) -> Optional
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (CalcJobImporter,)

if not load:
return entry_point

if isclass(entry_point) and issubclass(entry_point, CalcJobImporter):
return entry_point # type: ignore[return-value]

Expand All @@ -120,7 +124,10 @@ def DataFactory(entry_point_name: str, load: bool = True) -> Optional[Union[Entr
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Data,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Data)):
if not load:
return entry_point

if isclass(entry_point) and issubclass(entry_point, Data):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -140,7 +147,10 @@ def DbImporterFactory(entry_point_name: str, load: bool = True) -> Optional[Unio
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (DbImporter,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, DbImporter)):
if not load:
return entry_point

if isclass(entry_point) and issubclass(entry_point, DbImporter):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -160,7 +170,10 @@ def GroupFactory(entry_point_name: str, load: bool = True) -> Optional[Union[Ent
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Group,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Group)):
if not load:
return entry_point

if isclass(entry_point) and issubclass(entry_point, Group):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -180,7 +193,10 @@ def OrbitalFactory(entry_point_name: str, load: bool = True) -> Optional[Union[E
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Orbital,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Orbital)):
if not load:
return entry_point

if isclass(entry_point) and issubclass(entry_point, Orbital):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -200,7 +216,10 @@ def ParserFactory(entry_point_name: str, load: bool = True) -> Optional[Union[En
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Parser,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Parser)):
if not load:
return entry_point

if isclass(entry_point) and issubclass(entry_point, Parser):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -220,7 +239,10 @@ def SchedulerFactory(entry_point_name: str, load: bool = True) -> Optional[Union
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Scheduler,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Scheduler)):
if not load:
return entry_point

if isclass(entry_point) and issubclass(entry_point, Scheduler):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -239,7 +261,10 @@ def TransportFactory(entry_point_name: str, load: bool = True) -> Optional[Union
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (Transport,)

if isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, Transport)):
if not load:
return entry_point

if isclass(entry_point) and issubclass(entry_point, Transport):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -260,10 +285,11 @@ def WorkflowFactory(entry_point_name: str, load: bool = True) -> Optional[Union[
entry_point = BaseFactory(entry_point_group, entry_point_name, load=load)
valid_classes = (WorkChain, workfunction)

if (
isinstance(entry_point, EntryPoint) or (isclass(entry_point) and issubclass(entry_point, WorkChain)) or
(is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode)
):
if not load:
return entry_point

if ((isclass(entry_point) and issubclass(entry_point, WorkChain)) or
(is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode)): # type: ignore[union-attr]
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
6 changes: 2 additions & 4 deletions tests/plugins/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest

from aiida.common.warnings import AiidaDeprecationWarning
from aiida.plugins.entry_point import EntryPoint, get_entry_point, validate_registered_entry_points
from aiida.plugins.entry_point import get_entry_point, validate_registered_entry_points


def test_validate_registered_entry_points():
Expand Down Expand Up @@ -42,6 +42,4 @@ def test_get_entry_point_deprecated(group, name):
warning = f'The entry point `{name}` is deprecated. Please replace it with `core.{name}`.'

with pytest.warns(AiidaDeprecationWarning, match=warning):
entry_point = get_entry_point(group, name)

assert isinstance(entry_point, EntryPoint)
get_entry_point(group, name)