Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protocols: Move all static work chain inputs to protocol #902

Merged
merged 4 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/aiida_quantumespresso/workflows/protocols/pw/bands.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@ default_inputs:
bands_kpoints_distance: 0.025
clean_workdir: False
nbands_factor: 3.0
scf:
pw:
parameters:
CONTROL:
calculation: scf
bands:
pw:
parameters:
CONTROL:
calculation: bands
restart_mode: from_scratch
ELECTRONS:
diagonalization: paro
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value was added in QE 6.6? Since that is the oldest version we support?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yessir, see

#830 (comment)

diago_full_acc: True
startingpot: file
default_protocol: moderate
protocols:
moderate:
Expand Down
4 changes: 2 additions & 2 deletions src/aiida_quantumespresso/workflows/protocols/pw/relax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ default_inputs:
base_final_scf:
pw:
parameters:
CELL:
press_conv_thr: 0.5
CONTROL:
calculation: scf
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not ask questions how a pressure convergence condition was the only part of the "scf" step of the PwRelaxWorkChain protocol. 😅

default_protocol: moderate
protocols:
moderate:
Expand Down
18 changes: 5 additions & 13 deletions src/aiida_quantumespresso/workflows/pw/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from aiida import orm
from aiida.common import AttributeDict
from aiida.engine import ToContext, WorkChain, if_
from aiida.plugins import WorkflowFactory

from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis
from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain

from ..protocols.utils import ProtocolMixin

PwBaseWorkChain = WorkflowFactory('quantumespresso.pw.base')
PwRelaxWorkChain = WorkflowFactory('quantumespresso.pw.relax')


def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument
"""Validate the inputs of the entire input namespace."""
Expand Down Expand Up @@ -72,6 +71,7 @@ def define(cls, spec):
help='Explicit kpoints to use for the BANDS calculation. Specify either this or `bands_kpoints_distance`.')
spec.input('bands_kpoints_distance', valid_type=orm.Float, required=False,
help='Minimum kpoints distance for the BANDS calculation. Specify either this or `bands_kpoints`.')
spec.inputs['bands']['pw'].validator = PwCalculation.validate_inputs_base
spec.inputs.validator = validate_inputs
spec.outline(
cls.setup,
Expand Down Expand Up @@ -228,13 +228,12 @@ def run_scf(self):
inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace='scf'))
inputs.metadata.call_link_label = 'scf'
inputs.pw.structure = self.ctx.current_structure
inputs.pw.parameters = inputs.pw.parameters.get_dict()
inputs.pw.parameters.setdefault('CONTROL', {})['calculation'] = 'scf'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are removing this, maybe we move line 231 to the if self.ctx.current_number_of_bands: block, since then we only unwrap-and-wrap the parameters if we really need to update it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to still add this change @mbercx ? Then this can be merged

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this comment somehow. On it! 🚀


# Make sure to carry the number of bands from the relax workchain if it was run and it wasn't explicitly defined
# in the inputs. One of the base workchains in the relax workchain may have changed the number automatically in
# the sanity checks on band occupations.
if self.ctx.current_number_of_bands:
inputs.pw.parameters = inputs.pw.parameters.get_dict()
inputs.pw.parameters.setdefault('SYSTEM', {}).setdefault('nbnd', self.ctx.current_number_of_bands)

inputs = prepare_process_inputs(PwBaseWorkChain, inputs)
Expand Down Expand Up @@ -267,13 +266,6 @@ def run_bands(self):
inputs.pw.parameters.setdefault('SYSTEM', {})
inputs.pw.parameters.setdefault('ELECTRONS', {})

# The following flags always have to be set in the parameters, regardless of what caller specified in the inputs
inputs.pw.parameters['CONTROL']['calculation'] = 'bands'

# Only set the following parameters if not directly explicitly defined in the inputs
inputs.pw.parameters['ELECTRONS'].setdefault('diagonalization', 'cg')
inputs.pw.parameters['ELECTRONS'].setdefault('diago_full_acc', True)

# If `nbands_factor` is defined in the inputs we set the `nbnd` parameter
if 'nbands_factor' in self.inputs:
factor = self.inputs.nbands_factor.value
Expand Down
20 changes: 2 additions & 18 deletions src/aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,9 @@ def setup(self):
self.ctx.inputs.parameters.setdefault('CONTROL', {})
self.ctx.inputs.parameters.setdefault('ELECTRONS', {})
self.ctx.inputs.parameters.setdefault('SYSTEM', {})
self.ctx.inputs.parameters['CONTROL'].setdefault('calculation', 'scf')

self.ctx.inputs.settings = self.ctx.inputs.settings.get_dict() if 'settings' in self.ctx.inputs else {}

# If a ``parent_folder`` is specified, automatically set the parameters for a ``RestartType.Full`` unless the
# ``CONTROL.restart_mode`` has explicitly been set to ``from_scratch``. In that case, the user most likely set
# that, and we do not want to override it.
restart_mode = self.ctx.inputs.parameters['CONTROL'].get('restart_mode', None)
if 'parent_folder' in self.ctx.inputs and restart_mode != 'from_scratch':
self.set_restart_type(RestartType.FULL, self.ctx.inputs.parent_folder)

def validate_kpoints(self):
"""Validate the inputs related to k-points.

Expand All @@ -280,15 +272,6 @@ def validate_kpoints(self):

self.ctx.inputs.kpoints = kpoints

def set_max_seconds(self, max_wallclock_seconds):
"""Set the `max_seconds` to a fraction of `max_wallclock_seconds` option to prevent out-of-walltime problems.

