diff --git a/src/aiida_quantumespresso/workflows/protocols/pw/bands.yaml b/src/aiida_quantumespresso/workflows/protocols/pw/bands.yaml index 69981d1f9..fe4ab2d0c 100644 --- a/src/aiida_quantumespresso/workflows/protocols/pw/bands.yaml +++ b/src/aiida_quantumespresso/workflows/protocols/pw/bands.yaml @@ -2,6 +2,21 @@ default_inputs: bands_kpoints_distance: 0.025 clean_workdir: True nbands_factor: 3.0 + scf: + pw: + parameters: + CONTROL: + calculation: scf + bands: + pw: + parameters: + CONTROL: + calculation: bands + restart_mode: from_scratch + ELECTRONS: + diagonalization: paro + diago_full_acc: True + startingpot: file default_protocol: moderate protocols: moderate: diff --git a/src/aiida_quantumespresso/workflows/pw/bands.py b/src/aiida_quantumespresso/workflows/pw/bands.py index b5e0715b8..275b1cce2 100644 --- a/src/aiida_quantumespresso/workflows/pw/bands.py +++ b/src/aiida_quantumespresso/workflows/pw/bands.py @@ -3,16 +3,15 @@ from aiida import orm from aiida.common import AttributeDict from aiida.engine import ToContext, WorkChain, if_ -from aiida.plugins import WorkflowFactory from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis +from aiida_quantumespresso.calculations.pw import PwCalculation from aiida_quantumespresso.utils.mapping import prepare_process_inputs +from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain +from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain from ..protocols.utils import ProtocolMixin -PwBaseWorkChain = WorkflowFactory('quantumespresso.pw.base') -PwRelaxWorkChain = WorkflowFactory('quantumespresso.pw.relax') - def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument """Validate the inputs of the entire input namespace.""" @@ -72,6 +71,7 @@ def define(cls, spec): help='Explicit kpoints to use for the BANDS calculation. Specify either this or `bands_kpoints_distance`.') spec.input('bands_kpoints_distance', valid_type=orm.Float, required=False, help='Minimum kpoints distance for the BANDS calculation. Specify either this or `bands_kpoints`.') + spec.inputs['bands']['pw'].validator = PwCalculation.validate_inputs_base spec.inputs.validator = validate_inputs spec.outline( cls.setup, @@ -229,7 +229,6 @@ def run_scf(self): inputs.metadata.call_link_label = 'scf' inputs.pw.structure = self.ctx.current_structure inputs.pw.parameters = inputs.pw.parameters.get_dict() - inputs.pw.parameters.setdefault('CONTROL', {})['calculation'] = 'scf' # Make sure to carry the number of bands from the relax workchain if it was run and it wasn't explicitly defined # in the inputs. One of the base workchains in the relax workchain may have changed the number automatically in @@ -267,13 +266,6 @@ def run_bands(self): inputs.pw.parameters.setdefault('SYSTEM', {}) inputs.pw.parameters.setdefault('ELECTRONS', {}) - # The following flags always have to be set in the parameters, regardless of what caller specified in the inputs - inputs.pw.parameters['CONTROL']['calculation'] = 'bands' - - # Only set the following parameters if not directly explicitly defined in the inputs - inputs.pw.parameters['ELECTRONS'].setdefault('diagonalization', 'cg') - inputs.pw.parameters['ELECTRONS'].setdefault('diago_full_acc', True) - # If `nbands_factor` is defined in the inputs we set the `nbnd` parameter if 'nbands_factor' in self.inputs: factor = self.inputs.nbands_factor.value diff --git a/tests/workflows/protocols/pw/test_bands.py b/tests/workflows/protocols/pw/test_bands.py index 65d172bdd..8a5fa4283 100644 --- a/tests/workflows/protocols/pw/test_bands.py +++ b/tests/workflows/protocols/pw/test_bands.py @@ -87,7 +87,7 @@ def test_options(fixture_code, generate_structure): for subspace in ( builder.relax.base.pw.metadata, - builder.scf.pw.metadata, - builder.bands.pw.metadata, + builder.scf.pw.metadata, # pylint: disable=no-member + builder.bands.pw.metadata, # pylint: disable=no-member ): assert subspace['options']['queue_name'] == queue_name, subspace diff --git a/tests/workflows/protocols/pw/test_bands/test_default.yml b/tests/workflows/protocols/pw/test_bands/test_default.yml index ac2896f26..fd545229e 100644 --- a/tests/workflows/protocols/pw/test_bands/test_default.yml +++ b/tests/workflows/protocols/pw/test_bands/test_default.yml @@ -9,13 +9,15 @@ bands: withmpi: true parameters: CONTROL: - calculation: scf + calculation: bands etot_conv_thr: 2.0e-05 forc_conv_thr: 0.0001 tprnfor: true tstress: true ELECTRONS: conv_thr: 4.0e-10 + diago_full_acc: true + diagonalization: paro electron_maxstep: 80 mixing_beta: 0.4 SYSTEM: