Skip to content

Commit

Permalink
PwBandsWorkChain: Move inputs to protocol
Browse files Browse the repository at this point in the history
Several "default" inputs of the `PwBandsWorkChain` are set inside one of the steps of the work chain, which:

1. Can be confusing for the user, since the expected inputs are not there in the input files.
2. Can be frustrating for the user when a different value is desirable, for a use case that may not immediately be obvious.
3. Means default values are specified _both_ in the work chain logic and protocol, making it more difficult to get a clear overview of the input parameters.

Here we move the specification of the default inputs to the protocol file of the `PwBandsWorkChain` (`bands.yaml`).

Since the default protocol now correctly sets the calculation type to `bands`, the `validate_inputs` validator of the `PwCalculation` will raise a warning because a `parent_folder` has not been initially provided. Hence, we set the `validate_inputs_base` validator for the `pw` port of the `bands` namespace, as is also done for e.g. the `nscf` of the `PdosWorkchain`.
  • Loading branch information
mbercx committed Apr 6, 2023
1 parent 7f53c96 commit 4a61d66
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 15 deletions.
13 changes: 13 additions & 0 deletions src/aiida_quantumespresso/workflows/protocols/pw/bands.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@ default_inputs:
bands_kpoints_distance: 0.025
clean_workdir: True
nbands_factor: 3.0
scf:
pw:
parameters:
CONTROL:
calculation: scf
bands:
pw:
parameters:
CONTROL:
calculation: bands
ELECTRONS:
diagonalization: paro
diago_full_acc: True
default_protocol: moderate
protocols:
moderate:
Expand Down
16 changes: 4 additions & 12 deletions src/aiida_quantumespresso/workflows/pw/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from aiida import orm
from aiida.common import AttributeDict
from aiida.engine import ToContext, WorkChain, if_
from aiida.plugins import WorkflowFactory

from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis
from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain

from ..protocols.utils import ProtocolMixin

PwBaseWorkChain = WorkflowFactory('quantumespresso.pw.base')
PwRelaxWorkChain = WorkflowFactory('quantumespresso.pw.relax')


def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument
"""Validate the inputs of the entire input namespace."""
Expand Down Expand Up @@ -72,6 +71,7 @@ def define(cls, spec):
help='Explicit kpoints to use for the BANDS calculation. Specify either this or `bands_kpoints_distance`.')
spec.input('bands_kpoints_distance', valid_type=orm.Float, required=False,
help='Minimum kpoints distance for the BANDS calculation. Specify either this or `bands_kpoints`.')
spec.inputs['bands']['pw'].validator = PwCalculation.validate_inputs_base
spec.inputs.validator = validate_inputs
spec.outline(
cls.setup,
Expand Down Expand Up @@ -229,7 +229,6 @@ def run_scf(self):
inputs.metadata.call_link_label = 'scf'
inputs.pw.structure = self.ctx.current_structure
inputs.pw.parameters = inputs.pw.parameters.get_dict()
inputs.pw.parameters.setdefault('CONTROL', {})['calculation'] = 'scf'

# Make sure to carry the number of bands from the relax workchain if it was run and it wasn't explicitly defined
# in the inputs. One of the base workchains in the relax workchain may have changed the number automatically in
Expand Down Expand Up @@ -267,13 +266,6 @@ def run_bands(self):
inputs.pw.parameters.setdefault('SYSTEM', {})
inputs.pw.parameters.setdefault('ELECTRONS', {})

# The following flags always have to be set in the parameters, regardless of what caller specified in the inputs
inputs.pw.parameters['CONTROL']['calculation'] = 'bands'

# Only set the following parameters if not directly explicitly defined in the inputs
inputs.pw.parameters['ELECTRONS'].setdefault('diagonalization', 'cg')
inputs.pw.parameters['ELECTRONS'].setdefault('diago_full_acc', True)

# If `nbands_factor` is defined in the inputs we set the `nbnd` parameter
if 'nbands_factor' in self.inputs:
factor = self.inputs.nbands_factor.value
Expand Down
4 changes: 2 additions & 2 deletions tests/workflows/protocols/pw/test_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_options(fixture_code, generate_structure):

for subspace in (
builder.relax.base.pw.metadata,
builder.scf.pw.metadata,
builder.bands.pw.metadata,
builder.scf.pw.metadata, # pylint: disable=no-member
builder.bands.pw.metadata, # pylint: disable=no-member
):
assert subspace['options']['queue_name'] == queue_name, subspace
4 changes: 3 additions & 1 deletion tests/workflows/protocols/pw/test_bands/test_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ bands:
withmpi: true
parameters:
CONTROL:
calculation: scf
calculation: bands
etot_conv_thr: 2.0e-05
forc_conv_thr: 0.0001
tprnfor: true
tstress: true
ELECTRONS:
conv_thr: 4.0e-10
diago_full_acc: true
diagonalization: paro
electron_maxstep: 80
mixing_beta: 0.4
SYSTEM:
Expand Down

0 comments on commit 4a61d66

Please sign in to comment.