Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type checks to all plugin factories #3456

Merged
merged 1 commit into from
Oct 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aiida/backends/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@
'orm.utils.loaders': ['aiida.backends.tests.orm.utils.test_loaders'],
'orm.utils.repository': ['aiida.backends.tests.orm.utils.test_repository'],
'parsers.parser': ['aiida.backends.tests.parsers.test_parser'],
'plugin_loader': ['aiida.backends.tests.test_plugin_loader'],
'plugins.entry_point': ['aiida.backends.tests.plugins.test_entry_point'],
'plugins.factories': ['aiida.backends.tests.plugins.test_factories'],
'plugins.utils': ['aiida.backends.tests.plugins.test_utils'],
'query': ['aiida.backends.tests.test_query'],
'restapi': ['aiida.backends.tests.test_restapi'],
Expand Down
4 changes: 3 additions & 1 deletion aiida/backends/tests/engine/test_process_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,11 @@ def tearDown(self):
self.assertIsNone(Process.current())

def test_properties(self):
"""Test that the `is_process_function` attributes is set."""
"""Test that the `is_process_function` and `node_class` attributes are set."""
self.assertEqual(self.function_return_input.is_process_function, True)
self.assertEqual(self.function_return_input.node_class, orm.WorkFunctionNode)
self.assertEqual(self.function_return_true.is_process_function, True)
self.assertEqual(self.function_return_true.node_class, orm.CalcFunctionNode)

def test_plugin_version(self):
"""Test the version attributes of a process function."""
Expand Down
15 changes: 15 additions & 0 deletions aiida/backends/tests/plugins/test_entry_point.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
"""Tests for the :py:mod:`~aiida.plugins.entry_point` module."""
from __future__ import absolute_import

from aiida.backends.testbase import AiidaTestCase
from aiida.plugins.entry_point import validate_registered_entry_points


class TestEntryPoint(AiidaTestCase):
"""Tests for the :py:mod:`~aiida.plugins.entry_point` module."""

@staticmethod
def test_validate_registered_entry_points():
"""Test the `validate_registered_entry_points` function."""
validate_registered_entry_points()
161 changes: 161 additions & 0 deletions aiida/backends/tests/plugins/test_factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# -*- coding: utf-8 -*-
"""Tests for the :py:mod:`~aiida.plugins.factories` module."""
from __future__ import absolute_import

try:
from unittest.mock import patch
except ImportError:
from mock import patch

from aiida.backends.testbase import AiidaTestCase
from aiida.common.exceptions import InvalidEntryPointTypeError
from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain
from aiida.orm import Data, Node, CalcFunctionNode, WorkFunctionNode
from aiida.parsers import Parser
from aiida.plugins import factories
from aiida.schedulers import Scheduler
from aiida.transports import Transport
from aiida.tools.data.orbital import Orbital
from aiida.tools.dbimporters import DbImporter


def custom_load_entry_point(group, name):
"""Function that mocks `aiida.plugins.entry_point.load_entry_point` that is called by factories."""

@calcfunction
def calc_function():
pass

@workfunction
def work_function():
pass

entry_points = {
'aiida.calculations': {
'calc_job': CalcJob,
'calc_function': calc_function,
'work_function': work_function,
'work_chain': WorkChain
},
'aiida.data': {
'valid': Data,
'invalid': Node,
},
'aiida.tools.dbimporters': {
'valid': DbImporter,
'invalid': Node,
},
'aiida.tools.data.orbitals': {
'valid': Orbital,
'invalid': Node,
},
'aiida.parsers': {
'valid': Parser,
'invalid': Node,
},
'aiida.schedulers': {
'valid': Scheduler,
'invalid': Node,
},
'aiida.transports': {
'valid': Transport,
'invalid': Node,
},
'aiida.workflows': {
'calc_job': CalcJob,
'calc_function': calc_function,
'work_function': work_function,
'work_chain': WorkChain
}
}
return entry_points[group][name]


class TestFactories(AiidaTestCase):
"""Tests for the :py:mod:`~aiida.plugins.factories` factory classes."""

@patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point)
def test_calculation_factory(self):
"""Test the `CalculationFactory`."""
plugin = factories.CalculationFactory('calc_function')
self.assertEqual(plugin.is_process_function, True)
self.assertEqual(plugin.node_class, CalcFunctionNode)

plugin = factories.CalculationFactory('calc_job')
self.assertEqual(plugin, CalcJob)

with self.assertRaises(InvalidEntryPointTypeError):
factories.CalculationFactory('work_function')

with self.assertRaises(InvalidEntryPointTypeError):
factories.CalculationFactory('work_chain')

@patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point)
def test_workflow_factory(self):
"""Test the `WorkflowFactory`."""
plugin = factories.WorkflowFactory('work_function')
self.assertEqual(plugin.is_process_function, True)
self.assertEqual(plugin.node_class, WorkFunctionNode)

plugin = factories.WorkflowFactory('work_chain')
self.assertEqual(plugin, WorkChain)

with self.assertRaises(InvalidEntryPointTypeError):
factories.WorkflowFactory('calc_function')

with self.assertRaises(InvalidEntryPointTypeError):
factories.WorkflowFactory('calc_job')

@patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point)
def test_data_factory(self):
"""Test the `DataFactory`."""
plugin = factories.DataFactory('valid')
self.assertEqual(plugin, Data)