:param max_wallclock_seconds: the maximum wallclock time that will be set in the scheduler settings.
"""
max_seconds_factor = self.defaults.delta_factor_max_seconds
max_seconds = max_wallclock_seconds * max_seconds_factor
self.ctx.inputs.parameters['CONTROL']['max_seconds'] = max_seconds

def set_restart_type(self, restart_type, parent_folder=None):
"""Set the restart type for the next iteration."""

Expand Down Expand Up @@ -324,7 +307,8 @@ def prepare_process(self):
max_wallclock_seconds = self.ctx.inputs.metadata.options.get('max_wallclock_seconds', None)

if max_wallclock_seconds is not None and 'max_seconds' not in self.ctx.inputs.parameters['CONTROL']:
self.set_max_seconds(max_wallclock_seconds)
max_seconds = max_wallclock_seconds * self.defaults.delta_factor_max_seconds
self.ctx.inputs.parameters['CONTROL']['max_seconds'] = max_seconds

def report_error_handled(self, calculation, action):
"""Report an action taken for a calculation that has failed.
Expand Down
8 changes: 2 additions & 6 deletions src/aiida_quantumespresso/workflows/pw/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,11 @@ def setup(self):
self.ctx.relax_inputs.pw.parameters = self.ctx.relax_inputs.pw.parameters.get_dict()

self.ctx.relax_inputs.pw.parameters.setdefault('CONTROL', {})
self.ctx.relax_inputs.pw.parameters['CONTROL']['restart_mode'] = 'from_scratch'

# Set the meta_convergence and add it to the context
self.ctx.meta_convergence = self.inputs.meta_convergence.value
volume_cannot_change = (
self.ctx.relax_inputs.pw.parameters['CONTROL']['calculation'] in ('scf', 'relax') or
self.ctx.relax_inputs.pw.parameters['CONTROL'].get('calculation', 'scf') in ('scf', 'relax') or
self.ctx.relax_inputs.pw.parameters.get('CELL', {}).get('cell_dofree', None) == 'shape'
)
if self.ctx.meta_convergence and volume_cannot_change:
Expand All @@ -197,7 +196,7 @@ def setup(self):
if 'base_final_scf' in self.inputs:
self.ctx.final_scf_inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace='base_final_scf'))

if self.ctx.relax_inputs.pw.parameters['CONTROL']['calculation'] == 'scf':
if self.ctx.relax_inputs.pw.parameters['CONTROL'].get('calculation', 'scf') == 'scf':
self.report(
'Work chain will not run final SCF when `calculation` is set to `scf` for the relaxation '
'`PwBaseWorkChain`.'
Expand All @@ -208,9 +207,6 @@ def setup(self):
self.ctx.final_scf_inputs.pw.parameters = self.ctx.final_scf_inputs.pw.parameters.get_dict()

self.ctx.final_scf_inputs.pw.parameters.setdefault('CONTROL', {})
self.ctx.final_scf_inputs.pw.parameters['CONTROL']['calculation'] = 'scf'
self.ctx.final_scf_inputs.pw.parameters['CONTROL']['restart_mode'] = 'from_scratch'
self.ctx.final_scf_inputs.pw.parameters.pop('CELL', None)
self.ctx.final_scf_inputs.metadata.call_link_label = 'final_scf'

def should_run_relax(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/workflows/protocols/pw/test_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_options(fixture_code, generate_structure):

for subspace in (
builder.relax.base.pw.metadata,
builder.scf.pw.metadata,
builder.bands.pw.metadata,
builder.scf.pw.metadata, # pylint: disable=no-member
builder.bands.pw.metadata, # pylint: disable=no-member
):
assert subspace['options']['queue_name'] == queue_name, subspace
6 changes: 5 additions & 1 deletion tests/workflows/protocols/pw/test_bands/test_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@ bands:
withmpi: true
parameters:
CONTROL:
calculation: scf
calculation: bands
etot_conv_thr: 2.0e-05
forc_conv_thr: 0.0001
restart_mode: from_scratch
tprnfor: true
tstress: true
ELECTRONS:
conv_thr: 4.0e-10
diago_full_acc: true
diagonalization: paro
electron_maxstep: 80
mixing_beta: 0.4
startingpot: file
SYSTEM:
degauss: 0.01
ecutrho: 240.0
Expand Down
2 changes: 0 additions & 2 deletions tests/workflows/protocols/pw/test_relax/test_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ base_final_scf:
num_machines: 1
withmpi: true
parameters:
CELL:
press_conv_thr: 0.5
CONTROL:
calculation: scf
etot_conv_thr: 2.0e-05
Expand Down
10 changes: 3 additions & 7 deletions tests/workflows/pw/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,11 @@ def test_set_max_seconds(generate_workchain_pw):


@pytest.mark.parametrize('restart_mode, expected', (
(None, 'restart'),
('restart', 'restart'),
('from_scratch', 'from_scratch'),
))
def test_parent_folder(generate_workchain_pw, generate_calc_job_node, restart_mode, expected):
"""Test that ``parameters`` gets automatically updated if ``parent_folder`` in the inputs.

Specifically, the ``parameters`` should define the ``CONTROL.restart_mode`` unless it was explicitly set to
``from_scratch`` by the caller.
"""
def test_restart_mode(generate_workchain_pw, generate_calc_job_node, restart_mode, expected):
"""Test that the ``CONTROL.restart_mode`` specified by the user is always respected."""
node = generate_calc_job_node('pw', test_name='default')

inputs = generate_workchain_pw(return_inputs=True)
Expand Down