diff --git a/aiida/backends/tests/__init__.py b/aiida/backends/tests/__init__.py index 5545bb0e52..73bf2a8c96 100644 --- a/aiida/backends/tests/__init__.py +++ b/aiida/backends/tests/__init__.py @@ -142,7 +142,8 @@ 'orm.utils.loaders': ['aiida.backends.tests.orm.utils.test_loaders'], 'orm.utils.repository': ['aiida.backends.tests.orm.utils.test_repository'], 'parsers.parser': ['aiida.backends.tests.parsers.test_parser'], - 'plugin_loader': ['aiida.backends.tests.test_plugin_loader'], + 'plugins.entry_point': ['aiida.backends.tests.plugins.test_entry_point'], + 'plugins.factories': ['aiida.backends.tests.plugins.test_factories'], 'plugins.utils': ['aiida.backends.tests.plugins.test_utils'], 'query': ['aiida.backends.tests.test_query'], 'restapi': ['aiida.backends.tests.test_restapi'], diff --git a/aiida/backends/tests/engine/test_process_function.py b/aiida/backends/tests/engine/test_process_function.py index 544e6a84ff..baf39cd914 100644 --- a/aiida/backends/tests/engine/test_process_function.py +++ b/aiida/backends/tests/engine/test_process_function.py @@ -122,9 +122,11 @@ def tearDown(self): self.assertIsNone(Process.current()) def test_properties(self): - """Test that the `is_process_function` attributes is set.""" + """Test that the `is_process_function` and `node_class` attributes are set.""" self.assertEqual(self.function_return_input.is_process_function, True) + self.assertEqual(self.function_return_input.node_class, orm.WorkFunctionNode) self.assertEqual(self.function_return_true.is_process_function, True) + self.assertEqual(self.function_return_true.node_class, orm.CalcFunctionNode) def test_plugin_version(self): """Test the version attributes of a process function.""" diff --git a/aiida/backends/tests/plugins/test_entry_point.py b/aiida/backends/tests/plugins/test_entry_point.py new file mode 100644 index 0000000000..9668dbd07c --- /dev/null +++ b/aiida/backends/tests/plugins/test_entry_point.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +"""Tests for the :py:mod:`~aiida.plugins.entry_point` module.""" +from __future__ import absolute_import + +from aiida.backends.testbase import AiidaTestCase +from aiida.plugins.entry_point import validate_registered_entry_points + + +class TestEntryPoint(AiidaTestCase): + """Tests for the :py:mod:`~aiida.plugins.entry_point` module.""" + + @staticmethod + def test_validate_registered_entry_points(): + """Test the `validate_registered_entry_points` function.""" + validate_registered_entry_points() diff --git a/aiida/backends/tests/plugins/test_factories.py b/aiida/backends/tests/plugins/test_factories.py new file mode 100644 index 0000000000..15f25f0836 --- /dev/null +++ b/aiida/backends/tests/plugins/test_factories.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +"""Tests for the :py:mod:`~aiida.plugins.factories` module.""" +from __future__ import absolute_import + +try: + from unittest.mock import patch +except ImportError: + from mock import patch + +from aiida.backends.testbase import AiidaTestCase +from aiida.common.exceptions import InvalidEntryPointTypeError +from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain +from aiida.orm import Data, Node, CalcFunctionNode, WorkFunctionNode +from aiida.parsers import Parser +from aiida.plugins import factories +from aiida.schedulers import Scheduler +from aiida.transports import Transport +from aiida.tools.data.orbital import Orbital +from aiida.tools.dbimporters import DbImporter + + +def custom_load_entry_point(group, name): + """Function that mocks `aiida.plugins.entry_point.load_entry_point` that is called by factories.""" + + @calcfunction + def calc_function(): + pass + + @workfunction + def work_function(): + pass + + entry_points = { + 'aiida.calculations': { + 'calc_job': CalcJob, + 'calc_function': calc_function, + 'work_function': work_function, + 'work_chain': WorkChain + }, + 'aiida.data': { + 'valid': Data, + 'invalid': Node, + }, + 'aiida.tools.dbimporters': { + 'valid': DbImporter, + 'invalid': Node, + }, + 'aiida.tools.data.orbitals': { + 'valid': Orbital, + 'invalid': Node, + }, + 'aiida.parsers': { + 'valid': Parser, + 'invalid': Node, + }, + 'aiida.schedulers': { + 'valid': Scheduler, + 'invalid': Node, + }, + 'aiida.transports': { + 'valid': Transport, + 'invalid': Node, + }, + 'aiida.workflows': { + 'calc_job': CalcJob, + 'calc_function': calc_function, + 'work_function': work_function, + 'work_chain': WorkChain + } + } + return entry_points[group][name] + + +class TestFactories(AiidaTestCase): + """Tests for the :py:mod:`~aiida.plugins.factories` factory classes.""" + + @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + def test_calculation_factory(self): + """Test the `CalculationFactory`.""" + plugin = factories.CalculationFactory('calc_function') + self.assertEqual(plugin.is_process_function, True) + self.assertEqual(plugin.node_class, CalcFunctionNode) + + plugin = factories.CalculationFactory('calc_job') + self.assertEqual(plugin, CalcJob) + + with self.assertRaises(InvalidEntryPointTypeError): + factories.CalculationFactory('work_function') + + with self.assertRaises(InvalidEntryPointTypeError): + factories.CalculationFactory('work_chain') + + @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + def test_workflow_factory(self): + """Test the `WorkflowFactory`.""" + plugin = factories.WorkflowFactory('work_function') + self.assertEqual(plugin.is_process_function, True) + self.assertEqual(plugin.node_class, WorkFunctionNode) + + plugin = factories.WorkflowFactory('work_chain') + self.assertEqual(plugin, WorkChain) + + with self.assertRaises(InvalidEntryPointTypeError): + factories.WorkflowFactory('calc_function') + + with self.assertRaises(InvalidEntryPointTypeError): + factories.WorkflowFactory('calc_job') + + @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + def test_data_factory(self): + """Test the `DataFactory`.""" + plugin = factories.DataFactory('valid') + self.assertEqual(plugin, Data) + + with self.assertRaises(InvalidEntryPointTypeError): + factories.DataFactory('invalid') + + @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + def test_db_importer_factory(self): + """Test the `DbImporterFactory`.""" + plugin = factories.DbImporterFactory('valid') + self.assertEqual(plugin, DbImporter) + + with self.assertRaises(InvalidEntryPointTypeError): + factories.DbImporterFactory('invalid') + + @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + def test_orbital_factory(self): + """Test the `OrbitalFactory`.""" + plugin = factories.OrbitalFactory('valid') + self.assertEqual(plugin, Orbital) + + with self.assertRaises(InvalidEntryPointTypeError): + factories.OrbitalFactory('invalid') + + @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + def test_parser_factory(self): + """Test the `ParserFactory`.""" + plugin = factories.ParserFactory('valid') + self.assertEqual(plugin, Parser) + + with self.assertRaises(InvalidEntryPointTypeError): + factories.ParserFactory('invalid') + + @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + def test_scheduler_factory(self): + """Test the `SchedulerFactory`.""" + plugin = factories.SchedulerFactory('valid') + self.assertEqual(plugin, Scheduler) + + with self.assertRaises(InvalidEntryPointTypeError): + factories.SchedulerFactory('invalid') + + @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + def test_transport_factory(self): + """Test the `TransportFactory`.""" + plugin = factories.TransportFactory('valid') + self.assertEqual(plugin, Transport) + + with self.assertRaises(InvalidEntryPointTypeError): + factories.TransportFactory('invalid') diff --git a/aiida/backends/tests/test_plugin_loader.py b/aiida/backends/tests/test_plugin_loader.py deleted file mode 100644 index ea9bdad666..0000000000 --- a/aiida/backends/tests/test_plugin_loader.py +++ /dev/null @@ -1,104 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -from __future__ import division -from __future__ import print_function -from __future__ import absolute_import - -from aiida.backends.testbase import AiidaTestCase -from aiida.engine import CalcJob, WorkChain -from aiida.orm import Data -from aiida.parsers import Parser -from aiida.plugins import factories -from aiida.plugins.entry_point import get_entry_points -from aiida.schedulers import Scheduler -from aiida.transports import Transport -from aiida.tools.dbimporters import DbImporter - - -class TestExistingPlugins(AiidaTestCase): - """ - Test the get_entry_points function and the plugin Factories. - - Will fail when: - - * If get_entry_points returns something other than a list - * Any of the plugins, distributed with aiida or installed - from external plugin repositories, fail to load - """ - - def test_existing_calculations(self): - """Test listing all preinstalled calculations.""" - entry_points = get_entry_points('aiida.calculations') - self.assertIsInstance(entry_points, list) - - for entry_point in entry_points: - cls = factories.CalculationFactory(entry_point.name) - self.assertTrue(issubclass(cls, CalcJob), - 'Calculation plugin class {} is not subclass of {}'.format(cls, CalcJob)) - - def test_existing_data(self): - """Test listing all preinstalled data classes.""" - entry_points = get_entry_points('aiida.data') - self.assertIsInstance(entry_points, list) - - for entry_point in entry_points: - cls = factories.DataFactory(entry_point.name) - self.assertTrue(issubclass(cls, Data), - 'Data plugin class {} is not subclass of {}'.format(cls, Data)) - - def test_existing_parsers(self): - """Test listing all preinstalled parsers.""" - entry_points = get_entry_points('aiida.parsers') - self.assertIsInstance(entry_points, list) - - for entry_point in entry_points: - cls = factories.ParserFactory(entry_point.name) - self.assertTrue(issubclass(cls, Parser), - 'Parser plugin class {} is not subclass of {}'.format(cls, Parser)) - - def test_existing_schedulers(self): - """Test listing all preinstalled schedulers.""" - entry_points = get_entry_points('aiida.schedulers') - self.assertIsInstance(entry_points, list) - - for entry_point in entry_points: - cls = factories.SchedulerFactory(entry_point.name) - self.assertTrue(issubclass(cls, Scheduler), - 'Scheduler plugin class {} is not subclass of {}'.format(cls, Scheduler)) - - def test_existing_transports(self): - """Test listing all preinstalled transports.""" - entry_points = get_entry_points('aiida.transports') - self.assertIsInstance(entry_points, list) - - for entry_point in entry_points: - cls = factories.TransportFactory(entry_point.name) - self.assertTrue(issubclass(cls, Transport), - 'Transport plugin class {} is not subclass of {}'.format(cls, Transport)) - - def test_existing_workflows(self): - """Test listing all preinstalled workflows.""" - entry_points = get_entry_points('aiida.workflows') - self.assertIsInstance(entry_points, list) - - for entry_point in entry_points: - cls = factories.WorkflowFactory(entry_point.name) - self.assertTrue(issubclass(cls, WorkChain), - 'Workflow plugin class {} is not a subclass of {}'.format(cls, WorkChain)) - - def test_existing_dbimporters(self): - """Test listing all preinstalled dbimporter plugins.""" - entry_points = get_entry_points('aiida.tools.dbimporters') - self.assertIsInstance(entry_points, list) - - for entry_point in entry_points: - cls = factories.DbImporterFactory(entry_point.name) - self.assertTrue(issubclass(cls, DbImporter), - 'DbImporter plugin class {} is not subclass of {}'.format(cls, DbImporter)) diff --git a/aiida/cmdline/commands/cmd_devel.py b/aiida/cmdline/commands/cmd_devel.py index 0bb9c46f00..bda51c68ed 100644 --- a/aiida/cmdline/commands/cmd_devel.py +++ b/aiida/cmdline/commands/cmd_devel.py @@ -92,6 +92,21 @@ def devel_run_daemon(): start_daemon() +@verdi_devel.command('validate-plugins') +@decorators.with_dbenv() +def devel_validate_plugins(): + """Validate all plugins by checking they can be loaded.""" + from aiida.common.exceptions import EntryPointError + from aiida.plugins.entry_point import validate_registered_entry_points + + try: + validate_registered_entry_points() + except EntryPointError as exception: + echo.echo_critical(str(exception)) + + echo.echo_success('all registered plugins could successfully loaded.') + + @verdi_devel.command('tests') @click.argument('paths', nargs=-1, type=TestModuleParamType(), required=False) @options.VERBOSE(help='Print the class and function name for each test.') diff --git a/aiida/common/exceptions.py b/aiida/common/exceptions.py index cdd893aea2..6e6074e658 100644 --- a/aiida/common/exceptions.py +++ b/aiida/common/exceptions.py @@ -16,10 +16,11 @@ 'AiidaException', 'NotExistent', 'MultipleObjectsError', 'RemoteOperationError', 'ContentNotExistent', 'FailedError', 'StoringNotAllowed', 'ModificationNotAllowed', 'IntegrityError', 'UniquenessError', 'EntryPointError', 'MissingEntryPointError', 'MultipleEntryPointError', 'LoadingEntryPointError', - 'InvalidOperation', 'ParsingError', 'InternalError', 'PluginInternalError', 'ValidationError', 'ConfigurationError', - 'ProfileConfigurationError', 'MissingConfigurationError', 'ConfigurationVersionError', 'DbContentError', - 'InputValidationError', 'FeatureNotAvailable', 'FeatureDisabled', 'LicensingException', 'TestsNotAllowedError', - 'UnsupportedSpeciesError', 'TransportTaskException', 'OutputParsingError' + 'InvalidEntryPointTypeError', 'InvalidOperation', 'ParsingError', 'InternalError', 'PluginInternalError', + 'ValidationError', 'ConfigurationError', 'ProfileConfigurationError', 'MissingConfigurationError', + 'ConfigurationVersionError', 'DbContentError', 'InputValidationError', 'FeatureNotAvailable', 'FeatureDisabled', + 'LicensingException', 'TestsNotAllowedError', 'UnsupportedSpeciesError', 'TransportTaskException', + 'OutputParsingError' ) @@ -109,6 +110,10 @@ class LoadingEntryPointError(EntryPointError): """Raised when the resource corresponding to requested entry point cannot be imported.""" +class InvalidEntryPointTypeError(EntryPointError): + """Raised when a loaded entry point has a type that is not supported by the corresponding entry point group.""" + + class InvalidOperation(AiidaException): """ The allowed operation is not valid (e.g., when trying to add a non-internal attribute diff --git a/aiida/engine/__init__.py b/aiida/engine/__init__.py index 1d5d05e2ee..0460084da0 100644 --- a/aiida/engine/__init__.py +++ b/aiida/engine/__init__.py @@ -15,5 +15,6 @@ from .launch import * from .processes import * +from .utils import * -__all__ = (launch.__all__ + processes.__all__) +__all__ = (launch.__all__ + processes.__all__ + utils.__all__) diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 8f5dd2c2d7..2e8c7fccaf 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -196,6 +196,7 @@ def decorated_function(*args, **kwargs): decorated_function.run_get_pk = run_get_pk decorated_function.run_get_node = run_get_node decorated_function.is_process_function = True + decorated_function.node_class = node_class return decorated_function diff --git a/aiida/engine/utils.py b/aiida/engine/utils.py index 3c2286ffa9..311b3f47ee 100644 --- a/aiida/engine/utils.py +++ b/aiida/engine/utils.py @@ -20,7 +20,7 @@ import tornado.ioloop from tornado import concurrent, gen -__all__ = ('interruptable_task', 'InterruptableFuture') +__all__ = ('interruptable_task', 'InterruptableFuture', 'is_process_function') LOGGER = logging.getLogger(__name__) PROCESS_STATE_CHANGE_KEY = 'process|state_change|{}' diff --git a/aiida/plugins/entry_point.py b/aiida/plugins/entry_point.py index df9f24523b..667e6a761f 100644 --- a/aiida/plugins/entry_point.py +++ b/aiida/plugins/entry_point.py @@ -68,6 +68,34 @@ class EntryPointFormat(enum.Enum): } +def validate_registered_entry_points(): + """Validate all registered entry points by loading them with the corresponding factory. + + :raises EntryPointError: if any of the registered entry points cannot be loaded. This can happen if: + * The entry point cannot uniquely be resolved + * The resource registered at the entry point cannot be imported + * The resource's type is incompatible with the entry point group that it is defined in. + + """ + from . import factories + + factory_mapping = { + 'aiida.calculations': factories.CalculationFactory, + 'aiida.data': factories.DataFactory, + 'aiida.parsers': factories.ParserFactory, + 'aiida.schedulers': factories.SchedulerFactory, + 'aiida.transports': factories.TransportFactory, + 'aiida.tools.dbimporters': factories.DbImporterFactory, + 'aiida.tools.data.orbital': factories.OrbitalFactory, + 'aiida.workflows': factories.WorkflowFactory, + } + + for entry_point_group, factory in factory_mapping.items(): + entry_points = get_entry_points(entry_point_group) + for entry_point in entry_points: + factory(entry_point.name) + + def format_entry_point_string(group, name, fmt=EntryPointFormat.FULL): """ Format an entry point string for a given entry point group and name, based on the specified format diff --git a/aiida/plugins/factories.py b/aiida/plugins/factories.py index 0ffe71c4be..b51b487272 100644 --- a/aiida/plugins/factories.py +++ b/aiida/plugins/factories.py @@ -7,13 +7,13 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name +# pylint: disable=invalid-name,inconsistent-return-statements """Definition of factories to load classes from the various plugin groups.""" from __future__ import division from __future__ import print_function from __future__ import absolute_import -from .entry_point import load_entry_point +from aiida.common.exceptions import InvalidEntryPointTypeError __all__ = ( 'BaseFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', 'OrbitalFactory', 'ParserFactory', @@ -21,6 +21,19 @@ ) +def raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes): + """Raise an `InvalidEntryPointTypeError` with formatted message. + + :param entry_point_name: name of the entry point + :param entry_point_group: name of the entry point group + :param valid_classes: tuple of valid classes for the given entry point group + :raises aiida.common.InvalidEntryPointTypeError: always + """ + template = 'entry point `{}` registered in group `{}` is invalid because its type is not one of: {}' + args = (entry_point_name, entry_point_group, ', '.join([e.__name__ for e in valid_classes])) + raise InvalidEntryPointTypeError(template.format(*args)) + + def BaseFactory(group, name): """Return the plugin class registered under a given entry point group and name. @@ -31,76 +44,165 @@ def BaseFactory(group, name): :raises aiida.common.MultipleEntryPointError: entry point could not be uniquely resolved :raises aiida.common.LoadingEntryPointError: entry point could not be loaded """ + from .entry_point import load_entry_point return load_entry_point(group, name) -def CalculationFactory(entry_point): +def CalculationFactory(entry_point_name): """Return the `CalcJob` sub class registered under the given entry point. - :param entry_point: the entry point name + :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob` + :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ - return BaseFactory('aiida.calculations', entry_point) + from aiida.engine import CalcJob, calcfunction, is_process_function + from aiida.orm import CalcFunctionNode + + entry_point_group = 'aiida.calculations' + entry_point = BaseFactory(entry_point_group, entry_point_name) + valid_classes = (CalcJob, calcfunction) + + if issubclass(entry_point, CalcJob): + return entry_point + + if is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode: + return entry_point + + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) -def DataFactory(entry_point): +def DataFactory(entry_point_name): """Return the `Data` sub class registered under the given entry point. - :param entry_point: the entry point name + :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.orm.nodes.data.data.Data` + :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ - return BaseFactory('aiida.data', entry_point) + from aiida.orm import Data + entry_point_group = 'aiida.data' + entry_point = BaseFactory(entry_point_group, entry_point_name) + valid_classes = (Data,) -def DbImporterFactory(entry_point): + if issubclass(entry_point, Data): + return entry_point + + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) + + +def DbImporterFactory(entry_point_name): """Return the `DbImporter` sub class registered under the given entry point. - :param entry_point: the entry point name + :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.tools.dbimporters.baseclasses.DbImporter` + :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ - return BaseFactory('aiida.tools.dbimporters', entry_point) + from aiida.tools.dbimporters import DbImporter + + entry_point_group = 'aiida.tools.dbimporters' + entry_point = BaseFactory(entry_point_group, entry_point_name) + valid_classes = (DbImporter,) + if issubclass(entry_point, DbImporter): + return entry_point -def OrbitalFactory(entry_point): + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) + + +def OrbitalFactory(entry_point_name): """Return the `Orbital` sub class registered under the given entry point. - :param entry_point: the entry point name + :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.tools.data.orbital.orbital.Orbital` + :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ - return BaseFactory('aiida.tools.data.orbitals', entry_point) + from aiida.tools.data.orbital import Orbital + + entry_point_group = 'aiida.tools.data.orbitals' + entry_point = BaseFactory(entry_point_group, entry_point_name) + valid_classes = (Orbital,) + if issubclass(entry_point, Orbital): + return entry_point -def ParserFactory(entry_point): + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) + + +def ParserFactory(entry_point_name): """Return the `Parser` sub class registered under the given entry point. - :param entry_point: the entry point name + :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.parsers.parser.Parser` + :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ - return BaseFactory('aiida.parsers', entry_point) + from aiida.parsers import Parser + + entry_point_group = 'aiida.parsers' + entry_point = BaseFactory(entry_point_group, entry_point_name) + valid_classes = (Parser,) + + if issubclass(entry_point, Parser): + return entry_point + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) -def SchedulerFactory(entry_point): + +def SchedulerFactory(entry_point_name): """Return the `Scheduler` sub class registered under the given entry point. - :param entry_point: the entry point name + :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.schedulers.scheduler.Scheduler` + :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ - return BaseFactory('aiida.schedulers', entry_point) + from aiida.schedulers import Scheduler + + entry_point_group = 'aiida.schedulers' + entry_point = BaseFactory(entry_point_group, entry_point_name) + valid_classes = (Scheduler,) + + if issubclass(entry_point, Scheduler): + return entry_point + + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) -def TransportFactory(entry_point): +def TransportFactory(entry_point_name): """Return the `Transport` sub class registered under the given entry point. - :param entry_point: the entry point name + :param entry_point_name: the entry point name :return: sub class of :py:class:`~aiida.transports.transport.Transport` + :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ - return BaseFactory('aiida.transports', entry_point) + from aiida.transports import Transport + entry_point_group = 'aiida.transports' + entry_point = BaseFactory(entry_point_group, entry_point_name) + valid_classes = (Transport,) -def WorkflowFactory(entry_point): + if issubclass(entry_point, Transport): + return entry_point + + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) + + +def WorkflowFactory(entry_point_name): """Return the `WorkChain` sub class registered under the given entry point. - :param entry_point: the entry point name - :return: sub class of :py:class:`~aiida.engine.processes.workchains.workchain.WorkChain` + :param entry_point_name: the entry point name + :return: sub class of :py:class:`~aiida.engine.processes.workchains.workchain.WorkChain` or a `workfunction` + :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. """ - return BaseFactory('aiida.workflows', entry_point) + from aiida.engine import WorkChain, is_process_function, workfunction + from aiida.orm import WorkFunctionNode + + entry_point_group = 'aiida.workflows' + entry_point = BaseFactory(entry_point_group, entry_point_name) + valid_classes = (WorkChain, workfunction) + + if issubclass(entry_point, WorkChain): + return entry_point + + if is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode: + return entry_point + + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) diff --git a/docs/requirements_for_rtd.txt b/docs/requirements_for_rtd.txt index 419b245b80..1d4af84456 100644 --- a/docs/requirements_for_rtd.txt +++ b/docs/requirements_for_rtd.txt @@ -31,6 +31,7 @@ kiwipy[rmq]==0.5.1 markupsafe==1.1.1 marshmallow-sqlalchemy==0.19.0 mock==3.0.5 +mock==3.0.5; python_version<'3.3' monty==2.0.4 numpy==1.16.4 paramiko==2.6.0 diff --git a/docs/source/verdi/verdi_user_guide.rst b/docs/source/verdi/verdi_user_guide.rst index 5308351db6..b3f00ebd01 100644 --- a/docs/source/verdi/verdi_user_guide.rst +++ b/docs/source/verdi/verdi_user_guide.rst @@ -370,9 +370,10 @@ Below is a list with all available subcommands. --help Show this message and exit. Commands: - check-load-time Check for common indicators that slowdown `verdi`. - run_daemon Run a daemon instance in the current interpreter. - tests Run the unittest suite or parts of it. + check-load-time Check for common indicators that slowdown `verdi`. + run_daemon Run a daemon instance in the current interpreter. + tests Run the unittest suite or parts of it. + validate-plugins Validate all plugins by checking they can be loaded. .. _verdi_export: diff --git a/setup.json b/setup.json index 9c4be50f7d..c1802e012c 100644 --- a/setup.json +++ b/setup.json @@ -111,6 +111,7 @@ "pg8000<1.13.0", "pgtest==1.3.1", "pytest==4.6.6", + "mock==3.0.5; python_version<'3.3'", "sqlalchemy-diff==0.1.3", "unittest2==1.1.0; python_version<'3.5'" ],