Skip to content

Commit

Permalink
Protocols: Move all static work chain inputs to protocol (#902)
Browse files Browse the repository at this point in the history
Some "default" inputs are set inside one of the steps of several work 
chains, which:

1. Can be confusing for the user, since the expected inputs are not
   there in the input files.
2. Can be frustrating for the user when a different value is desirable,
   for a use case that may not immediately be obvious.
3. Means default values are specified _both_ in the work chain logic
   and protocol, making it more difficult to get a clear overview of the
   input parameters.

Here we move the specification of all static default inputs to the
corresponding protocol file. Additionally:

* Since the default protocol now correctly sets the calculation type to
  `bands`, the `validate_inputs` validator of the `PwCalculation` will
  raise a warning because a `parent_folder` has not been initially
  provided. Hence, we set the `validate_inputs_base` validator for the
  `pw` port of the `bands` namespace, as is also done for e.g. the
  `nscf` of the `PdosWorkchain`.
* We fix a bug in the `base_final_scf` part of the `PwRelaxWorkChain`
  protocol, that erroneously specified `CELL.press_conv_thr`.
* We move the logic in `PwBaseWorkChain.set_max_seconds` to
  `PwBaseWorkChain.prepare_process`. Having a separate method for these
  few lines of code seemed unnecessary, especially since it's the only
  step that is executed in `prepare_process`.
* We remove the code added in 0aba276
  that automatically set `RestartType.FULL` in case the `parent_folder`
  input is provided for the `PwBaseWorkChain`. There are simply too many
  different use cases of restarting I think it's difficult if not
  impossible to consider all of them, and in several known use cases
  this code addition would cause more harm than good.
  • Loading branch information
mbercx authored Apr 18, 2023
1 parent 9291f84 commit 01f1470
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 51 deletions.
15 changes: 15 additions & 0 deletions src/aiida_quantumespresso/workflows/protocols/pw/bands.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@ default_inputs:
bands_kpoints_distance: 0.025
clean_workdir: False
nbands_factor: 3.0
scf:
pw:
parameters:
CONTROL:
calculation: scf
bands:
pw:
parameters:
CONTROL:
calculation: bands
restart_mode: from_scratch
ELECTRONS:
diagonalization: paro
diago_full_acc: True
startingpot: file
default_protocol: moderate
protocols:
moderate:
Expand Down
4 changes: 2 additions & 2 deletions src/aiida_quantumespresso/workflows/protocols/pw/relax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ default_inputs:
base_final_scf:
pw:
parameters:
CELL:
press_conv_thr: 0.5
CONTROL:
calculation: scf
default_protocol: moderate
protocols:
moderate:
Expand Down
18 changes: 5 additions & 13 deletions src/aiida_quantumespresso/workflows/pw/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from aiida import orm
from aiida.common import AttributeDict
from aiida.engine import ToContext, WorkChain, if_
from aiida.plugins import WorkflowFactory

from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis
from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain

from ..protocols.utils import ProtocolMixin

PwBaseWorkChain = WorkflowFactory('quantumespresso.pw.base')
PwRelaxWorkChain = WorkflowFactory('quantumespresso.pw.relax')


def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument
"""Validate the inputs of the entire input namespace."""
Expand Down Expand Up @@ -72,6 +71,7 @@ def define(cls, spec):
help='Explicit kpoints to use for the BANDS calculation. Specify either this or `bands_kpoints_distance`.')
spec.input('bands_kpoints_distance', valid_type=orm.Float, required=False,
help='Minimum kpoints distance for the BANDS calculation. Specify either this or `bands_kpoints`.')
spec.inputs['bands']['pw'].validator = PwCalculation.validate_inputs_base
spec.inputs.validator = validate_inputs
spec.outline(
cls.setup,
Expand Down Expand Up @@ -228,13 +228,12 @@ def run_scf(self):
inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace='scf'))
inputs.metadata.call_link_label = 'scf'
inputs.pw.structure = self.ctx.current_structure
inputs.pw.parameters = inputs.pw.parameters.get_dict()
inputs.pw.parameters.setdefault('CONTROL', {})['calculation'] = 'scf'

