diff --git a/aiida_submission_controller/base.py b/aiida_submission_controller/base.py index 60ef8a6..8a70be5 100644 --- a/aiida_submission_controller/base.py +++ b/aiida_submission_controller/base.py @@ -2,47 +2,42 @@ """A prototype class to submit processes in batches, avoiding to submit too many.""" import abc import logging + from aiida import engine, orm +from aiida.common import NotExistent +from pydantic import BaseModel, validator CMDLINE_LOGGER = logging.getLogger('verdi') -class BaseSubmissionController: +def validate_group_exists(value: str) -> str: + """Validator that makes sure the ``Group`` with the provided label exists.""" + try: + orm.Group.collection.get(label=value) + except NotExistent as exc: + raise ValueError( + f'Group with label `{value}` does not exist.') from exc + else: + return value + + +class BaseSubmissionController(BaseModel): """Controller to submit a maximum number of processes (workflows or calculations) at a given time. This is an abstract base class: you need to subclass it and define the abstract methods. """ - def __init__(self, group_label, max_concurrent): - """Create a new controller to manage (and limit) concurrent submissions. - - :param group_label: a group label: the group will be created at instantiation (if not existing already, - and it will be used to manage the calculations) - :param extra_unique_keys: a tuple or list of keys of extras that are used to uniquely identify - a process in the group. E.g. ('value1', 'value2'). - - :note: try to use actual values that allow for an equality comparison (strings, bools, integers), and avoid - floats, because of truncation errors. - """ - self._group_label = group_label - self._max_concurrent = max_concurrent - - # Create the group if needed - self._group, _ = orm.Group.objects.get_or_create(self.group_label) + group_label: str + """Label of the group to store the process nodes in.""" + max_concurrent: int + """Maximum concurrent active processes.""" - @property - def group_label(self): - """Return the label of the group that is managed by this class.""" - return self._group_label + _validate_group_exists = validator('group_label', + allow_reuse=True)(validate_group_exists) @property def group(self): """Return the AiiDA ORM Group instance that is managed by this class.""" - return self._group - - @property - def max_concurrent(self): - """Value of the maximum number of concurrent processes that can be run.""" - return self._max_concurrent + return orm.Group.objects.get(label=self.group_label) def get_query(self, process_projections, only_active=False): """Return a QueryBuilder object to get all processes in the group associated to this. diff --git a/aiida_submission_controller/from_group.py b/aiida_submission_controller/from_group.py index 7b59d21..f684825 100644 --- a/aiida_submission_controller/from_group.py +++ b/aiida_submission_controller/from_group.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- """A prototype class to submit processes in batches, avoiding to submit too many.""" from aiida import orm -from .base import BaseSubmissionController +from pydantic import validator + +from .base import BaseSubmissionController, validate_group_exists class FromGroupSubmissionController(BaseSubmissionController): # pylint: disable=abstract-method @@ -10,31 +12,16 @@ class FromGroupSubmissionController(BaseSubmissionController): # pylint: disabl This is (still) an abstract base class: you need to subclass it and define the abstract methods. """ - def __init__(self, parent_group_label, *args, **kwargs): - """Create a new controller to manage (and limit) concurrent submissions. - - :param parent_group_label: a group label: the group will be used to decide - which submissions to use. The group must already exist. Extras (in the method - `get_all_extras_to_submit`) will be returned from all extras in that group - (you need to make sure they are unique). + parent_group_label: str + """Label of the parent group from which to construct the process inputs.""" - For all other parameters, see the docstring of ``BaseSubmissionController.__init__``. - """ - super().__init__(*args, **kwargs) - self._parent_group_label = parent_group_label - # Load the group (this also ensures it exists) - self._parent_group = orm.Group.objects.get( - label=self.parent_group_label) - - @property - def parent_group_label(self): - """Return the label of the parent group that is used as a reference.""" - return self._parent_group_label + _validate_group_exists = validator('parent_group_label', + allow_reuse=True)(validate_group_exists) @property def parent_group(self): """Return the AiiDA ORM Group instance of the parent group.""" - return self._parent_group + return orm.Group.objects.get(label=self.parent_group_label) def get_parent_node_from_extras(self, extras_values): """Return the Node instance (in the parent group) from the (unique) extras identifying it.""" diff --git a/examples/add_in_batches.py b/examples/add_in_batches.py index 58400cd..8707cfa 100644 --- a/examples/add_in_batches.py +++ b/examples/add_in_batches.py @@ -1,17 +1,24 @@ # -*- coding: utf-8 -*- """An example of a SubmissionController implementation to compute a 12x12 table of additions.""" +from aiida import orm +from aiida.plugins import CalculationFactory +from pydantic import validator -from aiida import orm, plugins from aiida_submission_controller import BaseSubmissionController class AdditionTableSubmissionController(BaseSubmissionController): """The implementation of a SubmissionController to compute a 12x12 table of additions.""" - def __init__(self, code_name, *args, **kwargs): - """Pass also a code name, that should be a code associated to an `arithmetic.add` plugin.""" - super().__init__(*args, **kwargs) - self._code = orm.load_code(code_name) - self._process_class = plugins.CalculationFactory('arithmetic.add') + code_label: str + """Label of the `code.arithmetic.add` `Code`.""" + @validator('code_label') + def _check_code_plugin(cls, value): + plugin_type = orm.load_code(value).default_calc_job_plugin + if plugin_type == 'core.arithmetic.add': + return value + raise ValueError( + f'Code with label `{value}` has incorrect plugin type: `{plugin_type}`' + ) def get_extra_unique_keys(self): """Return a tuple of the keys of the unique extras that will be used to uniquely identify your workchains. @@ -37,12 +44,13 @@ def get_inputs_and_processclass_from_extras(self, extras_values): I just submit an ArithmeticAdd calculation summing the two values stored in the extras: ``left_operand + right_operand``. """ + code = orm.load_code(self.code_label) inputs = { - 'code': self._code, + 'code': code, 'x': orm.Int(extras_values[0]), 'y': orm.Int(extras_values[1]) } - return inputs, self._process_class + return inputs, CalculationFactory(code.get_input_plugin_name()) def main(): @@ -55,7 +63,7 @@ def main(): ## verdi code setup -L add --on-computer --computer=localhost -P arithmetic.add --remote-abs-path=/bin/bash -n # Create a controller controller = AdditionTableSubmissionController( - code_name='add@localhost', + code_label='add@localhost', group_label='tests/addition_table', max_concurrent=10) diff --git a/pyproject.toml b/pyproject.toml index 21bbf6a..557406e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,9 @@ classifiers = [ requires-python = ">=3.6" dependencies = [ - "aiida-core>=1.0" + "aiida-core>=1.0", + "pydantic~=1.10.4", + "pylint-pydantic~=0.1.8" ] [project.urls] @@ -38,6 +40,12 @@ qe = [ "aiida-quantumespresso" ] +[tool.pylint.master] +load-plugins = "pylint_pydantic" + +[tool.pylint.'MESSAGES CONTROL'] +extension-pkg-whitelist = "pydantic" + [tool.pylint.format] max-line-length = 120