Skip to content

Commit

Permalink
Remove infamous code and move inputs to protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
mbercx committed Apr 17, 2023
1 parent e95ae2c commit a193fa1
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 34 deletions.
4 changes: 2 additions & 2 deletions src/aiida_quantumespresso/workflows/protocols/pw/relax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ default_inputs:
base_final_scf:
pw:
parameters:
CELL:
press_conv_thr: 0.5
CONTROL:
calculation: scf
default_protocol: moderate
protocols:
moderate:
Expand Down
19 changes: 2 additions & 17 deletions src/aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,6 @@ def setup(self):

self.ctx.inputs.settings = self.ctx.inputs.settings.get_dict() if 'settings' in self.ctx.inputs else {}

# If a ``parent_folder`` is specified, automatically set the parameters for a ``RestartType.Full`` unless the
# ``CONTROL.restart_mode`` has explicitly been set to ``from_scratch``. In that case, the user most likely set
# that, and we do not want to override it.
restart_mode = self.ctx.inputs.parameters['CONTROL'].get('restart_mode', None)
if 'parent_folder' in self.ctx.inputs and restart_mode != 'from_scratch':
self.set_restart_type(RestartType.FULL, self.ctx.inputs.parent_folder)

def validate_kpoints(self):
"""Validate the inputs related to k-points.
Expand All @@ -279,15 +272,6 @@ def validate_kpoints(self):

self.ctx.inputs.kpoints = kpoints

def set_max_seconds(self, max_wallclock_seconds):
"""Set the `max_seconds` to a fraction of `max_wallclock_seconds` option to prevent out-of-walltime problems.
:param max_wallclock_seconds: the maximum wallclock time that will be set in the scheduler settings.
"""
max_seconds_factor = self.defaults.delta_factor_max_seconds
max_seconds = max_wallclock_seconds * max_seconds_factor
self.ctx.inputs.parameters['CONTROL']['max_seconds'] = max_seconds

def set_restart_type(self, restart_type, parent_folder=None):
"""Set the restart type for the next iteration."""

Expand Down Expand Up @@ -323,7 +307,8 @@ def prepare_process(self):
max_wallclock_seconds = self.ctx.inputs.metadata.options.get('max_wallclock_seconds', None)

if max_wallclock_seconds is not None and 'max_seconds' not in self.ctx.inputs.parameters['CONTROL']:
self.set_max_seconds(max_wallclock_seconds)
max_seconds = max_wallclock_seconds * self.defaults.delta_factor_max_seconds
self.ctx.inputs.parameters['CONTROL']['max_seconds'] = max_seconds

def report_error_handled(self, calculation, action):
"""Report an action taken for a calculation that has failed.
Expand Down
8 changes: 2 additions & 6 deletions src/aiida_quantumespresso/workflows/pw/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,11 @@ def setup(self):
self.ctx.relax_inputs.pw.parameters = self.ctx.relax_inputs.pw.parameters.get_dict()

self.ctx.relax_inputs.pw.parameters.setdefault('CONTROL', {})
self.ctx.relax_inputs.pw.parameters['CONTROL']['restart_mode'] = 'from_scratch'

# Set the meta_convergence and add it to the context
self.ctx.meta_convergence = self.inputs.meta_convergence.value
volume_cannot_change = (
self.ctx.relax_inputs.pw.parameters['CONTROL']['calculation'] in ('scf', 'relax') or
self.ctx.relax_inputs.pw.parameters['CONTROL'].get('calculation', 'scf') in ('scf', 'relax') or
self.ctx.relax_inputs.pw.parameters.get('CELL', {}).get('cell_dofree', None) == 'shape'
)
if self.ctx.meta_convergence and volume_cannot_change:
Expand All @@ -197,7 +196,7 @@ def setup(self):
if 'base_final_scf' in self.inputs:
self.ctx.final_scf_inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace='base_final_scf'))

if self.ctx.relax_inputs.pw.parameters['CONTROL']['calculation'] == 'scf':
if self.ctx.relax_inputs.pw.parameters['CONTROL'].get('calculation', 'scf') == 'scf':
self.report(
'Work chain will not run final SCF when `calculation` is set to `scf` for the relaxation '
'`PwBaseWorkChain`.'
Expand All @@ -208,9 +207,6 @@ def setup(self):
self.ctx.final_scf_inputs.pw.parameters = self.ctx.final_scf_inputs.pw.parameters.get_dict()

self.ctx.final_scf_inputs.pw.parameters.setdefault('CONTROL', {})
self.ctx.final_scf_inputs.pw.parameters['CONTROL']['calculation'] = 'scf'
self.ctx.final_scf_inputs.pw.parameters['CONTROL']['restart_mode'] = 'from_scratch'
self.ctx.final_scf_inputs.pw.parameters.pop('CELL', None)
self.ctx.final_scf_inputs.metadata.call_link_label = 'final_scf'

def should_run_relax(self):
Expand Down
2 changes: 0 additions & 2 deletions tests/workflows/protocols/pw/test_relax/test_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ base_final_scf:
num_machines: 1
withmpi: true
parameters:
CELL:
press_conv_thr: 0.5
CONTROL:
calculation: scf
etot_conv_thr: 2.0e-05
Expand Down
10 changes: 3 additions & 7 deletions tests/workflows/pw/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,11 @@ def test_set_max_seconds(generate_workchain_pw):


@pytest.mark.parametrize('restart_mode, expected', (
(None, 'restart'),
('restart', 'restart'),
('from_scratch', 'from_scratch'),
))
def test_parent_folder(generate_workchain_pw, generate_calc_job_node, restart_mode, expected):
"""Test that ``parameters`` gets automatically updated if ``parent_folder`` in the inputs.
Specifically, the ``parameters`` should define the ``CONTROL.restart_mode`` unless it was explicitly set to
``from_scratch`` by the caller.
"""
def test_restart_mode(generate_workchain_pw, generate_calc_job_node, restart_mode, expected):
"""Test that the ``CONTROL.restart_mode`` specified by the user is always respected."""
node = generate_calc_job_node('pw', test_name='default')

inputs = generate_workchain_pw(return_inputs=True)
Expand Down

0 comments on commit a193fa1

Please sign in to comment.