# Make sure to carry the number of bands from the relax workchain if it was run and it wasn't explicitly defined
# in the inputs. One of the base workchains in the relax workchain may have changed the number automatically in
# the sanity checks on band occupations.
if self.ctx.current_number_of_bands:
inputs.pw.parameters = inputs.pw.parameters.get_dict()
inputs.pw.parameters.setdefault('SYSTEM', {}).setdefault('nbnd', self.ctx.current_number_of_bands)

inputs = prepare_process_inputs(PwBaseWorkChain, inputs)
Expand Down Expand Up @@ -267,13 +266,6 @@ def run_bands(self):
inputs.pw.parameters.setdefault('SYSTEM', {})
inputs.pw.parameters.setdefault('ELECTRONS', {})

# The following flags always have to be set in the parameters, regardless of what caller specified in the inputs
inputs.pw.parameters['CONTROL']['calculation'] = 'bands'

# Only set the following parameters if not directly explicitly defined in the inputs
inputs.pw.parameters['ELECTRONS'].setdefault('diagonalization', 'cg')
inputs.pw.parameters['ELECTRONS'].setdefault('diago_full_acc', True)

# If `nbands_factor` is defined in the inputs we set the `nbnd` parameter
if 'nbands_factor' in self.inputs:
factor = self.inputs.nbands_factor.value
Expand Down
20 changes: 2 additions & 18 deletions src/aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,9 @@ def setup(self):
self.ctx.inputs.parameters.setdefault('CONTROL', {})
self.ctx.inputs.parameters.setdefault('ELECTRONS', {})
self.ctx.inputs.parameters.setdefault('SYSTEM', {})
self.ctx.inputs.parameters['CONTROL'].setdefault('calculation', 'scf')

self.ctx.inputs.settings = self.ctx.inputs.settings.get_dict() if 'settings' in self.ctx.inputs else {}

# If a ``parent_folder`` is specified, automatically set the parameters for a ``RestartType.Full`` unless the
# ``CONTROL.restart_mode`` has explicitly been set to ``from_scratch``. In that case, the user most likely set
# that, and we do not want to override it.
restart_mode = self.ctx.inputs.parameters['CONTROL'].get('restart_mode', None)
if 'parent_folder' in self.ctx.inputs and restart_mode != 'from_scratch':
self.set_restart_type(RestartType.FULL, self.ctx.inputs.parent_folder)

def validate_kpoints(self):
"""Validate the inputs related to k-points.
Expand All @@ -280,15 +272,6 @@ def validate_kpoints(self):

self.ctx.inputs.kpoints = kpoints

def set_max_seconds(self, max_wallclock_seconds):
"""Set the `max_seconds` to a fraction of `max_wallclock_seconds` option to prevent out-of-walltime problems.
:param max_wallclock_seconds: the maximum wallclock time that will be set in the scheduler settings.
"""
max_seconds_factor = self.defaults.delta_factor_max_seconds
max_seconds = max_wallclock_seconds * max_seconds_factor
self.ctx.inputs.parameters['CONTROL']['max_seconds'] = max_seconds

def set_restart_type(self, restart_type, parent_folder=None):
"""Set the restart type for the next iteration."""

Expand Down Expand Up @@ -324,7 +307,8 @@ def prepare_process(self):
max_wallclock_seconds = self.ctx.inputs.metadata.options.get('max_wallclock_seconds', None)

if max_wallclock_seconds is not None and 'max_seconds' not in self.ctx.inputs.parameters['CONTROL']:
self.set_max_seconds(max_wallclock_seconds)
max_seconds = max_wallclock_seconds * self.defaults.delta_factor_max_seconds
self.ctx.inputs.parameters['CONTROL']['max_seconds'] = max_seconds

def report_error_handled(self, calculation, action):
"""Report an action taken for a calculation that has failed.
Expand Down
8 changes: 2 additions & 6 deletions src/aiida_quantumespresso/workflows/pw/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,11 @@ def setup(self):
self.ctx.relax_inputs.pw.parameters = self.ctx.relax_inputs.pw.parameters.get_dict()

self.ctx.relax_inputs.pw.parameters.setdefault('CONTROL', {})
self.ctx.relax_inputs.pw.parameters['CONTROL']['restart_mode'] = 'from_scratch'

