From c0eb159c016d5f6a085e19d7d2240a50dbc65b12 Mon Sep 17 00:00:00 2001 From: Marnik Bercx Date: Sun, 9 Apr 2023 17:22:22 +0200 Subject: [PATCH] `PwCalculation`: Fix restart validation for `nscf`/`bands` For interrupted runs, restarting from a previous `nscf`/`bands` run with `CONTROL.restart_mode` set to `restart` is a valid use case. Here we remove the validation that raised a warning when the user tried to restart an interrupted run with this input. --- src/aiida_quantumespresso/calculations/pw.py | 18 ++++++------------ tests/calculations/test_pw.py | 12 ++---------- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/src/aiida_quantumespresso/calculations/pw.py b/src/aiida_quantumespresso/calculations/pw.py index 53095cecd..7436d822f 100644 --- a/src/aiida_quantumespresso/calculations/pw.py +++ b/src/aiida_quantumespresso/calculations/pw.py @@ -175,21 +175,16 @@ def validate_inputs_base(value, _): parameters = value['parameters'].get_dict() calculation_type = parameters.get('CONTROL', {}).get('calculation', 'scf') - # Check that the restart input parameters are set correctly - if calculation_type in ('nscf', 'bands'): - if parameters.get('ELECTRONS', {}).get('startingpot', 'file') != 'file': - return f'`startingpot` should be set to `file` for a `{calculation_type}` calculation.' - if parameters.get('CONTROL', {}).get('restart_mode', 'from_scratch') != 'from_scratch': - warnings.warn(f'`restart_mode` should be set to `from_scratch` for a `{calculation_type}` calculation.') - elif 'parent_folder' in value: + # If a `parent_folder` input is provided, make sure the inputs are set to restart + if 'parent_folder' in value and calculation_type not in ('nscf', 'bands'): if not any([ parameters.get('CONTROL', {}).get('restart_mode', None) == 'restart', parameters.get('ELECTRONS', {}).get('startingpot', None) == 'file', parameters.get('ELECTRONS', {}).get('startingwfc', None) == 'file' ]): warnings.warn( - '`parent_folder` input was provided for the `PwCalculation`, but no ' - 'input parameters are set to restart from these files.' + f'`parent_folder` input was provided for the `{calculation_type}` `PwCalculation`, but no input' + 'parameters (e.g. `restart_mode`, `startingpot`, ...) are set to restart from this folder.' ) @classmethod @@ -197,9 +192,8 @@ def validate_inputs(cls, value, port_namespace): """Validate the top level namespace. Check that the restart input parameters are set correctly. In case of 'nscf' and 'bands' calculations, this - means ``parent_folder`` is provided, ``startingpot`` is set to 'file' and ``restart_mode`` is 'from_scratch'. - For other calculations, if the ``parent_folder`` is provided, the restart settings must be set to use some of - the outputs. + means ``parent_folder`` is provided. For other calculations, if the ``parent_folder`` is provided, the restart + settings must be set to use some of the outputs. Note that the validator is split in two methods: ``validate_inputs`` and ``validate_inputs_base``. This is to facilitate work chains that wrap this calculation that will provide the ``parent_folder`` themselves and so do diff --git a/tests/calculations/test_pw.py b/tests/calculations/test_pw.py index a8a6586cb..5f87b189e 100644 --- a/tests/calculations/test_pw.py +++ b/tests/calculations/test_pw.py @@ -330,18 +330,10 @@ def test_pw_validate_inputs_restart_nscf( inputs['parent_folder'] = remote_data generate_calc_job(fixture_sandbox, entry_point_name, inputs) - # Set `startingpot` to `'atomic'` -> raise - parameters['ELECTRONS']['startingpot'] = 'atomic' - inputs['parameters'] = orm.Dict(parameters) - with pytest.raises(ValueError, match='`startingpot` should be set to `file` for a `.*` calculation.'): - generate_calc_job(fixture_sandbox, entry_point_name, inputs) - - # Set `restart_mode` to `'restart'` -> warning - parameters['ELECTRONS'].pop('startingpot') + # Set `restart_mode` to `'restart'` -> works parameters['CONTROL']['restart_mode'] = 'restart' inputs['parameters'] = orm.Dict(parameters) - with pytest.warns(Warning, match='`restart_mode` should be set to `from_scratch` for a `.*`.'): - generate_calc_job(fixture_sandbox, entry_point_name, inputs) + generate_calc_job(fixture_sandbox, entry_point_name, inputs) def test_fixed_coords(fixture_sandbox, generate_calc_job, generate_inputs_pw, file_regression):