Skip to content

Commit

Permalink
‼️ Redesign controllers as pydantic BaseModels
Browse files Browse the repository at this point in the history
Breaking behaviour: In case no `Group` with `group_label` exists,
submission controllers created the specified group upon instantiation.
This means is the user simply had a typo in the label upon instatiation
of the controller, a new (useless) group was created. Here we adapt
this behaviour to check for the groups existence instead.

This is done by converting the controller classes into `pydantic`
`BaseModel`s and using a `validator`. Using `BaseModel` has a few added
benefits:

1. Much cleaner constructor specification, that is inherited by sub
   classes. This becomes especially appreciable for more complex submission
   controllers.
2. Automatic validation on all constructor inputs.
3. Adds type hinting to all constructor inputs, which makes the classes
   much more user friendly and facilitates development.
4. `pydantic` `BaseModel`s come with built-in JSON serialization
   support, which makes it easier to store and recreate submission
   controllers from JSON files. If we want to set up reproducable HTC
   infrastructure, this may come in handy.
  • Loading branch information
mbercx committed Apr 22, 2023
1 parent 5302078 commit 8cc3f3c
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 62 deletions.
16 changes: 12 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,24 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v3

- name: Set up Python 3.8
- name: Cache Python dependencies
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: pip-pre-commit-${{ hashFiles('**/setup.json') }}
restore-keys:
pip-pre-commit-

- name: Set up Python 3.11
uses: actions/setup-python@v1
with:
python-version: 3.8
python-version: 3.11

- name: Install python dependencies
run: |
pip install pre-commit pylint==2.6.0
pip install -e.[dev]
- name: Run pre-commit
run:
Expand Down
49 changes: 22 additions & 27 deletions aiida_submission_controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,42 @@
"""A prototype class to submit processes in batches, avoiding to submit too many."""
import abc
import logging

from aiida import engine, orm
from aiida.common import NotExistent
from pydantic import BaseModel, validator

CMDLINE_LOGGER = logging.getLogger('verdi')


class BaseSubmissionController:
def validate_group_exists(value: str) -> str:
"""Validator that makes sure the ``Group`` with the provided label exists."""
try:
orm.Group.collection.get(label=value)
except NotExistent as exc:
raise ValueError(
f'Group with label `{value}` does not exist.') from exc
else:
return value


class BaseSubmissionController(BaseModel):
"""Controller to submit a maximum number of processes (workflows or calculations) at a given time.
This is an abstract base class: you need to subclass it and define the abstract methods.
"""
def __init__(self, group_label, max_concurrent):
"""Create a new controller to manage (and limit) concurrent submissions.
:param group_label: a group label: the group will be created at instantiation (if not existing already,
and it will be used to manage the calculations)
:param extra_unique_keys: a tuple or list of keys of extras that are used to uniquely identify
a process in the group. E.g. ('value1', 'value2').
:note: try to use actual values that allow for an equality comparison (strings, bools, integers), and avoid
floats, because of truncation errors.
"""
self._group_label = group_label
self._max_concurrent = max_concurrent

# Create the group if needed
self._group, _ = orm.Group.objects.get_or_create(self.group_label)
group_label: str
"""Label of the group to store the process nodes in."""
max_concurrent: int
"""Maximum concurrent active processes."""

@property
def group_label(self):
"""Return the label of the group that is managed by this class."""
return self._group_label
_validate_group_exists = validator('group_label',
allow_reuse=True)(validate_group_exists)

@property
def group(self):
"""Return the AiiDA ORM Group instance that is managed by this class."""
return self._group

@property
def max_concurrent(self):
"""Value of the maximum number of concurrent processes that can be run."""
return self._max_concurrent
return orm.Group.objects.get(label=self.group_label)

def get_query(self, process_projections, only_active=False):
"""Return a QueryBuilder object to get all processes in the group associated to this.
Expand Down
29 changes: 8 additions & 21 deletions aiida_submission_controller/from_group.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-
"""A prototype class to submit processes in batches, avoiding to submit too many."""
from aiida import orm
from .base import BaseSubmissionController
from pydantic import validator

from .base import BaseSubmissionController, validate_group_exists


class FromGroupSubmissionController(BaseSubmissionController): # pylint: disable=abstract-method
Expand All @@ -10,31 +12,16 @@ class FromGroupSubmissionController(BaseSubmissionController): # pylint: disabl
This is (still) an abstract base class: you need to subclass it
and define the abstract methods.
"""
def __init__(self, parent_group_label, *args, **kwargs):
"""Create a new controller to manage (and limit) concurrent submissions.
:param parent_group_label: a group label: the group will be used to decide
which submissions to use. The group must already exist. Extras (in the method
`get_all_extras_to_submit`) will be returned from all extras in that group
(you need to make sure they are unique).
parent_group_label: str
"""Label of the parent group from which to construct the process inputs."""

For all other parameters, see the docstring of ``BaseSubmissionController.__init__``.
"""
super().__init__(*args, **kwargs)
self._parent_group_label = parent_group_label
# Load the group (this also ensures it exists)
self._parent_group = orm.Group.objects.get(
label=self.parent_group_label)

@property
def parent_group_label(self):
"""Return the label of the parent group that is used as a reference."""
return self._parent_group_label
_validate_group_exists = validator('parent_group_label',
allow_reuse=True)(validate_group_exists)

@property
def parent_group(self):
"""Return the AiiDA ORM Group instance of the parent group."""
return self._parent_group
return orm.Group.objects.get(label=self.parent_group_label)

def get_parent_node_from_extras(self, extras_values):
"""Return the Node instance (in the parent group) from the (unique) extras identifying it."""
Expand Down
26 changes: 17 additions & 9 deletions examples/add_in_batches.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# -*- coding: utf-8 -*-
"""An example of a SubmissionController implementation to compute a 12x12 table of additions."""
from aiida import orm
from aiida.plugins import CalculationFactory
from pydantic import validator

from aiida import orm, plugins
from aiida_submission_controller import BaseSubmissionController


class AdditionTableSubmissionController(BaseSubmissionController):
"""The implementation of a SubmissionController to compute a 12x12 table of additions."""
def __init__(self, code_name, *args, **kwargs):
"""Pass also a code name, that should be a code associated to an `arithmetic.add` plugin."""
super().__init__(*args, **kwargs)
self._code = orm.load_code(code_name)
self._process_class = plugins.CalculationFactory('arithmetic.add')
code_label: str
"""Label of the `code.arithmetic.add` `Code`."""
@validator('code_label')
def _check_code_plugin(cls, value):
plugin_type = orm.load_code(value).default_calc_job_plugin
if plugin_type == 'core.arithmetic.add':
return value
raise ValueError(
f'Code with label `{value}` has incorrect plugin type: `{plugin_type}`'
)

def get_extra_unique_keys(self):
"""Return a tuple of the keys of the unique extras that will be used to uniquely identify your workchains.
Expand All @@ -37,12 +44,13 @@ def get_inputs_and_processclass_from_extras(self, extras_values):
I just submit an ArithmeticAdd calculation summing the two values stored in the extras:
``left_operand + right_operand``.
"""
code = orm.load_code(self.code_label)
inputs = {
'code': self._code,
'code': code,
'x': orm.Int(extras_values[0]),
'y': orm.Int(extras_values[1])
}
return inputs, self._process_class
return inputs, CalculationFactory(code.get_input_plugin_name())


def main():
Expand All @@ -55,7 +63,7 @@ def main():
## verdi code setup -L add --on-computer --computer=localhost -P arithmetic.add --remote-abs-path=/bin/bash -n
# Create a controller
controller = AdditionTableSubmissionController(
code_name='add@localhost',
code_label='add@localhost',
group_label='tests/addition_table',
max_concurrent=10)

Expand Down
13 changes: 12 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ classifiers = [
requires-python = ">=3.6"

dependencies = [
"aiida-core>=1.0"
"aiida-core>=1.0",
"pydantic~=1.10.4",
]

[project.urls]
Expand All @@ -37,6 +38,16 @@ Source = "https://github.com/aiidateam/aiida-submission-controller"
qe = [
"aiida-quantumespresso"
]
dev = [
"pre-commit~=2.17.0",
"pylint-pydantic~=0.1.8"
]

[tool.pylint.master]
load-plugins = "pylint_pydantic"

[tool.pylint.'MESSAGES CONTROL']
extension-pkg-whitelist = "pydantic"

[tool.pylint.format]
max-line-length = 120
Expand Down

0 comments on commit 8cc3f3c

Please sign in to comment.