diff --git a/aiidalab_qe/node_view.py b/aiidalab_qe/node_view.py index 07b58535d..411729e8b 100644 --- a/aiidalab_qe/node_view.py +++ b/aiidalab_qe/node_view.py @@ -17,7 +17,7 @@ import traitlets from aiida.cmdline.utils.common import get_workchain_report from aiida.common import LinkType -from aiida.orm import CalcJobNode, Node, WorkChainNode +from aiida.orm import CalcJobNode, Node, ProjectionData, WorkChainNode from aiidalab_widgets_base import ProcessMonitor, register_viewer_widget from aiidalab_widgets_base.viewers import StructureDataViewer from ase import Atoms @@ -124,27 +124,44 @@ def export_bands_data(work_chain_node): def export_pdos_data(work_chain_node): if "dos" in work_chain_node.outputs: fermi_energy = work_chain_node.outputs.nscf_parameters["fermi_energy"] - x_label, energy_dos, energy_units = work_chain_node.outputs.dos.get_x() + _, energy_dos, energy_units = work_chain_node.outputs.dos.get_x() tdos_values = { f"{n} | {u}": v for n, v, u in work_chain_node.outputs.dos.get_y() } pdos_orbitals = [] - for orbital, pdos, energy in work_chain_node.outputs.projections.get_pdos(): - orbital_data = orbital.get_orbital_dict() - kind_name = orbital_data["kind_name"] - orbital_name = orbital.get_name_from_quantum_numbers( - orbital_data["angular_momentum"], orbital_data["magnetic_number"] - ) - pdos_orbitals.append( - { - "kind": kind_name, - "orbital": orbital_name, - "energy | eV": energy, - "pdos | states/eV": pdos, - } - ) + if "projections" in work_chain_node.outputs: + projection_list = [ + (work_chain_node.outputs.projections, None), + ] + else: + projection_list = [ + (work_chain_node.outputs.projections_up, "up"), + (work_chain_node.outputs.projections_down, "dn"), + ] + tdos_values["dos | states/eV"] = tdos_values.pop( + "dos_spin_up | states/eV" + ) + tdos_values.pop("dos_spin_down | states/eV") + + for projections, suffix in projection_list: # type: ProjectionData, str + for orbital, pdos, energy in projections.get_pdos(): + orbital_data = orbital.get_orbital_dict() + kind_name = orbital_data["kind_name"] + orbital_name = orbital.get_name_from_quantum_numbers( + orbital_data["angular_momentum"], orbital_data["magnetic_number"] + ) + if suffix is not None: + orbital_name += f"-{suffix}" + + pdos_orbitals.append( + { + "kind": kind_name, + "orbital": orbital_name, + "energy | eV": energy, + "pdos | states/eV": pdos, + } + ) data_dict = { "fermi_energy": fermi_energy, @@ -155,19 +172,6 @@ def export_pdos_data(work_chain_node): # And this is why we shouldn't use special encoders... return json.loads(json.dumps(data_dict, cls=MontyEncoder)) - fermi_energy = work_chain_node.outputs.nscf__output_parameters.get_dict()[ - "fermi_energy" - ] - ( - xlabel, - energy_dos, - energy_units, - ) = work_chain_node.outputs.dos__output_dos.get_x() - tdos_values = { - f"{n} | {u}": v - for n, v, u in work_chain_node.outputs.dos__output_dos.get_y() - } - def export_data(work_chain_node): return dict( diff --git a/src/aiidalab_qe_workchain/__init__.py b/src/aiidalab_qe_workchain/__init__.py index 5a33f96f4..fa669668e 100644 --- a/src/aiidalab_qe_workchain/__init__.py +++ b/src/aiidalab_qe_workchain/__init__.py @@ -86,6 +86,8 @@ def define(cls, spec): spec.output('nscf_parameters', valid_type=Dict, required=False) spec.output('dos', valid_type=XyData, required=False) spec.output('projections', valid_type=Orbital, required=False) + spec.output('projections_up', valid_type=Orbital, required=False) + spec.output('projections_down', valid_type=Orbital, required=False) # yapf: enable @classmethod @@ -347,15 +349,26 @@ def results(self): "band_parameters", self.ctx.workchain_bands.outputs.band_parameters ) self.out("band_structure", self.ctx.workchain_bands.outputs.band_structure) + if "workchain_pdos" in self.ctx: self.out( "nscf_parameters", - self.ctx.workchain_pdos.outputs.nscf__output_parameters, - ) - self.out("dos", self.ctx.workchain_pdos.outputs.dos__output_dos) - self.out( - "projections", self.ctx.workchain_pdos.outputs.projwfc__projections + self.ctx.workchain_pdos.outputs.nscf.output_parameters, ) + self.out("dos", self.ctx.workchain_pdos.outputs.dos.output_dos) + if "projections_up" in self.ctx.workchain_pdos.outputs.projwfc: + self.out( + "projections_up", + self.ctx.workchain_pdos.outputs.projwfc.projections_up, + ) + self.out( + "projections_down", + self.ctx.workchain_pdos.outputs.projwfc.projections_down, + ) + else: + self.out( + "projections", self.ctx.workchain_pdos.outputs.projwfc.projections + ) def on_terminated(self): """Clean the working directories of all child calculations if `clean_workdir=True` in the inputs."""