From 557628d3f7aa4b55a64ab57038222f1ffac92d83 Mon Sep 17 00:00:00 2001 From: Marnik Bercx Date: Sat, 22 Apr 2023 14:24:47 +0200 Subject: [PATCH] =?UTF-8?q?=E2=80=BC=EF=B8=8F=20Redesign=20controllers=20a?= =?UTF-8?q?s=20`pydantic`=20`BaseModel`s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Breaking behaviour: In case no `Group` with `group_label` exists, submission controllers created the specified group upon instantiation. This means is the user simply had a typo in the label upon instatiation of the controller, a new (useless) group was created. Here we adapt this behaviour to check for the groups existence instead. This is done by converting the controller classes into `pydantic` `BaseModel`s and using a `validator`. Using `BaseModel` has a few added benefits: 1. Much cleaner constructor specification, that is inherited by sub classes. This becomes especially appreciable for more complex submission controllers. 2. Automatic validation on all constructor inputs. 3. Adds type hinting to all constructor inputs, which makes the classes much more user friendly and facilitates development. 4. `pydantic` `BaseModel`s come with built-in JSON serialization support, which makes it easier to store and recreate submission controllers from JSON files. If we want to set up reproducable HTC infrastructure, this may come in handy. --- .github/workflows/ci.yml | 8 ++-- aiida_submission_controller/base.py | 49 ++++++++++------------- aiida_submission_controller/from_group.py | 29 ++++---------- examples/add_in_batches.py | 26 +++++++----- pyproject.toml | 12 +++++- 5 files changed, 62 insertions(+), 62 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9127e0e..4d4a691 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,16 +7,16 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.11 uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.11 - name: Install python dependencies run: | - pip install pre-commit pylint==2.6.0 + pip install -e.[dev] - name: Run pre-commit run: diff --git a/aiida_submission_controller/base.py b/aiida_submission_controller/base.py index 60ef8a6..8a70be5 100644 --- a/aiida_submission_controller/base.py +++ b/aiida_submission_controller/base.py @@ -2,47 +2,42 @@ """A prototype class to submit processes in batches, avoiding to submit too many.""" import abc import logging + from aiida import engine, orm +from aiida.common import NotExistent +from pydantic import BaseModel, validator CMDLINE_LOGGER = logging.getLogger('verdi') -class BaseSubmissionController: +def validate_group_exists(value: str) -> str: + """Validator that makes sure the ``Group`` with the provided label exists.""" + try: + orm.Group.collection.get(label=value) + except NotExistent as exc: + raise ValueError( + f'Group with label `{value}` does not exist.') from exc + else: + return value + + +class BaseSubmissionController(BaseModel): """Controller to submit a maximum number of processes (workflows or calculations) at a given time. This is an abstract base class: you need to subclass it and define the abstract methods. """ - def __init__(self, group_label, max_concurrent): - """Create a new controller to manage (and limit) concurrent submissions. - - :param group_label: a group label: the group will be created at instantiation (if not existing already, - and it will be used to manage the calculations) - :param extra_unique_keys: a tuple or list of keys of extras that are used to uniquely identify - a process in the group. E.g. ('value1', 'value2'). - - :note: try to use actual values that allow for an equality comparison (strings, bools, integers), and avoid - floats, because of truncation errors. - """ - self._group_label = group_label - self._max_concurrent = max_concurrent - - # Create the group if needed - self._group, _ = orm.Group.objects.get_or_create(self.group_label) + group_label: str + """Label of the group to store the process nodes in.""" + max_concurrent: int + """Maximum concurrent active processes.""" - @property - def group_label(self): - """Return the label of the group that is managed by this class.""" - return self._group_label + _validate_group_exists = validator('group_label', + allow_reuse=True)(validate_group_exists) @property def group(self): """Return the AiiDA ORM Group instance that is managed by this class.""" - return self._group - - @property - def max_concurrent(self): - """Value of the maximum number of concurrent processes that can be run.""" - return self._max_concurrent + return orm.Group.objects.get(label=self.group_label) def get_query(self, process_projections, only_active=False): """Return a QueryBuilder object to get all processes in the group associated to this. diff --git a/aiida_submission_controller/from_group.py b/aiida_submission_controller/from_group.py index 7b59d21..f684825 100644 --- a/aiida_submission_controller/from_group.py +++ b/aiida_submission_controller/from_group.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- """A prototype class to submit processes in batches, avoiding to submit too many.""" from aiida import orm -from .base import BaseSubmissionController +from pydantic import validator + +from .base import BaseSubmissionController, validate_group_exists class FromGroupSubmissionController(BaseSubmissionController): # pylint: disable=abstract-method @@ -10,31 +12,16 @@ class FromGroupSubmissionController(BaseSubmissionController): # pylint: disabl This is (still) an abstract base class: you need to subclass it and define the abstract methods. """ - def __init__(self, parent_group_label, *args, **kwargs): - """Create a new controller to manage (and limit) concurrent submissions. - - :param parent_group_label: a group label: the group will be used to decide - which submissions to use. The group must already exist. Extras (in the method - `get_all_extras_to_submit`) will be returned from all extras in that group - (you need to make sure they are unique). + parent_group_label: str + """Label of the parent group from which to construct the process inputs.""" - For all other parameters, see the docstring of ``BaseSubmissionController.__init__``. - """ - super().__init__(*args, **kwargs) - self._parent_group_label = parent_group_label - # Load the group (this also ensures it exists) - self._parent_group = orm.Group.objects.get( - label=self.parent_group_label) - - @property - def parent_group_label(self): - """Return the label of the parent group that is used as a reference.""" - return self._parent_group_label + _validate_group_exists = validator('parent_group_label', + allow_reuse=True)(validate_group_exists) @property def parent_group(self): """Return the AiiDA ORM Group instance of the parent group.""" - return self._parent_group + return orm.Group.objects.get(label=self.parent_group_label) def get_parent_node_from_extras(self, extras_values): """Return the Node instance (in the parent group) from the (unique) extras identifying it.""" diff --git a/examples/add_in_batches.py b/examples/add_in_batches.py index 58400cd..8707cfa 100644 --- a/examples/add_in_batches.py +++ b/examples/add_in_batches.py @@ -1,17 +1,24 @@ # -*- coding: utf-8 -*- """An example of a SubmissionController implementation to compute a 12x12 table of additions.""" +from aiida import orm +from aiida.plugins import CalculationFactory +from pydantic import validator -from aiida import orm, plugins from aiida_submission_controller import BaseSubmissionController class AdditionTableSubmissionController(BaseSubmissionController): """The implementation of a SubmissionController to compute a 12x12 table of additions.""" - def __init__(self, code_name, *args, **kwargs): - """Pass also a code name, that should be a code associated to an `arithmetic.add` plugin.""" - super().__init__(*args, **kwargs) - self._code = orm.load_code(code_name) - self._process_class = plugins.CalculationFactory('arithmetic.add') + code_label: str + """Label of the `code.arithmetic.add` `Code`.""" + @validator('code_label') + def _check_code_plugin(cls, value): + plugin_type = orm.load_code(value).default_calc_job_plugin + if plugin_type == 'core.arithmetic.add': + return value + raise ValueError( + f'Code with label `{value}` has incorrect plugin type: `{plugin_type}`' + ) def get_extra_unique_keys(self): """Return a tuple of the keys of the unique extras that will be used to uniquely identify your workchains. @@ -37,12 +44,13 @@ def get_inputs_and_processclass_from_extras(self, extras_values): I just submit an ArithmeticAdd calculation summing the two values stored in the extras: ``left_operand + right_operand``. """ + code = orm.load_code(self.code_label) inputs = { - 'code': self._code, + 'code': code, 'x': orm.Int(extras_values[0]), 'y': orm.Int(extras_values[1]) } - return inputs, self._process_class + return inputs, CalculationFactory(code.get_input_plugin_name()) def main(): @@ -55,7 +63,7 @@ def main(): ## verdi code setup -L add --on-computer --computer=localhost -P arithmetic.add --remote-abs-path=/bin/bash -n # Create a controller controller = AdditionTableSubmissionController( - code_name='add@localhost', + code_label='add@localhost', group_label='tests/addition_table', max_concurrent=10) diff --git a/pyproject.toml b/pyproject.toml index 21bbf6a..a51e2fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,8 @@ classifiers = [ requires-python = ">=3.6" dependencies = [ - "aiida-core>=1.0" + "aiida-core>=1.0", + "pydantic~=1.10.4", ] [project.urls] @@ -37,6 +38,15 @@ Source = "https://github.com/aiidateam/aiida-submission-controller" qe = [ "aiida-quantumespresso" ] +dev = [ + "pylint-pydantic~=0.1.8" +] + +[tool.pylint.master] +load-plugins = "pylint_pydantic" + +[tool.pylint.'MESSAGES CONTROL'] +extension-pkg-whitelist = "pydantic" [tool.pylint.format] max-line-length = 120