Skip to content

Commit

Permalink
PdosWorkChain: Fix constrained magnetization case
Browse files Browse the repository at this point in the history
When using constrained magnetization via the `SYSTEM.tot_magnetization` input of `pw.x`,
the calculation outputs two different fermi levels for the two spin channels:
`fermi_energy_up` and `fermi_energy_down`. In this case the `PdosWorkChain` would except
since it extracts the `fermi_energy` output in the `inspect_nscf` step of the outline.

Here the step will check if the `fermi_energy` value is in the output, and look for the
two values of the up/down spin channels in case it isn't. The fermi level of the NSCF
step is then set to the maximum of the two levels.

The `fixture_code` fixture is updated to use a query instead of `load_code`, since
using `pymark.parametrize` with the hierarchy of fixtures can apparently create multiple
instances of codes with the same label in the testing profile.

Co-authored-by: Marnik Bercx <mbercx@gmail.com>
  • Loading branch information
AndresOrtegaGuerrero and mbercx authored Dec 12, 2023
1 parent b9c7517 commit a68e1e1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
7 changes: 6 additions & 1 deletion src/aiida_quantumespresso/workflows/pdos.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,12 @@ def inspect_nscf(self):
self.ctx.nscf_emin = workchain.outputs.output_band.get_array('bands').min()
self.ctx.nscf_emax = workchain.outputs.output_band.get_array('bands').max()
self.ctx.nscf_parent_folder = workchain.outputs.remote_folder
self.ctx.nscf_fermi = workchain.outputs.output_parameters.dict.fermi_energy
if 'fermi_energy' in workchain.outputs.output_parameters.dict:
self.ctx.nscf_fermi = workchain.outputs.output_parameters.dict.fermi_energy
else:
fermi_energy_up = workchain.outputs.output_parameters.dict.fermi_energy_up
fermi_energy_down = workchain.outputs.output_parameters.dict.fermi_energy_down
self.ctx.nscf_fermi = max(fermi_energy_down, fermi_energy_up)

def _generate_dos_inputs(self):
"""Run DOS calculation, to generate total Densities of State."""
Expand Down
13 changes: 8 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# pylint: disable=redefined-outer-name,too-many-statements
# pylint: disable=redefined-outer-name,too-many-statements,unsubscriptable-object
"""Initialise a text database and profile for pytest."""
from collections.abc import Mapping
import io
Expand Down Expand Up @@ -51,14 +51,17 @@ def fixture_code(fixture_localhost):
"""Return an ``InstalledCode`` instance configured to run calculations of given entry point on localhost."""

def _fixture_code(entry_point_name):
from aiida.common import exceptions
from aiida.orm import InstalledCode, load_code
from aiida.orm import InstalledCode, QueryBuilder

label = f'test.{entry_point_name}'

query = QueryBuilder().append(
InstalledCode,
filters={'label': label},
)
try:
return load_code(label=label)
except exceptions.NotExistent:
return query.first()[0]
except TypeError:
return InstalledCode(
label=label,
computer=fixture_localhost,
Expand Down
18 changes: 16 additions & 2 deletions tests/workflows/test_pdos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# pylint: disable=unused-argument
# pylint: disable=unused-argument,too-many-statements
"""Tests for the `PdosWorkChain` class."""
from __future__ import absolute_import

Expand All @@ -8,6 +8,7 @@
from aiida.engine.utils import instantiate_process
from aiida.manage.manager import get_manager
from plumpy import ProcessState
import pytest

from aiida_quantumespresso.calculations.helpers import pw_input_helper

Expand All @@ -26,6 +27,17 @@ def instantiate_process_cls(process_cls, inputs):
return instantiate_process(runner, process_cls, **inputs)


@pytest.mark.parametrize(
'nscf_output_parameters', [
{
'fermi_energy': 6.9
},
{
'fermi_energy_down': 5.9,
'fermi_energy_up': 6.9
},
]
)
def test_default(
generate_workchain_pdos,
generate_workchain_pw,
Expand All @@ -35,6 +47,7 @@ def test_default(
generate_calc_job_node,
fixture_sandbox,
generate_bands_data,
nscf_output_parameters,
):
"""Test instantiating the WorkChain, then mock its process, by calling methods in the ``spec.outline``."""

Expand Down Expand Up @@ -76,7 +89,7 @@ def test_default(
remote.store()
remote.base.links.add_incoming(mock_wknode, link_type=LinkType.RETURN, link_label='remote_folder')

result = orm.Dict({'fermi_energy': 6.9029595890428})
result = orm.Dict(nscf_output_parameters)
result.store()
result.base.links.add_incoming(mock_wknode, link_type=LinkType.RETURN, link_label='output_parameters')

Expand All @@ -87,6 +100,7 @@ def test_default(
wkchain.ctx.workchain_nscf = mock_wknode

assert wkchain.inspect_nscf() is None
assert 'nscf_fermi' in wkchain.ctx

# mock run dos and projwfc, and check that their inputs are acceptable
dos_inputs, projwfc_inputs = wkchain.run_pdos_parallel()
Expand Down

0 comments on commit a68e1e1

Please sign in to comment.