with self.assertRaises(InvalidEntryPointTypeError):
factories.DataFactory('invalid')

@patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point)
def test_db_importer_factory(self):
"""Test the `DbImporterFactory`."""
plugin = factories.DbImporterFactory('valid')
self.assertEqual(plugin, DbImporter)

with self.assertRaises(InvalidEntryPointTypeError):
factories.DbImporterFactory('invalid')

@patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point)
def test_orbital_factory(self):
"""Test the `OrbitalFactory`."""
plugin = factories.OrbitalFactory('valid')
self.assertEqual(plugin, Orbital)

with self.assertRaises(InvalidEntryPointTypeError):
factories.OrbitalFactory('invalid')

@patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point)
def test_parser_factory(self):
"""Test the `ParserFactory`."""
plugin = factories.ParserFactory('valid')
self.assertEqual(plugin, Parser)

with self.assertRaises(InvalidEntryPointTypeError):
factories.ParserFactory('invalid')

@patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point)
def test_scheduler_factory(self):
"""Test the `SchedulerFactory`."""
plugin = factories.SchedulerFactory('valid')
self.assertEqual(plugin, Scheduler)

with self.assertRaises(InvalidEntryPointTypeError):
factories.SchedulerFactory('invalid')

@patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point)
def test_transport_factory(self):
"""Test the `TransportFactory`."""
plugin = factories.TransportFactory('valid')
self.assertEqual(plugin, Transport)

with self.assertRaises(InvalidEntryPointTypeError):
factories.TransportFactory('invalid')
104 changes: 0 additions & 104 deletions aiida/backends/tests/test_plugin_loader.py

This file was deleted.

15 changes: 15 additions & 0 deletions aiida/cmdline/commands/cmd_devel.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ def devel_run_daemon():
start_daemon()


@verdi_devel.command('validate-plugins')
@decorators.with_dbenv()
def devel_validate_plugins():
"""Validate all plugins by checking they can be loaded."""
from aiida.common.exceptions import EntryPointError
from aiida.plugins.entry_point import validate_registered_entry_points

try:
validate_registered_entry_points()
except EntryPointError as exception:
echo.echo_critical(str(exception))

echo.echo_success('all registered plugins could successfully loaded.')


@verdi_devel.command('tests')
@click.argument('paths', nargs=-1, type=TestModuleParamType(), required=False)
@options.VERBOSE(help='Print the class and function name for each test.')
Expand Down
13 changes: 9 additions & 4 deletions aiida/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
'AiidaException', 'NotExistent', 'MultipleObjectsError', 'RemoteOperationError', 'ContentNotExistent',
'FailedError', 'StoringNotAllowed', 'ModificationNotAllowed', 'IntegrityError', 'UniquenessError',
'EntryPointError', 'MissingEntryPointError', 'MultipleEntryPointError', 'LoadingEntryPointError',
'InvalidOperation', 'ParsingError', 'InternalError', 'PluginInternalError', 'ValidationError', 'ConfigurationError',
'ProfileConfigurationError', 'MissingConfigurationError', 'ConfigurationVersionError', 'DbContentError',
'InputValidationError', 'FeatureNotAvailable', 'FeatureDisabled', 'LicensingException', 'TestsNotAllowedError',
'UnsupportedSpeciesError', 'TransportTaskException', 'OutputParsingError'
'InvalidEntryPointTypeError', 'InvalidOperation', 'ParsingError', 'InternalError', 'PluginInternalError',
'ValidationError', 'ConfigurationError', 'ProfileConfigurationError', 'MissingConfigurationError',
'ConfigurationVersionError', 'DbContentError', 'InputValidationError', 'FeatureNotAvailable', 'FeatureDisabled',
'LicensingException', 'TestsNotAllowedError', 'UnsupportedSpeciesError', 'TransportTaskException',
'OutputParsingError'
)


Expand Down Expand Up @@ -109,6 +110,10 @@ class LoadingEntryPointError(EntryPointError):
"""Raised when the resource corresponding to requested entry point cannot be imported."""


class InvalidEntryPointTypeError(EntryPointError):
"""Raised when a loaded entry point has a type that is not supported by the corresponding entry point group."""


class InvalidOperation(AiidaException):
"""
The allowed operation is not valid (e.g., when trying to add a non-internal attribute
Expand Down
3 changes: 2 additions & 1 deletion aiida/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@

from .launch import *
from .processes import *
from .utils import *

__all__ = (launch.__all__ + processes.__all__)
__all__ = (launch.__all__ + processes.__all__ + utils.__all__)
1 change: 1 addition & 0 deletions aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def decorated_function(*args, **kwargs):
decorated_function.run_get_pk = run_get_pk
decorated_function.run_get_node = run_get_node
decorated_function.is_process_function = True
decorated_function.node_class = node_class

return decorated_function

Expand Down
2 changes: 1 addition & 1 deletion aiida/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tornado.ioloop
from tornado import concurrent, gen

__all__ = ('interruptable_task', 'InterruptableFuture')
__all__ = ('interruptable_task', 'InterruptableFuture', 'is_process_function')

LOGGER = logging.getLogger(__name__)
PROCESS_STATE_CHANGE_KEY = 'process|state_change|{}'
Expand Down
Loading