diff --git a/src/aiida_quantumespresso/workflows/protocols/pw/bands.yaml b/src/aiida_quantumespresso/workflows/protocols/pw/bands.yaml index 4053d413b..f7f987947 100644 --- a/src/aiida_quantumespresso/workflows/protocols/pw/bands.yaml +++ b/src/aiida_quantumespresso/workflows/protocols/pw/bands.yaml @@ -2,6 +2,21 @@ default_inputs: bands_kpoints_distance: 0.025 clean_workdir: False nbands_factor: 3.0 + scf: + pw: + parameters: + CONTROL: + calculation: scf + bands: + pw: + parameters: + CONTROL: + calculation: bands + restart_mode: from_scratch + ELECTRONS: + diagonalization: paro + diago_full_acc: True + startingpot: file default_protocol: moderate protocols: moderate: diff --git a/src/aiida_quantumespresso/workflows/protocols/pw/relax.yaml b/src/aiida_quantumespresso/workflows/protocols/pw/relax.yaml index c7a151e5c..5a670967c 100644 --- a/src/aiida_quantumespresso/workflows/protocols/pw/relax.yaml +++ b/src/aiida_quantumespresso/workflows/protocols/pw/relax.yaml @@ -11,8 +11,8 @@ default_inputs: base_final_scf: pw: parameters: - CELL: - press_conv_thr: 0.5 + CONTROL: + calculation: scf default_protocol: moderate protocols: moderate: diff --git a/src/aiida_quantumespresso/workflows/pw/bands.py b/src/aiida_quantumespresso/workflows/pw/bands.py index b5e0715b8..b7bd45b3a 100644 --- a/src/aiida_quantumespresso/workflows/pw/bands.py +++ b/src/aiida_quantumespresso/workflows/pw/bands.py @@ -3,16 +3,15 @@ from aiida import orm from aiida.common import AttributeDict from aiida.engine import ToContext, WorkChain, if_ -from aiida.plugins import WorkflowFactory from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis +from aiida_quantumespresso.calculations.pw import PwCalculation from aiida_quantumespresso.utils.mapping import prepare_process_inputs +from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain +from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain from ..protocols.utils import ProtocolMixin -PwBaseWorkChain = WorkflowFactory('quantumespresso.pw.base') -PwRelaxWorkChain = WorkflowFactory('quantumespresso.pw.relax') - def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument """Validate the inputs of the entire input namespace.""" @@ -72,6 +71,7 @@ def define(cls, spec): help='Explicit kpoints to use for the BANDS calculation. Specify either this or `bands_kpoints_distance`.') spec.input('bands_kpoints_distance', valid_type=orm.Float, required=False, help='Minimum kpoints distance for the BANDS calculation. Specify either this or `bands_kpoints`.') + spec.inputs['bands']['pw'].validator = PwCalculation.validate_inputs_base spec.inputs.validator = validate_inputs spec.outline( cls.setup, @@ -228,13 +228,12 @@ def run_scf(self): inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace='scf')) inputs.metadata.call_link_label = 'scf' inputs.pw.structure = self.ctx.current_structure - inputs.pw.parameters = inputs.pw.parameters.get_dict() - inputs.pw.parameters.setdefault('CONTROL', {})['calculation'] = 'scf' # Make sure to carry the number of bands from the relax workchain if it was run and it wasn't explicitly defined # in the inputs. One of the base workchains in the relax workchain may have changed the number automatically in # the sanity checks on band occupations. if self.ctx.current_number_of_bands: + inputs.pw.parameters = inputs.pw.parameters.get_dict() inputs.pw.parameters.setdefault('SYSTEM', {}).setdefault('nbnd', self.ctx.current_number_of_bands) inputs = prepare_process_inputs(PwBaseWorkChain, inputs) @@ -267,13 +266,6 @@ def run_bands(self): inputs.pw.parameters.setdefault('SYSTEM', {}) inputs.pw.parameters.setdefault('ELECTRONS', {}) - # The following flags always have to be set in the parameters, regardless of what caller specified in the inputs - inputs.pw.parameters['CONTROL']['calculation'] = 'bands' - - # Only set the following parameters if not directly explicitly defined in the inputs - inputs.pw.parameters['ELECTRONS'].setdefault('diagonalization', 'cg') - inputs.pw.parameters['ELECTRONS'].setdefault('diago_full_acc', True) - # If `nbands_factor` is defined in the inputs we set the `nbnd` parameter if 'nbands_factor' in self.inputs: factor = self.inputs.nbands_factor.value diff --git a/src/aiida_quantumespresso/workflows/pw/base.py b/src/aiida_quantumespresso/workflows/pw/base.py index b216677e6..5f35d5b79 100644 --- a/src/aiida_quantumespresso/workflows/pw/base.py +++ b/src/aiida_quantumespresso/workflows/pw/base.py @@ -244,17 +244,9 @@ def setup(self): self.ctx.inputs.parameters.setdefault('CONTROL', {}) self.ctx.inputs.parameters.setdefault('ELECTRONS', {}) self.ctx.inputs.parameters.setdefault('SYSTEM', {}) - self.ctx.inputs.parameters['CONTROL'].setdefault('calculation', 'scf') self.ctx.inputs.settings = self.ctx.inputs.settings.get_dict() if 'settings' in self.ctx.inputs else {} - # If a ``parent_folder`` is specified, automatically set the parameters for a ``RestartType.Full`` unless the - # ``CONTROL.restart_mode`` has explicitly been set to ``from_scratch``. In that case, the user most likely set - # that, and we do not want to override it. - restart_mode = self.ctx.inputs.parameters['CONTROL'].get('restart_mode', None) - if 'parent_folder' in self.ctx.inputs and restart_mode != 'from_scratch': - self.set_restart_type(RestartType.FULL, self.ctx.inputs.parent_folder) - def validate_kpoints(self): """Validate the inputs related to k-points. @@ -280,15 +272,6 @@ def validate_kpoints(self): self.ctx.inputs.kpoints = kpoints - def set_max_seconds(self, max_wallclock_seconds): - """Set the `max_seconds` to a fraction of `max_wallclock_seconds` option to prevent out-of-walltime problems. - - :param max_wallclock_seconds: the maximum wallclock time that will be set in the scheduler settings. - """ - max_seconds_factor = self.defaults.delta_factor_max_seconds - max_seconds = max_wallclock_seconds * max_seconds_factor - self.ctx.inputs.parameters['CONTROL']['max_seconds'] = max_seconds - def set_restart_type(self, restart_type, parent_folder=None): """Set the restart type for the next iteration.""" @@ -324,7 +307,8 @@ def prepare_process(self): max_wallclock_seconds = self.ctx.inputs.metadata.options.get('max_wallclock_seconds', None) if max_wallclock_seconds is not None and 'max_seconds' not in self.ctx.inputs.parameters['CONTROL']: - self.set_max_seconds(max_wallclock_seconds) + max_seconds = max_wallclock_seconds * self.defaults.delta_factor_max_seconds + self.ctx.inputs.parameters['CONTROL']['max_seconds'] = max_seconds def report_error_handled(self, calculation, action): """Report an action taken for a calculation that has failed. diff --git a/src/aiida_quantumespresso/workflows/pw/relax.py b/src/aiida_quantumespresso/workflows/pw/relax.py index d7fdd7a5a..2888a3e36 100644 --- a/src/aiida_quantumespresso/workflows/pw/relax.py +++ b/src/aiida_quantumespresso/workflows/pw/relax.py @@ -179,12 +179,11 @@ def setup(self): self.ctx.relax_inputs.pw.parameters = self.ctx.relax_inputs.pw.parameters.get_dict() self.ctx.relax_inputs.pw.parameters.setdefault('CONTROL', {}) - self.ctx.relax_inputs.pw.parameters['CONTROL']['restart_mode'] = 'from_scratch' # Set the meta_convergence and add it to the context self.ctx.meta_convergence = self.inputs.meta_convergence.value volume_cannot_change = ( - self.ctx.relax_inputs.pw.parameters['CONTROL']['calculation'] in ('scf', 'relax') or + self.ctx.relax_inputs.pw.parameters['CONTROL'].get('calculation', 'scf') in ('scf', 'relax') or self.ctx.relax_inputs.pw.parameters.get('CELL', {}).get('cell_dofree', None) == 'shape' ) if self.ctx.meta_convergence and volume_cannot_change: @@ -197,7 +196,7 @@ def setup(self): if 'base_final_scf' in self.inputs: self.ctx.final_scf_inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace='base_final_scf')) - if self.ctx.relax_inputs.pw.parameters['CONTROL']['calculation'] == 'scf': + if self.ctx.relax_inputs.pw.parameters['CONTROL'].get('calculation', 'scf') == 'scf': self.report( 'Work chain will not run final SCF when `calculation` is set to `scf` for the relaxation ' '`PwBaseWorkChain`.' @@ -208,9 +207,6 @@ def setup(self): self.ctx.final_scf_inputs.pw.parameters = self.ctx.final_scf_inputs.pw.parameters.get_dict() self.ctx.final_scf_inputs.pw.parameters.setdefault('CONTROL', {}) - self.ctx.final_scf_inputs.pw.parameters['CONTROL']['calculation'] = 'scf' - self.ctx.final_scf_inputs.pw.parameters['CONTROL']['restart_mode'] = 'from_scratch' - self.ctx.final_scf_inputs.pw.parameters.pop('CELL', None) self.ctx.final_scf_inputs.metadata.call_link_label = 'final_scf' def should_run_relax(self): diff --git a/tests/workflows/protocols/pw/test_bands.py b/tests/workflows/protocols/pw/test_bands.py index 65d172bdd..8a5fa4283 100644 --- a/tests/workflows/protocols/pw/test_bands.py +++ b/tests/workflows/protocols/pw/test_bands.py @@ -87,7 +87,7 @@ def test_options(fixture_code, generate_structure): for subspace in ( builder.relax.base.pw.metadata, - builder.scf.pw.metadata, - builder.bands.pw.metadata, + builder.scf.pw.metadata, # pylint: disable=no-member + builder.bands.pw.metadata, # pylint: disable=no-member ): assert subspace['options']['queue_name'] == queue_name, subspace diff --git a/tests/workflows/protocols/pw/test_bands/test_default.yml b/tests/workflows/protocols/pw/test_bands/test_default.yml index b1682597e..7883f8588 100644 --- a/tests/workflows/protocols/pw/test_bands/test_default.yml +++ b/tests/workflows/protocols/pw/test_bands/test_default.yml @@ -9,15 +9,19 @@ bands: withmpi: true parameters: CONTROL: - calculation: scf + calculation: bands etot_conv_thr: 2.0e-05 forc_conv_thr: 0.0001 + restart_mode: from_scratch tprnfor: true tstress: true ELECTRONS: conv_thr: 4.0e-10 + diago_full_acc: true + diagonalization: paro electron_maxstep: 80 mixing_beta: 0.4 + startingpot: file SYSTEM: degauss: 0.01 ecutrho: 240.0 diff --git a/tests/workflows/protocols/pw/test_relax/test_default.yml b/tests/workflows/protocols/pw/test_relax/test_default.yml index 3aef614a4..8182ec2b2 100644 --- a/tests/workflows/protocols/pw/test_relax/test_default.yml +++ b/tests/workflows/protocols/pw/test_relax/test_default.yml @@ -44,8 +44,6 @@ base_final_scf: num_machines: 1 withmpi: true parameters: - CELL: - press_conv_thr: 0.5 CONTROL: calculation: scf etot_conv_thr: 2.0e-05 diff --git a/tests/workflows/pw/test_base.py b/tests/workflows/pw/test_base.py index 0c4d3c97c..ce1b41a8b 100644 --- a/tests/workflows/pw/test_base.py +++ b/tests/workflows/pw/test_base.py @@ -252,15 +252,11 @@ def test_set_max_seconds(generate_workchain_pw): @pytest.mark.parametrize('restart_mode, expected', ( - (None, 'restart'), + ('restart', 'restart'), ('from_scratch', 'from_scratch'), )) -def test_parent_folder(generate_workchain_pw, generate_calc_job_node, restart_mode, expected): - """Test that ``parameters`` gets automatically updated if ``parent_folder`` in the inputs. - - Specifically, the ``parameters`` should define the ``CONTROL.restart_mode`` unless it was explicitly set to - ``from_scratch`` by the caller. - """ +def test_restart_mode(generate_workchain_pw, generate_calc_job_node, restart_mode, expected): + """Test that the ``CONTROL.restart_mode`` specified by the user is always respected.""" node = generate_calc_job_node('pw', test_name='default') inputs = generate_workchain_pw(return_inputs=True)