# Set the meta_convergence and add it to the context
self.ctx.meta_convergence = self.inputs.meta_convergence.value
volume_cannot_change = (
self.ctx.relax_inputs.pw.parameters['CONTROL']['calculation'] in ('scf', 'relax') or
self.ctx.relax_inputs.pw.parameters['CONTROL'].get('calculation', 'scf') in ('scf', 'relax') or
self.ctx.relax_inputs.pw.parameters.get('CELL', {}).get('cell_dofree', None) == 'shape'
)
if self.ctx.meta_convergence and volume_cannot_change:
Expand All @@ -197,7 +196,7 @@ def setup(self):
if 'base_final_scf' in self.inputs:
self.ctx.final_scf_inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace='base_final_scf'))

if self.ctx.relax_inputs.pw.parameters['CONTROL']['calculation'] == 'scf':
if self.ctx.relax_inputs.pw.parameters['CONTROL'].get('calculation', 'scf') == 'scf':
self.report(
'Work chain will not run final SCF when `calculation` is set to `scf` for the relaxation '
'`PwBaseWorkChain`.'
Expand All @@ -208,9 +207,6 @@ def setup(self):
self.ctx.final_scf_inputs.pw.parameters = self.ctx.final_scf_inputs.pw.parameters.get_dict()

self.ctx.final_scf_inputs.pw.parameters.setdefault('CONTROL', {})
self.ctx.final_scf_inputs.pw.parameters['CONTROL']['calculation'] = 'scf'
self.ctx.final_scf_inputs.pw.parameters['CONTROL']['restart_mode'] = 'from_scratch'
self.ctx.final_scf_inputs.pw.parameters.pop('CELL', None)
self.ctx.final_scf_inputs.metadata.call_link_label = 'final_scf'

def should_run_relax(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/workflows/protocols/pw/test_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_options(fixture_code, generate_structure):

for subspace in (
builder.relax.base.pw.metadata,
builder.scf.pw.metadata,
builder.bands.pw.metadata,
builder.scf.pw.metadata, # pylint: disable=no-member
builder.bands.pw.metadata, # pylint: disable=no-member
):
assert subspace['options']['queue_name'] == queue_name, subspace
6 changes: 5 additions & 1 deletion tests/workflows/protocols/pw/test_bands/test_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@ bands:
withmpi: true
parameters:
CONTROL:
calculation: scf
calculation: bands
etot_conv_thr: 2.0e-05
forc_conv_thr: 0.0001
restart_mode: from_scratch
tprnfor: true
tstress: true
ELECTRONS:
conv_thr: 4.0e-10
diago_full_acc: true
diagonalization: paro
electron_maxstep: 80
mixing_beta: 0.4
startingpot: file
SYSTEM:
degauss: 0.01
ecutrho: 240.0
Expand Down
2 changes: 0 additions & 2 deletions tests/workflows/protocols/pw/test_relax/test_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ base_final_scf:
num_machines: 1
withmpi: true
parameters:
CELL:
press_conv_thr: 0.5
CONTROL:
calculation: scf
etot_conv_thr: 2.0e-05
Expand Down
10 changes: 3 additions & 7 deletions tests/workflows/pw/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,11 @@ def test_set_max_seconds(generate_workchain_pw):


@pytest.mark.parametrize('restart_mode, expected', (
(None, 'restart'),
('restart', 'restart'),
('from_scratch', 'from_scratch'),
))
def test_parent_folder(generate_workchain_pw, generate_calc_job_node, restart_mode, expected):
"""Test that ``parameters`` gets automatically updated if ``parent_folder`` in the inputs.
Specifically, the ``parameters`` should define the ``CONTROL.restart_mode`` unless it was explicitly set to
``from_scratch`` by the caller.
"""
def test_restart_mode(generate_workchain_pw, generate_calc_job_node, restart_mode, expected):
"""Test that the ``CONTROL.restart_mode`` specified by the user is always respected."""
node = generate_calc_job_node('pw', test_name='default')

inputs = generate_workchain_pw(return_inputs=True)
Expand Down

0 comments on commit 01f1470

Please sign in to comment.