diff --git a/aiidalab_qe/app/__init__.py b/aiidalab_qe/app/__init__.py deleted file mode 100644 index 45353d73..00000000 --- a/aiidalab_qe/app/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Package for the AiiDAlab QE app.""" - -from .process import WorkChainSelector -from .steps import SubmitQeAppWorkChainStep, ViewQeAppWorkChainStatusAndResultsStep -from .structures import StructureSelectionStep - -__all__ = [ - "StructureSelectionStep", - "SubmitQeAppWorkChainStep", - "ViewQeAppWorkChainStatusAndResultsStep", - "WorkChainSelector", -] diff --git a/aiidalab_qe/app/steps.py b/aiidalab_qe/app/steps.py deleted file mode 100644 index 14e2cfcb..00000000 --- a/aiidalab_qe/app/steps.py +++ /dev/null @@ -1,1401 +0,0 @@ -# -*- coding: utf-8 -*- -"""Widgets for the submission of bands work chains. - -Authors: AiiDAlab team -""" -from __future__ import annotations - -import os -import typing as t -from dataclasses import dataclass - -import ipywidgets as ipw -import traitlets -from aiida.common import NotExistent -from aiida.engine import ProcessBuilderNamespace, ProcessState, submit -from aiida.orm import WorkChainNode, load_code, load_node -from aiida.plugins import DataFactory -from aiida_quantumespresso.common.types import ElectronicType, RelaxType, SpinType -from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain -from aiidalab_widgets_base import ( - AiidaNodeViewWidget, - ComputationalResourcesWidget, - ProcessMonitor, - ProcessNodesTreeWidget, - WizardAppWidgetStep, -) -from IPython.display import clear_output, display - -from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS -from aiidalab_qe.app.pseudos import PseudoFamilySelector -from aiidalab_qe.app.setup_codes import QESetupWidget -from aiidalab_qe.app.sssp import SSSPInstallWidget -from aiidalab_qe.app.widgets import ParallelizationSettings, ResourceSelectionWidget -from aiidalab_qe.workflows import QeAppWorkChain - -StructureData = DataFactory("core.structure") -Float = DataFactory("core.float") -Dict = DataFactory("core.dict") -Str = DataFactory("core.str") - -PROTOCOL_PSEUDO_MAP = { - "fast": "SSSP/1.2/PBE/efficiency", - "moderate": "SSSP/1.2/PBE/efficiency", - "precise": "SSSP/1.2/PBE/precision", -} - - -# The static input parameters for the QE App WorkChain -# The dataclass does not include codes and structure which will be set -# from widgets separately. -# Relax type, electronic type, spin type, are str because they are used also -# for serialized input of extras attributes of the workchain -@dataclass(frozen=True) -class QeWorkChainParameters: - protocol: str - relax_type: str - properties: t.List[str] - spin_type: str - electronic_type: str - overrides: t.Dict[str, t.Any] - initial_magnetic_moments: t.Dict[str, float] - - -class WorkChainSettings(ipw.VBox): - structure_title = ipw.HTML( - """
-

Structure

""" - ) - structure_help = ipw.HTML( - """
- You have three options:
- (1) Structure as is: perform a self consistent calculation using the structure provided as input.
- (2) Atomic positions: perform a full relaxation of the internal atomic coordinates.
- (3) Full geometry: perform a full relaxation for both the internal atomic coordinates and the cell vectors.
""" - ) - materials_help = ipw.HTML( - """
- Below you can indicate both if the material should be treated as an insulator - or a metal (if in doubt, choose "Metal"), - and if it should be studied with magnetization/spin polarization, - switch magnetism On or Off (On is at least twice more costly). -
""" - ) - - properties_title = ipw.HTML( - """
-

Properties

""" - ) - properties_help = ipw.HTML( - """
- The band structure workflow will - automatically detect the default path in reciprocal space using the - - SeeK-path tool.
""" - ) - - protocol_title = ipw.HTML( - """
-

Protocol

""" - ) - protocol_help = ipw.HTML( - """
- The "moderate" protocol represents a trade-off between - accuracy and speed. Choose the "fast" protocol for a faster calculation - with less precision and the "precise" protocol to aim at best accuracy (at the price of longer/costlier calculations).
""" - ) - - def __init__(self, **kwargs): - # RelaxType: degrees of freedom in geometry optimization - self.relax_type = ipw.ToggleButtons( - options=[ - ("Structure as is", "none"), - ("Atomic positions", "positions"), - ("Full geometry", "positions_cell"), - ], - value="positions_cell", - ) - - # SpinType: magnetic properties of material - self.spin_type = ipw.ToggleButtons( - options=[("Off", "none"), ("On", "collinear")], - value=DEFAULT_PARAMETERS["spin_type"], - style={"description_width": "initial"}, - ) - - # ElectronicType: electronic properties of material - self.electronic_type = ipw.ToggleButtons( - options=[("Metal", "metal"), ("Insulator", "insulator")], - value=DEFAULT_PARAMETERS["electronic_type"], - style={"description_width": "initial"}, - ) - - # Checkbox to see if the band structure should be calculated - self.bands_run = ipw.Checkbox( - description="", - indent=False, - value=True, - layout=ipw.Layout(max_width="10%"), - ) - - # Checkbox to see if the PDOS should be calculated - self.pdos_run = ipw.Checkbox( - description="", - indent=False, - value=True, - layout=ipw.Layout(max_width="10%"), - ) - - # Work chain protocol - self.workchain_protocol = ipw.ToggleButtons( - options=["fast", "moderate", "precise"], - value="moderate", - ) - super().__init__( - children=[ - self.structure_title, - self.structure_help, - self.relax_type, - self.materials_help, - ipw.HBox( - children=[ - ipw.Label( - "Electronic Type:", - layout=ipw.Layout( - justify_content="flex-start", width="120px" - ), - ), - self.electronic_type, - ] - ), - ipw.HBox( - children=[ - ipw.Label( - "Magnetism:", - layout=ipw.Layout( - justify_content="flex-start", width="120px" - ), - ), - self.spin_type, - ] - ), - self.properties_title, - ipw.HTML("Select which properties to calculate:"), - ipw.HBox(children=[ipw.HTML("Band structure"), self.bands_run]), - ipw.HBox( - children=[ - ipw.HTML("Projected density of states"), - self.pdos_run, - ] - ), - self.properties_help, - self.protocol_title, - ipw.HTML("Select the protocol:", layout=ipw.Layout(flex="1 1 auto")), - self.workchain_protocol, - self.protocol_help, - ], - **kwargs, - ) - - def _update_settings(self, **kwargs): - """Update the settings based on the given dict.""" - for key in [ - "relax_type", - "spin_type", - "electronic_type", - "bands_run", - "pdos_run", - "workchain_protocol", - ]: - if key in kwargs: - getattr(self, key).value = kwargs[key] - - -class AdvancedSettings(ipw.VBox): - title = ipw.HTML( - """
-

Advanced Settings

""" - ) - description = ipw.HTML("""Select the advanced settings for the pw.x code.""") - - def __init__(self, **kwargs): - self.override = ipw.Checkbox( - description="Override", - indent=False, - value=False, - ) - self.smearing = SmearingSettings() - self.kpoints = KpointSettings() - self.tot_charge = TotalCharge() - self.magnetization = MagnetizationSettings() - self.list_overrides = [ - self.smearing.override, - self.kpoints.override, - self.tot_charge.override, - self.magnetization.override, - ] - for override in self.list_overrides: - ipw.dlink( - (self.override, "value"), - (override, "disabled"), - lambda override: not override, - ) - self.override.observe(self.set_advanced_settings, "value") - super().__init__( - children=[ - self.title, - ipw.HBox( - [ - self.description, - self.override, - ], - ), - self.tot_charge, - self.magnetization, - self.smearing, - self.kpoints, - ], - layout=ipw.Layout(justify_content="space-between"), - **kwargs, - ) - - def set_advanced_settings(self, _=None): - self.smearing.reset() - self.kpoints.reset() - self.tot_charge.reset() - self.magnetization.reset() - - -class TotalCharge(ipw.VBox): - """Widget to define the total charge of the simulation""" - - tot_charge_default = traitlets.Float(default_value=0.0) - - def __init__(self, **kwargs): - self.override = ipw.Checkbox( - description="Override", - indent=False, - value=False, - ) - self.charge = ipw.BoundedFloatText( - value=0, - min=-3, - max=3, - step=0.01, - disabled=False, - description="Total charge:", - style={"description_width": "initial"}, - ) - ipw.dlink( - (self.override, "value"), - (self.charge, "disabled"), - lambda override: not override, - ) - super().__init__( - children=[ - ipw.HBox( - [ - self.override, - self.charge, - ], - ), - ], - layout=ipw.Layout(justify_content="space-between"), - **kwargs, - ) - self.charge.observe(self.set_tot_charge, "value") - self.override.observe(self.set_tot_charge, "value") - - def set_tot_charge(self, _=None): - self.charge.value = ( - self.charge.value if self.override.value else self.tot_charge_default - ) - - def _update_settings(self, **kwargs): - """Update the override and override_tot_charge and override_tot_charge values by the given keyword arguments - Therefore the override checkbox is not updated and defaults to True""" - self.override.value = True - with self.hold_trait_notifications(): - if "tot_charge" in kwargs: - self.charge.value = kwargs["tot_charge"] - - def reset(self): - with self.hold_trait_notifications(): - self.charge.value = self.tot_charge_default - self.override.value = False - - -class MagnetizationSettings(ipw.VBox): - """Widget to set the initial magnetic moments for each kind names defined in the StructureData (StructureDtaa.get_kind_names()) - Usually these are the names of the elements in the StructureData - (For example 'C' , 'N' , 'Fe' . However the StructureData can have defined kinds like 'Fe1' and 'Fe2') - - The widget generate a dictionary that can be used to set initial_magnetic_moments in the builder of PwBaseWorkChain - - Attributes: - input_structure(StructureData): trait that containes the input_strucgure (confirmed structure from previous step) - """ - - input_structure = traitlets.Instance(StructureData, allow_none=True) - - def __init__(self, **kwargs): - self.input_structure = StructureData() - self.input_structure_labels = [] - self.description = ipw.HTML( - "Define magnetization: Input structure not confirmed" - ) - self.kinds = self.create_kinds_widget() - self.kinds_widget_out = ipw.Output() - self.override = ipw.Checkbox( - description="Override", - indent=False, - value=False, - ) - super().__init__( - children=[ - ipw.HBox( - [ - self.override, - self.description, - self.kinds_widget_out, - ], - ), - ], - layout=ipw.Layout(justify_content="space-between"), - **kwargs, - ) - self.display_kinds() - self.override.observe(self._disable_kinds_widgets, "value") - - def _disable_kinds_widgets(self, _=None): - for i in range(len(self.kinds.children)): - self.kinds.children[i].disabled = not self.override.value - - def reset(self): - self.override.value = False - if hasattr(self.kinds, "children") and self.kinds.children: - for i in range(len(self.kinds.children)): - self.kinds.children[i].value = 0.0 - - def create_kinds_widget(self): - if self.input_structure_labels: - widgets_list = [] - for kind_label in self.input_structure_labels: - kind_widget = ipw.BoundedFloatText( - description=kind_label, - min=-1, - max=1, - step=0.1, - value=0.0, - disabled=True, - ) - widgets_list.append(kind_widget) - kinds_widget = ipw.VBox(widgets_list) - else: - kinds_widget = None - - return kinds_widget - - def update_kinds_widget(self): - self.input_structure_labels = self.input_structure.get_kind_names() - self.kinds = self.create_kinds_widget() - self.description.value = "Define magnetization: " - self.display_kinds() - - def display_kinds(self): - if "PYTEST_CURRENT_TEST" not in os.environ and self.kinds: - with self.kinds_widget_out: - clear_output() - display(self.kinds) - - def _update_widget(self, change): - self.input_structure = change["new"] - self.update_kinds_widget() - - def get_magnetization(self): - """Method to generate the dictionary with the initial magnetic moments""" - magnetization = {} - for i in range(len(self.kinds.children)): - magnetization[self.input_structure_labels[i]] = self.kinds.children[i].value - return magnetization - - def _set_magnetization_values(self, **kwargs): - """Update used for conftest setting all magnetization to a value""" - self.override.value = True - with self.hold_trait_notifications(): - if "initial_magnetic_moments" in kwargs: - for i in range(len(self.kinds.children)): - self.kinds.children[i].value = kwargs["initial_magnetic_moments"] - - -class SmearingSettings(ipw.VBox): - smearing_description = ipw.HTML( - """

- The smearing type and width is set by the chosen protocol. - Tick the box to override the default, not advised unless you've mastered smearing effects (click here for a discussion). -

""" - ) - - # The default of `smearing` and `degauss` the type and width - # must be linked to the `protocol` - degauss_default = traitlets.Float(default_value=0.01) - smearing_default = traitlets.Unicode(default_value="cold") - - def __init__(self, **kwargs): - self.override = ipw.Checkbox( - description="Override", - indent=False, - value=False, - ) - self.smearing = ipw.Dropdown( - options=["cold", "gaussian", "fermi-dirac", "methfessel-paxton"], - value=self.smearing_default, - description="Smearing type:", - disabled=False, - style={"description_width": "initial"}, - ) - self.degauss = ipw.FloatText( - value=self.degauss_default, - step=0.005, - description="Smearing width (Ry):", - disabled=False, - style={"description_width": "initial"}, - ) - ipw.dlink( - (self.override, "value"), - (self.degauss, "disabled"), - lambda override: not override, - ) - ipw.dlink( - (self.override, "value"), - (self.smearing, "disabled"), - lambda override: not override, - ) - self.degauss.observe(self.set_smearing, "value") - self.smearing.observe(self.set_smearing, "value") - self.override.observe(self.set_smearing, "value") - - super().__init__( - children=[ - self.smearing_description, - ipw.HBox([self.override, self.smearing, self.degauss]), - ], - layout=ipw.Layout(justify_content="space-between"), - **kwargs, - ) - - def set_smearing(self, _=None): - self.degauss.value = ( - self.degauss.value if self.override.value else self.degauss_default - ) - self.smearing.value = ( - self.smearing.value if self.override.value else self.smearing_default - ) - - def _update_settings(self, **kwargs): - """Update the smearing and degauss values by the given keyword arguments - This is the same as the `set_smearing` method but without the observer. - Therefore the override checkbox is not updated and defaults to True""" - self.override.value = True - - with self.hold_trait_notifications(): - if "smearing" in kwargs: - self.smearing.value = kwargs["smearing"] - - if "degauss" in kwargs: - self.degauss.value = kwargs["degauss"] - - def reset(self): - with self.hold_trait_notifications(): - self.degauss.value = self.degauss_default - self.smearing.value = self.smearing_default - self.override.value = False - - -class KpointSettings(ipw.VBox): - kpoints_distance_description = ipw.HTML( - """
- The k-points mesh density of the SCF calculation is set by the protocol. - The value below represents the maximum distance between the k-points in each direction of reciprocal space. - Tick the box to override the default, smaller is more accurate and costly.
""" - ) - - # The default of `kpoints_distance` must be linked to the `protocol` - kpoints_distance_default = traitlets.Float(default_value=0.15) - - def __init__(self, **kwargs): - self.override = ipw.Checkbox( - description="Override", - indent=False, - value=False, - ) - self.distance = ipw.FloatText( - value=self.kpoints_distance_default, - step=0.05, - description="K-points distance (1/Å):", - disabled=False, - style={"description_width": "initial"}, - ) - ipw.dlink( - (self.override, "value"), - (self.distance, "disabled"), - lambda override: not override, - ) - self.distance.observe(self.set_kpoints_distance, "value") - self.override.observe(self.set_kpoints_distance, "value") - self.observe(self.set_kpoints_distance, "kpoints_distance_default") - - super().__init__( - children=[ - self.kpoints_distance_description, - ipw.HBox([self.override, self.distance]), - ], - layout=ipw.Layout(justify_content="space-between"), - **kwargs, - ) - - def set_kpoints_distance(self, _=None): - self.distance.value = ( - self.distance.value - if self.override.value - else self.kpoints_distance_default - ) - - def _update_settings(self, **kwargs): - """Update the kpoints_distance value by the given keyword arguments. - This is the same as the `set_kpoints_distance` method but without the observer. - """ - self.override.value = True - if "kpoints_distance" in kwargs: - self.distance.value = kwargs["kpoints_distance"] - - def reset(self): - with self.hold_trait_notifications(): - self.distance.value = self.kpoints_distance_default - self.override.value = False - - -class ConfigureQeAppWorkChainStep(ipw.VBox, WizardAppWidgetStep): - confirmed = traitlets.Bool() - previous_step_state = traitlets.UseEnum(WizardAppWidgetStep.State) - workchain_settings = traitlets.Instance(WorkChainSettings, allow_none=True) - pseudo_family_selector = traitlets.Instance(PseudoFamilySelector, allow_none=True) - advanced_settings = traitlets.Instance(AdvancedSettings, allow_none=True) - input_structure = traitlets.Instance(StructureData, allow_none=True) - - def __init__(self, **kwargs): - self.workchain_settings = WorkChainSettings() - self.workchain_settings.relax_type.observe(self._update_state, "value") - self.workchain_settings.bands_run.observe(self._update_state, "value") - self.workchain_settings.pdos_run.observe(self._update_state, "value") - - self.pseudo_family_selector = PseudoFamilySelector() - self.advanced_settings = AdvancedSettings() - - ipw.dlink( - (self.workchain_settings.workchain_protocol, "value"), - (self.advanced_settings.kpoints, "kpoints_distance_default"), - lambda protocol: PwBaseWorkChain.get_protocol_inputs(protocol)[ - "kpoints_distance" - ], - ) - - ipw.dlink( - (self.workchain_settings.workchain_protocol, "value"), - (self.advanced_settings.smearing, "degauss_default"), - lambda protocol: PwBaseWorkChain.get_protocol_inputs(protocol)["pw"][ - "parameters" - ]["SYSTEM"]["degauss"], - ) - - ipw.dlink( - (self.workchain_settings.workchain_protocol, "value"), - (self.advanced_settings.smearing, "smearing_default"), - lambda protocol: PwBaseWorkChain.get_protocol_inputs(protocol)["pw"][ - "parameters" - ]["SYSTEM"]["smearing"], - ) - - self.tab = ipw.Tab( - children=[ - self.workchain_settings, - ipw.VBox( - children=[ - self.advanced_settings, - self.pseudo_family_selector, - ] - ), - ], - layout=ipw.Layout(min_height="250px"), - ) - - self.tab.set_title(0, "Workflow") - self.tab.set_title(1, "Advanced settings") - - self._submission_blocker_messages = ipw.HTML() - - self.confirm_button = ipw.Button( - description="Confirm", - tooltip="Confirm the currently selected settings and go to the next step.", - button_style="success", - icon="check-circle", - disabled=True, - layout=ipw.Layout(width="auto"), - ) - - self.confirm_button.on_click(self.confirm) - - super().__init__( - children=[ - self.tab, - self._submission_blocker_messages, - self.confirm_button, - ], - **kwargs, - ) - - @traitlets.observe("input_structure") - def _update_input_structure(self, change): - if self.input_structure is not None: - self.advanced_settings.magnetization._update_widget(change) - - @traitlets.observe("previous_step_state") - def _observe_previous_step_state(self, change): - self._update_state() - - def set_input_parameters(self, parameters): - """Set the inputs in the GUI based on a set of parameters.""" - - with self.hold_trait_notifications(): - # Work chain settings - self.workchain_settings.relax_type.value = parameters["relax_type"] - self.workchain_settings.spin_type.value = parameters["spin_type"] - self.workchain_settings.electronic_type.value = parameters[ - "electronic_type" - ] - self.workchain_settings.bands_run.value = parameters["run_bands"] - self.workchain_settings.pdos_run.value = parameters["run_pdos"] - self.workchain_settings.workchain_protocol.value = parameters["protocol"] - - # Advanced settings - self.pseudo_family_selector.value = parameters["pseudo_family"] - if parameters.get("kpoints_distance_override", None) is not None: - self.advanced_settings.kpoints.distance.value = parameters[ - "kpoints_distance_override" - ] - self.advanced_settings.kpoints.override.value = True - if parameters.get("degauss_override", None) is not None: - self.advanced_settings.smearing.degauss.value = parameters[ - "degauss_override" - ] - self.advanced_settings.smearing.override.value = True - if parameters.get("smearing_override", None) is not None: - self.advanced_settings.smearing.smearing.value = parameters[ - "smearing_override" - ] - self.advanced_settings.smearing.override.value = True - - def _update_state(self, _=None): - if self.previous_step_state == self.State.SUCCESS: - self.confirm_button.disabled = False - self._submission_blocker_messages.value = "" - self.state = self.State.CONFIGURED - elif self.previous_step_state == self.State.FAIL: - self.state = self.State.FAIL - else: - self.confirm_button.disabled = True - self.state = self.State.INIT - self.set_input_parameters(DEFAULT_PARAMETERS) - - def confirm(self, _=None): - self.confirm_button.disabled = False - self.state = self.State.SUCCESS - - @traitlets.default("state") - def _default_state(self): - return self.State.INIT - - def reset(self): - with self.hold_trait_notifications(): - self.set_input_parameters(DEFAULT_PARAMETERS) - - -class SubmitQeAppWorkChainStep(ipw.VBox, WizardAppWidgetStep): - """Step for submission of a bands workchain.""" - - codes_title = ipw.HTML( - """
-

Codes

""" - ) - codes_help = ipw.HTML( - """
Select the code to use for running the calculations. The codes - on the local machine (localhost) are installed by default, but you can - configure new ones on potentially more powerful machines by clicking on - "Setup new code".
""" - ) - - # This number provides a rough estimate for how many MPI tasks are needed - # for a given structure. - NUM_SITES_PER_MPI_TASK_DEFAULT = 6 - - # Warn the user if they are trying to run calculations for a large - # structure on localhost. - RUN_ON_LOCALHOST_NUM_SITES_WARN_THRESHOLD = 10 - - # Put a limit on how many MPI tasks you want to run per k-pool by default - MAX_MPI_PER_POOL = 20 - - input_structure = traitlets.Instance(StructureData, allow_none=True) - process = traitlets.Instance(WorkChainNode, allow_none=True) - previous_step_state = traitlets.UseEnum(WizardAppWidgetStep.State) - workchain_settings = traitlets.Instance(WorkChainSettings, allow_none=True) - pseudo_family_selector = traitlets.Instance(PseudoFamilySelector, allow_none=True) - advanced_settings = traitlets.Instance(AdvancedSettings, allow_none=True) - _submission_blockers = traitlets.List(traitlets.Unicode()) - - def __init__(self, qe_auto_setup=True, **kwargs): - self.message_area = ipw.Output() - self._submission_blocker_messages = ipw.HTML() - - self.pw_code = ComputationalResourcesWidget( - description="pw.x:", default_calc_job_plugin="quantumespresso.pw" - ) - self.dos_code = ComputationalResourcesWidget( - description="dos.x:", - default_calc_job_plugin="quantumespresso.dos", - ) - self.projwfc_code = ComputationalResourcesWidget( - description="projwfc.x:", - default_calc_job_plugin="quantumespresso.projwfc", - ) - - self.resources_config = ResourceSelectionWidget() - self.parallelization = ParallelizationSettings() - - self.set_selected_codes(DEFAULT_PARAMETERS) - self.set_resource_defaults() - - self.pw_code.observe(self._update_state, "value") - self.pw_code.observe(self._update_resources, "value") - self.dos_code.observe(self._update_state, "value") - self.projwfc_code.observe(self._update_state, "value") - - self.submit_button = ipw.Button( - description="Submit", - tooltip="Submit the calculation with the selected parameters.", - icon="play", - button_style="success", - layout=ipw.Layout(width="auto", flex="1 1 auto"), - disabled=True, - ) - - self.submit_button.on_click(self._on_submit_button_clicked) - - # The SSSP installation status widget shows the installation status of - # the SSSP pseudo potentials and triggers the installation in case that - # they are not yet installed. The widget will remain in a "busy" state - # in case that the installation was already triggered elsewhere, e.g., - # by the start up scripts. The submission is blocked while the - # potentials are not yet installed. - self.sssp_installation_status = SSSPInstallWidget(auto_start=qe_auto_setup) - self.sssp_installation_status.observe(self._update_state, ["busy", "installed"]) - self.sssp_installation_status.observe(self._toggle_install_widgets, "installed") - - # The QE setup widget checks whether there are codes that match specific - # expected labels (e.g. "pw-7.2@localhost") and triggers both the - # installation of QE into a dedicated conda environment and the setup of - # the codes in case that they are not already configured. - self.qe_setup_status = QESetupWidget(auto_start=qe_auto_setup) - self.qe_setup_status.observe(self._update_state, "busy") - self.qe_setup_status.observe(self._toggle_install_widgets, "installed") - self.qe_setup_status.observe(self._auto_select_code, "installed") - - super().__init__( - children=[ - self.codes_title, - self.codes_help, - self.pw_code, - self.dos_code, - self.projwfc_code, - self.resources_config, - self.parallelization, - self.message_area, - self.sssp_installation_status, - self.qe_setup_status, - self._submission_blocker_messages, - self.submit_button, - ] - ) - - @traitlets.observe("_submission_blockers") - def _observe_submission_blockers(self, change): - if change["new"]: - fmt_list = "\n".join((f"
  • {item}
  • " for item in sorted(change["new"]))) - self._submission_blocker_messages.value = f""" -
    - The submission is blocked, due to the following reason(s): -
    """ - else: - self._submission_blocker_messages.value = "" - - def _identify_submission_blockers(self): - # Do not submit while any of the background setup processes are running. - if self.qe_setup_status.busy or self.sssp_installation_status.busy: - yield "Background setup processes must finish." - - # No code selected (this is ignored while the setup process is running). - if self.pw_code.value is None and not self.qe_setup_status.busy: - yield ("No pw code selected") - - # No code selected for pdos (this is ignored while the setup process is running). - if ( - self.workchain_settings.pdos_run.value - and (self.dos_code.value is None or self.projwfc_code.value is None) - and not self.qe_setup_status.busy - ): - yield "Calculating the PDOS requires both dos.x and projwfc.x to be set." - - # SSSP library not installed - if not self.sssp_installation_status.installed: - yield "The SSSP library is not installed." - - if ( - self.workchain_settings.pdos_run.value - and not any( - [ - self.pw_code.value is None, - self.dos_code.value is None, - self.projwfc_code.value is None, - ] - ) - and len( - set( - ( - load_code(self.pw_code.value).computer.pk, - load_code(self.dos_code.value).computer.pk, - load_code(self.projwfc_code.value).computer.pk, - ) - ) - ) - != 1 - ): - yield ( - "All selected codes must be installed on the same computer. This is because the " - "PDOS calculations rely on large files that are not retrieved by AiiDA." - ) - - def _update_state(self, _=None): - # If the previous step has failed, this should fail as well. - if self.previous_step_state is self.State.FAIL: - self.state = self.State.FAIL - return - # Do not interact with the user if they haven't successfully completed the previous step. - elif self.previous_step_state is not self.State.SUCCESS: - self.state = self.State.INIT - return - - # Process is already running. - if self.process is not None: - self.state = self.State.SUCCESS - return - - blockers = list(self._identify_submission_blockers()) - if any(blockers): - self._submission_blockers = blockers - self.state = self.State.READY - return - - self._submission_blockers = [] - self.state = self.state.CONFIGURED - - def _toggle_install_widgets(self, change): - if change["new"]: - self.children = [ - child for child in self.children if child is not change["owner"] - ] - - def _auto_select_code(self, change): - if change["new"] and not change["old"]: - for code in [ - "pw_code", - "dos_code", - "projwfc_code", - ]: - try: - code_widget = getattr(self, code) - code_widget.refresh() - code_widget.value = load_code(DEFAULT_PARAMETERS[code]).uuid - except NotExistent: - pass - - _ALERT_MESSAGE = """ -
    - × - × - {message} -
    """ - - def _show_alert_message(self, message, alert_class="info"): - with self.message_area: - display( - ipw.HTML( - self._ALERT_MESSAGE.format(alert_class=alert_class, message=message) - ) - ) - - def _update_resources(self, change): - if change["new"] and ( - change["old"] is None - or load_code(change["new"]).computer.pk - != load_code(change["old"]).computer.pk - ): - self.set_resource_defaults(load_code(change["new"]).computer) - - def set_resource_defaults(self, computer=None): - if computer is None or computer.hostname == "localhost": - self.resources_config.num_nodes.disabled = True - self.resources_config.num_nodes.value = 1 - self.resources_config.num_cpus.max = os.cpu_count() - self.resources_config.num_cpus.value = 1 - self.resources_config.num_cpus.description = "CPUs" - self.parallelization.npools.value = 1 - else: - default_mpiprocs = computer.get_default_mpiprocs_per_machine() - self.resources_config.num_nodes.disabled = False - self.resources_config.num_cpus.max = default_mpiprocs - self.resources_config.num_cpus.value = default_mpiprocs - self.resources_config.num_cpus.description = "CPUs/node" - self.parallelization.npools.value = self._get_default_parallelization() - - self._check_resources() - - def _get_default_parallelization(self): - """A _very_ rudimentary approach for obtaining a minimal npools setting.""" - num_mpiprocs = ( - self.resources_config.num_nodes.value * self.resources_config.num_cpus.value - ) - - for i in range(1, num_mpiprocs + 1): - if num_mpiprocs % i == 0 and num_mpiprocs // i < self.MAX_MPI_PER_POOL: - return i - - def _check_resources(self): - """Check whether the currently selected resources will be sufficient and warn if not.""" - if not self.pw_code.value: - return # No code selected, nothing to do. - - num_cpus = self.resources_config.num_cpus.value - on_localhost = load_node(self.pw_code.value).computer.hostname == "localhost" - if self.pw_code.value and on_localhost and num_cpus > 1: - self._show_alert_message( - "The selected code would be executed on the local host, but " - "the number of CPUs is larger than one. Please review " - "the configuration and consider to select a code that runs " - "on a larger system if necessary.", - alert_class="warning", - ) - elif ( - self.input_structure - and on_localhost - and len(self.input_structure.sites) - > self.RUN_ON_LOCALHOST_NUM_SITES_WARN_THRESHOLD - ): - self._show_alert_message( - "The selected code would be executed on the local host, but the " - "number of sites of the selected structure is relatively large. " - "Consider to select a code that runs on a larger system if " - "necessary.", - alert_class="warning", - ) - - @traitlets.observe("state") - def _observe_state(self, change): - with self.hold_trait_notifications(): - self.submit_button.disabled = change["new"] != self.State.CONFIGURED - - @traitlets.observe("previous_step_state") - def _observe_input_structure(self, _): - self._update_state() - self.set_pdos_status() - - @traitlets.observe("process") - def _observe_process(self, change): - with self.hold_trait_notifications(): - process_node = change["new"] - if process_node is not None: - self.input_structure = process_node.inputs.structure - builder_parameters = process_node.base.extras.get( - "builder_parameters", None - ) - if builder_parameters is not None: - self.set_selected_codes(builder_parameters) - self._update_state() - - def _on_submit_button_clicked(self, _): - self.submit_button.disabled = True - self.submit() - - def set_selected_codes(self, parameters): - """Set the inputs in the GUI based on a set of parameters.""" - - # Codes - def _get_code_uuid(code): - if code is not None: - try: - return load_code(code).uuid - except NotExistent: - return None - - with self.hold_trait_notifications(): - # Codes - self.pw_code.value = _get_code_uuid(parameters["pw_code"]) - self.dos_code.value = _get_code_uuid(parameters["dos_code"]) - self.projwfc_code.value = _get_code_uuid(parameters["projwfc_code"]) - - def set_pdos_status(self): - if self.workchain_settings.pdos_run.value: - self.dos_code.code_select_dropdown.disabled = False - self.projwfc_code.code_select_dropdown.disabled = False - else: - self.dos_code.code_select_dropdown.disabled = True - self.projwfc_code.code_select_dropdown.disabled = True - - def submit(self, _=None): - """Submit the work chain with the current inputs.""" - builder = self._create_builder() - extra_parameters = self._create_extra_report_parameters() - - with self.hold_trait_notifications(): - self.process = submit(builder) - - # Set the builder parameters on the work chain - builder_parameters = self._extract_report_parameters( - builder, extra_parameters - ) - self.process.base.extras.set("builder_parameters", builder_parameters) - - self._update_state() - - def _get_qe_workchain_parameters(self) -> QeWorkChainParameters: - """Get the parameters of the `QeWorkChain` from widgets.""" - # create the the initial_magnetic_moments as None (Default) - initial_magnetic_moments = None - # create the override parameters for sub PwBaseWorkChain - pw_overrides = {"base": {}, "scf": {}, "nscf": {}, "band": {}} - for key in ["base", "scf", "nscf", "band"]: - if self.pseudo_family_selector.override_protocol_pseudo_family.value: - pw_overrides[key]["pseudo_family"] = self.pseudo_family_selector.value - if self.advanced_settings.override.value: - pw_overrides[key]["pw"] = {"parameters": {"SYSTEM": {}}} - if self.advanced_settings.tot_charge.override.value: - pw_overrides[key]["pw"]["parameters"]["SYSTEM"][ - "tot_charge" - ] = self.advanced_settings.tot_charge.charge.value - if ( - self.advanced_settings.magnetization.override.value - and self.workchain_settings.spin_type.value == "collinear" - ): - initial_magnetic_moments = ( - self.advanced_settings.magnetization.get_magnetization() - ) - - if key in ["base", "scf"]: - if self.advanced_settings.kpoints.override.value: - pw_overrides[key][ - "kpoints_distance" - ] = self.advanced_settings.kpoints.distance.value - if ( - self.advanced_settings.smearing.override.value - and self.workchain_settings.electronic_type.value == "metal" - ): - # smearing type setting - pw_overrides[key]["pw"]["parameters"]["SYSTEM"][ - "smearing" - ] = self.advanced_settings.smearing.smearing.value - - # smearing degauss setting - pw_overrides[key]["pw"]["parameters"]["SYSTEM"][ - "degauss" - ] = self.advanced_settings.smearing.degauss.value - - overrides = { - "relax": { - "base": pw_overrides["base"], - }, - "bands": { - "scf": pw_overrides["scf"], - "bands": pw_overrides["band"], - }, - "pdos": { - "scf": pw_overrides["scf"], - "nscf": pw_overrides["nscf"], - }, - } - - # Work chain settings - relax_type = self.workchain_settings.relax_type.value - electronic_type = self.workchain_settings.electronic_type.value - spin_type = self.workchain_settings.spin_type.value - - run_bands = self.workchain_settings.bands_run.value - run_pdos = self.workchain_settings.pdos_run.value - protocol = self.workchain_settings.workchain_protocol.value - - properties = [] - - if run_bands: - properties.append("bands") - if run_pdos: - properties.append("pdos") - - if RelaxType(relax_type) is not RelaxType.NONE or not (run_bands or run_pdos): - properties.append("relax") - - return QeWorkChainParameters( - protocol=protocol, - relax_type=relax_type, - properties=properties, - spin_type=spin_type, - electronic_type=electronic_type, - overrides=overrides, - initial_magnetic_moments=initial_magnetic_moments, - ) - - def _create_builder(self) -> ProcessBuilderNamespace: - """Create the builder for the `QeAppWorkChain` submit.""" - pw_code = self.pw_code.value - dos_code = self.dos_code.value - projwfc_code = self.projwfc_code.value - - parameters = self._get_qe_workchain_parameters() - - builder = QeAppWorkChain.get_builder_from_protocol( - structure=self.input_structure, - pw_code=load_code(pw_code), - dos_code=load_code(dos_code), - projwfc_code=load_code(projwfc_code), - protocol=parameters.protocol, - relax_type=RelaxType(parameters.relax_type), - properties=parameters.properties, - spin_type=SpinType(parameters.spin_type), - electronic_type=ElectronicType(parameters.electronic_type), - overrides=parameters.overrides, - initial_magnetic_moments=parameters.initial_magnetic_moments, - ) - - resources = { - "num_machines": self.resources_config.num_nodes.value, - "num_mpiprocs_per_machine": self.resources_config.num_cpus.value, - } - - npool = self.parallelization.npools.value - self._update_builder(builder, resources, npool, self.MAX_MPI_PER_POOL) - - return builder - - def _update_builder(self, buildy, resources, npools, max_mpi_per_pool): - """Update the resources and parallelization of the ``QeAppWorkChain`` builder.""" - for k, v in buildy.items(): - if isinstance(v, (dict, ProcessBuilderNamespace)): - if k == "pw" and v["pseudos"]: - v["parallelization"] = Dict(dict={"npool": npools}) - if k == "projwfc": - v["settings"] = Dict(dict={"cmdline": ["-nk", str(npools)]}) - if k == "dos": - v["metadata"]["options"]["resources"] = { - "num_machines": 1, - "num_mpiprocs_per_machine": min( - max_mpi_per_pool, - resources["num_mpiprocs_per_machine"], - ), - } - # Continue to the next item to avoid overriding the resources in the - # recursive `update_builder` call. - continue - if k == "resources": - buildy["resources"] = resources - else: - self._update_builder(v, resources, npools, max_mpi_per_pool) - - def _create_extra_report_parameters(self) -> dict[str, t.Any]: - """This method will also create a dictionary of the parameters that were not - readably represented in the builder, which will be used to the report. - It is stored in the `extra_report_parameters`. - """ - qe_workchain_parameters = self._get_qe_workchain_parameters() - - # Construct the extra report parameters needed for the report - extra_report_parameters = { - "relax_type": qe_workchain_parameters.relax_type, - "electronic_type": qe_workchain_parameters.electronic_type, - "spin_type": qe_workchain_parameters.spin_type, - "protocol": qe_workchain_parameters.protocol, - "initial_magnetic_moments": qe_workchain_parameters.initial_magnetic_moments, - } - - # update pseudo family information to extra_report_parameters - if self.pseudo_family_selector.override_protocol_pseudo_family.value: - # If the pseudo family is overridden, use that - pseudo_family = self.pseudo_family_selector.value - else: - # otherwise extract the information from protocol - pseudo_family = PROTOCOL_PSEUDO_MAP[qe_workchain_parameters.protocol] - - pseudo_family_info = pseudo_family.split("/") - if pseudo_family_info[0] == "SSSP": - pseudo_protocol = pseudo_family_info[3] - elif pseudo_family_info[0] == "PseudoDojo": - pseudo_protocol = pseudo_family_info[4] - extra_report_parameters.update( - { - "pseudo_family": pseudo_family, - "pseudo_library": pseudo_family_info[0], - "pseudo_version": pseudo_family_info[1], - "functional": pseudo_family_info[2], - "pseudo_protocol": pseudo_protocol, - } - ) - - # store codes info into extra_report_parameters for loading the process - pw_code = self.pw_code.value - dos_code = self.dos_code.value - projwfc_code = self.projwfc_code.value - - extra_report_parameters.update( - { - "pw_code": pw_code, - "dos_code": dos_code, - "projwfc_code": projwfc_code, - } - ) - - return extra_report_parameters - - @staticmethod - def _extract_report_parameters( - builder, extra_report_parameters - ) -> dict[str, t.Any]: - """Extract (recover) the parameters for report from the builder. - - There are some parameters that are not stored in the builder, but can be extracted - directly from the widgets, such as the ``pseudo_family`` and ``relax_type``. - """ - parameters = { - "run_relax": "relax" in builder.properties, - "run_bands": "bands" in builder.properties, - "run_pdos": "pdos" in builder.properties, - } - - # Extract the pw calculation parameters from the builder - - # energy_cutoff is same for all pw calculations when pseudopotentials are fixed - # as well as the smearing settings (semaring and degauss) and scf kpoints distance - # read from the first pw calculation of relax workflow. - # It is safe then to extract these parameters from the first pw calculation, since the - # builder is anyway set with subworkchain inputs even it is not run which controlled by - # the properties inputs. - energy_cutoff_wfc = builder.relax.base["pw"]["parameters"]["SYSTEM"]["ecutwfc"] - energy_cutoff_rho = builder.relax.base["pw"]["parameters"]["SYSTEM"]["ecutrho"] - occupation = builder.relax.base["pw"]["parameters"]["SYSTEM"]["occupations"] - scf_kpoints_distance = builder.relax.base.kpoints_distance.value - - parameters.update( - { - "energy_cutoff_wfc": energy_cutoff_wfc, - "energy_cutoff_rho": energy_cutoff_rho, - "occupation": occupation, - "scf_kpoints_distance": scf_kpoints_distance, - } - ) - - if occupation == "smearing": - parameters["degauss"] = builder.relax.base["pw"]["parameters"]["SYSTEM"][ - "degauss" - ] - parameters["smearing"] = builder.relax.base["pw"]["parameters"]["SYSTEM"][ - "smearing" - ] - - parameters[ - "bands_kpoints_distance" - ] = builder.bands.bands_kpoints_distance.value - parameters["nscf_kpoints_distance"] = builder.pdos.nscf.kpoints_distance.value - - parameters["tot_charge"] = builder.relax.base["pw"]["parameters"]["SYSTEM"].get( - "tot_charge", 0.0 - ) - - # parameters from extra_report_parameters - for k, v in extra_report_parameters.items(): - parameters.update({k: v}) - - return parameters - - def reset(self): - with self.hold_trait_notifications(): - self.process = None - self.input_structure = None - - -class ViewQeAppWorkChainStatusAndResultsStep(ipw.VBox, WizardAppWidgetStep): - process = traitlets.Unicode(allow_none=True) - - def __init__(self, **kwargs): - self.process_tree = ProcessNodesTreeWidget() - ipw.dlink( - (self, "process"), - (self.process_tree, "value"), - ) - - self.node_view = AiidaNodeViewWidget(layout={"width": "auto", "height": "auto"}) - ipw.dlink( - (self.process_tree, "selected_nodes"), - (self.node_view, "node"), - transform=lambda nodes: nodes[0] if nodes else None, - ) - self.process_status = ipw.VBox(children=[self.process_tree, self.node_view]) - - # Setup process monitor - self.process_monitor = ProcessMonitor( - timeout=0.2, - callbacks=[ - self.process_tree.update, - self._update_state, - ], - ) - ipw.dlink((self, "process"), (self.process_monitor, "value")) - - super().__init__([self.process_status], **kwargs) - - def can_reset(self): - "Do not allow reset while process is running." - return self.state is not self.State.ACTIVE - - def reset(self): - self.process = None - - def _update_state(self): - if self.process is None: - self.state = self.State.INIT - else: - process = load_node(self.process) - process_state = process.process_state - if process_state in ( - ProcessState.CREATED, - ProcessState.RUNNING, - ProcessState.WAITING, - ): - self.state = self.State.ACTIVE - elif ( - process_state in (ProcessState.EXCEPTED, ProcessState.KILLED) - or process.is_failed - ): - self.state = self.State.FAIL - elif process.is_finished_ok: - self.state = self.State.SUCCESS - - @traitlets.observe("process") - def _observe_process(self, change): - self._update_state() diff --git a/qe.ipynb b/qe.ipynb index 128a688b..509708b7 100644 --- a/qe.ipynb +++ b/qe.ipynb @@ -54,14 +54,12 @@ "from jinja2 import Environment\n", "\n", "from aiidalab_qe.app import static\n", - "from aiidalab_qe.app.process import QeAppWorkChainSelector\n", - "from aiidalab_qe.app.steps import (\n", - " ConfigureQeAppWorkChainStep,\n", - " SubmitQeAppWorkChainStep,\n", - " ViewQeAppWorkChainStatusAndResultsStep,\n", - ")\n", - "from aiidalab_qe.app.structures import Examples, StructureSelectionStep\n", - "from aiidalab_qe.app.widgets import AddingTagsEditor\n", + "from aiidalab_qe.app.common.process import QeAppWorkChainSelector\n", + "from aiidalab_qe.app.common.widgets import AddingTagsEditor\n", + "from aiidalab_qe.app.configuration import ConfigureQeAppWorkChainStep\n", + "from aiidalab_qe.app.result import ViewQeAppWorkChainStatusAndResultsStep\n", + "from aiidalab_qe.app.structure import Examples, StructureSelectionStep\n", + "from aiidalab_qe.app.submission import SubmitQeAppWorkChainStep\n", "from aiidalab_qe.version import __version__\n", "\n", "OptimadeQueryWidget.title = \"OPTIMADE\" # monkeypatch\n", diff --git a/setup.cfg b/setup.cfg index c24ae311..ccc6cb66 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,8 @@ project_urls = Logo = https://raw.githubusercontent.com/aiidalab/aiidalab-qe/master/miscellaneous/logos/QE.jpg [options] +package_dir = + = src packages = find: install_requires = aiida-core~=2.2,<3 @@ -32,6 +34,9 @@ install_requires = pydantic~=1.10,>=1.10.8 python_requires = >=3.8 +[options.packages.find] +where = src + [options.extras_require] dev = bumpver~=2023.1124 @@ -46,7 +51,8 @@ dev = selenium~=4.7.0 [options.package_data] -aiidalab_qe.parameters = qeapp.yaml +aiidalab_qe.app.parameters = qeapp.yaml +aiidalab_qe.app.static = * [aiidalab] title = Quantum ESPRESSO @@ -69,7 +75,7 @@ tag = True push = True [bumpver:file_patterns] -aiidalab_qe/version.py = +src/aiidalab_qe/version.py = __version__ = "{version}" setup.cfg = current_version = "{version}" diff --git a/setup.py b/setup.py deleted file mode 100644 index b908cbe5..00000000 --- a/setup.py +++ /dev/null @@ -1,3 +0,0 @@ -import setuptools - -setuptools.setup() diff --git a/aiidalab_qe/__init__.py b/src/aiidalab_qe/__init__.py similarity index 100% rename from aiidalab_qe/__init__.py rename to src/aiidalab_qe/__init__.py diff --git a/aiidalab_qe/__main__.py b/src/aiidalab_qe/__main__.py similarity index 80% rename from aiidalab_qe/__main__.py rename to src/aiidalab_qe/__main__.py index 963fe26d..304c05a6 100644 --- a/aiidalab_qe/__main__.py +++ b/src/aiidalab_qe/__main__.py @@ -3,9 +3,9 @@ import click -from aiidalab_qe.app.setup_codes import codes_are_setup -from aiidalab_qe.app.setup_codes import install as install_qe_codes -from aiidalab_qe.app.sssp import install as setup_sssp +from aiidalab_qe.app.common.setup_codes import codes_are_setup +from aiidalab_qe.app.common.setup_codes import install as install_qe_codes +from aiidalab_qe.app.submission.sssp import install as setup_sssp @click.group() diff --git a/src/aiidalab_qe/app/__init__.py b/src/aiidalab_qe/app/__init__.py new file mode 100644 index 00000000..2f5c6b20 --- /dev/null +++ b/src/aiidalab_qe/app/__init__.py @@ -0,0 +1,15 @@ +"""Package for the AiiDAlab QE app.""" + +from .common import WorkChainSelector +from .configuration import ConfigureQeAppWorkChainStep +from .result import ViewQeAppWorkChainStatusAndResultsStep +from .structure import StructureSelectionStep +from .submission import SubmitQeAppWorkChainStep + +__all__ = [ + "StructureSelectionStep", + "ConfigureQeAppWorkChainStep", + "SubmitQeAppWorkChainStep", + "ViewQeAppWorkChainStatusAndResultsStep", + "WorkChainSelector", +] diff --git a/src/aiidalab_qe/app/common/__init__.py b/src/aiidalab_qe/app/common/__init__.py new file mode 100644 index 00000000..8dfca259 --- /dev/null +++ b/src/aiidalab_qe/app/common/__init__.py @@ -0,0 +1,7 @@ +# trigger registration of the viewer widget: +from .node_view import CalcJobNodeViewerWidget # noqa: F401 +from .process import WorkChainSelector + +__all__ = [ + "WorkChainSelector", +] diff --git a/src/aiidalab_qe/app/common/node_view.py b/src/aiidalab_qe/app/common/node_view.py new file mode 100644 index 00000000..8e7fcdd2 --- /dev/null +++ b/src/aiidalab_qe/app/common/node_view.py @@ -0,0 +1,112 @@ +"""Results view widgets (MOVE TO OTHER MODULE!) + +Authors: AiiDAlab team +""" + +import ipywidgets as ipw +import nglview +import traitlets as tl +from aiida import orm +from aiidalab_widgets_base import register_viewer_widget +from ase import Atoms + +from .widgets import CalcJobOutputFollower, LogOutputWidget + + +class MinimalStructureViewer(ipw.VBox): + structure = tl.Union([tl.Instance(Atoms), tl.Instance(orm.Node)], allow_none=True) + _displayed_structure = tl.Instance(Atoms, allow_none=True, read_only=True) + + background = tl.Unicode() + supercell = tl.List(tl.Int()) + + def __init__(self, structure, *args, **kwargs): + self._viewer = nglview.NGLWidget() + self._viewer.camera = "orthographic" + self._viewer.stage.set_parameters(mouse_preset="pymol") + ipw.link((self, "background"), (self._viewer, "background")) + + self.structure = structure + + super().__init__( + children=[ + self._viewer, + ], + *args, + **kwargs, + ) + + @tl.default("background") + def _default_background(self): + return "#FFFFFF" + + @tl.default("supercell") + def _default_supercell(self): + return [1, 1, 1] + + @tl.validate("structure") + def _valid_structure(self, change): # pylint: disable=no-self-use + """Update structure.""" + structure = change["value"] + + if structure is None: + return None # if no structure provided, the rest of the code can be skipped + + if isinstance(structure, Atoms): + return structure + if isinstance(structure, orm.Node): + return structure.get_ase() + raise ValueError( + "Unsupported type {}, structure must be one of the following types: " + "ASE Atoms object, AiiDA CifData or StructureData." + ) + + @tl.observe("structure") + def _update_displayed_structure(self, change): + """Update displayed_structure trait after the structure trait has been modified.""" + # Remove the current structure(s) from the viewer. + if change["new"] is not None: + self.set_trait("_displayed_structure", change["new"].repeat(self.supercell)) + else: + self.set_trait("_displayed_structure", None) + + @tl.observe("_displayed_structure") + def _update_structure_viewer(self, change): + """Update the view if displayed_structure trait was modified.""" + with self.hold_trait_notifications(): + for ( + comp_id + ) in self._viewer._ngl_component_ids: # pylint: disable=protected-access + self._viewer.remove_component(comp_id) + self.selection = list() + if change["new"] is not None: + self._viewer.add_component(nglview.ASEStructure(change["new"])) + self._viewer.clear() + self._viewer.stage.set_parameters(clipDist=0) + self._viewer.add_representation("unitcell", diffuse="#df0587") + self._viewer.add_representation("ball+stick", aspectRatio=3.5) + + +class VBoxWithCaption(ipw.VBox): + def __init__(self, caption, body, *args, **kwargs): + super().__init__(children=[ipw.HTML(caption), body], *args, **kwargs) + + +@register_viewer_widget("process.calculation.calcjob.CalcJobNode.") +class CalcJobNodeViewerWidget(ipw.VBox): + def __init__(self, calcjob, **kwargs): + self.calcjob = calcjob + self.output_follower = CalcJobOutputFollower() + self.log_output = LogOutputWidget() + + self.output_follower.calcjob_uuid = self.calcjob.uuid + self.output_follower.observe(self._observe_output_follower_lineno, ["lineno"]) + + super().__init__( + [ipw.HTML(f"CalcJob: {self.calcjob}"), self.log_output], **kwargs + ) + + def _observe_output_follower_lineno(self, _): + with self.hold_trait_notifications(): + self.log_output.filename = self.output_follower.filename + self.log_output.value = "\n".join(self.output_follower.output) diff --git a/aiidalab_qe/app/process.py b/src/aiidalab_qe/app/common/process.py similarity index 100% rename from aiidalab_qe/app/process.py rename to src/aiidalab_qe/app/common/process.py diff --git a/aiidalab_qe/app/setup_codes.py b/src/aiidalab_qe/app/common/setup_codes.py similarity index 99% rename from aiidalab_qe/app/setup_codes.py rename to src/aiidalab_qe/app/common/setup_codes.py index 15aadebf..27056251 100644 --- a/aiidalab_qe/app/setup_codes.py +++ b/src/aiidalab_qe/app/common/setup_codes.py @@ -10,7 +10,7 @@ from aiida.orm import load_code from filelock import FileLock, Timeout -from aiidalab_qe.app.widgets import ProgressBar +from aiidalab_qe.app.common.widgets import ProgressBar __all__ = [ "QESetupWidget", diff --git a/aiidalab_qe/app/widgets.py b/src/aiidalab_qe/app/common/widgets.py similarity index 85% rename from aiidalab_qe/app/widgets.py rename to src/aiidalab_qe/app/common/widgets.py index c498318a..778fd38a 100644 --- a/aiidalab_qe/app/widgets.py +++ b/src/aiidalab_qe/app/common/widgets.py @@ -16,7 +16,6 @@ import numpy as np import traitlets from aiida.orm import CalcJobNode, load_node -from aiidalab_widgets_base import register_viewer_widget from aiidalab_widgets_base.utils import ( StatusHTML, list_to_string_range, @@ -24,9 +23,6 @@ ) from IPython.display import HTML, Javascript, clear_output, display -# trigger registration of the viewer widget: -from aiidalab_qe.app import node_view # noqa: F401 - __all__ = [ "CalcJobOutputFollower", "LogOutputWidget", @@ -350,105 +346,6 @@ def _pull_output(self): self._output_queue.task_done() -@register_viewer_widget("process.calculation.calcjob.CalcJobNode.") -class CalcJobNodeViewerWidget(ipw.VBox): - def __init__(self, calcjob, **kwargs): - self.calcjob = calcjob - self.output_follower = CalcJobOutputFollower() - self.log_output = LogOutputWidget() - - self.output_follower.calcjob_uuid = self.calcjob.uuid - self.output_follower.observe(self._observe_output_follower_lineno, ["lineno"]) - - super().__init__( - [ipw.HTML(f"CalcJob: {self.calcjob}"), self.log_output], **kwargs - ) - - def _observe_output_follower_lineno(self, _): - with self.hold_trait_notifications(): - self.log_output.filename = self.output_follower.filename - self.log_output.value = "\n".join(self.output_follower.output) - - -class ResourceSelectionWidget(ipw.VBox): - """Widget for the selection of compute resources.""" - - title = ipw.HTML( - """
    -

    Resources

    -
    """ - ) - prompt = ipw.HTML( - """
    -

    - Specify the resources to use for the pw.x calculation. -

    """ - ) - - def __init__(self, **kwargs): - extra = { - "style": {"description_width": "150px"}, - "layout": {"min_width": "180px"}, - } - self.num_nodes = ipw.BoundedIntText( - value=1, step=1, min=1, max=1000, description="Nodes", **extra - ) - self.num_cpus = ipw.BoundedIntText( - value=1, step=1, min=1, description="CPUs", **extra - ) - - super().__init__( - children=[ - self.title, - ipw.HBox( - children=[self.prompt, self.num_nodes, self.num_cpus], - layout=ipw.Layout(justify_content="space-between"), - ), - ] - ) - - def reset(self): - self.num_nodes.value = 1 - self.num_cpus.value = 1 - - -class ParallelizationSettings(ipw.VBox): - """Widget for setting the parallelization settings.""" - - title = ipw.HTML( - """
    -

    Parallelization

    -
    """ - ) - prompt = ipw.HTML( - """
    -

    - Specify the number of k-points pools for the calculations. -

    """ - ) - - def __init__(self, **kwargs): - extra = { - "style": {"description_width": "150px"}, - "layout": {"min_width": "180px"}, - } - self.npools = ipw.BoundedIntText( - value=1, step=1, min=1, max=128, description="Number of k-pools", **extra - ) - super().__init__( - children=[ - self.title, - ipw.HBox( - children=[self.prompt, self.npools], - layout=ipw.Layout(justify_content="space-between"), - ), - ] - ) - - def reset(self): - self.npools.value = 1 - - class ProgressBar(ipw.HBox): class AnimationRate(float): pass diff --git a/src/aiidalab_qe/app/configuration/__init__.py b/src/aiidalab_qe/app/configuration/__init__.py new file mode 100644 index 00000000..00110ee0 --- /dev/null +++ b/src/aiidalab_qe/app/configuration/__init__.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +"""Widgets for the submission of bands work chains. + +Authors: AiiDAlab team +""" +from __future__ import annotations + +import ipywidgets as ipw +import traitlets as tl +from aiida import orm +from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain +from aiidalab_widgets_base import WizardAppWidgetStep + +from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS + +from .advanced import AdvancedSettings +from .pseudos import PseudoFamilySelector +from .workflow import WorkChainSettings + + +class ConfigureQeAppWorkChainStep(ipw.VBox, WizardAppWidgetStep): + confirmed = tl.Bool() + previous_step_state = tl.UseEnum(WizardAppWidgetStep.State) + workchain_settings = tl.Instance(WorkChainSettings, allow_none=True) + pseudo_family_selector = tl.Instance(PseudoFamilySelector, allow_none=True) + advanced_settings = tl.Instance(AdvancedSettings, allow_none=True) + input_structure = tl.Instance(orm.StructureData, allow_none=True) + + def __init__(self, **kwargs): + self.workchain_settings = WorkChainSettings() + self.workchain_settings.relax_type.observe(self._update_state, "value") + self.workchain_settings.bands_run.observe(self._update_state, "value") + self.workchain_settings.pdos_run.observe(self._update_state, "value") + + self.pseudo_family_selector = PseudoFamilySelector() + self.advanced_settings = AdvancedSettings() + + ipw.dlink( + (self.workchain_settings.workchain_protocol, "value"), + (self.advanced_settings.kpoints, "kpoints_distance_default"), + lambda protocol: PwBaseWorkChain.get_protocol_inputs(protocol)[ + "kpoints_distance" + ], + ) + + ipw.dlink( + (self.workchain_settings.workchain_protocol, "value"), + (self.advanced_settings.smearing, "degauss_default"), + lambda protocol: PwBaseWorkChain.get_protocol_inputs(protocol)["pw"][ + "parameters" + ]["SYSTEM"]["degauss"], + ) + + ipw.dlink( + (self.workchain_settings.workchain_protocol, "value"), + (self.advanced_settings.smearing, "smearing_default"), + lambda protocol: PwBaseWorkChain.get_protocol_inputs(protocol)["pw"][ + "parameters" + ]["SYSTEM"]["smearing"], + ) + + self.tab = ipw.Tab( + children=[ + self.workchain_settings, + ipw.VBox( + children=[ + self.advanced_settings, + self.pseudo_family_selector, + ] + ), + ], + layout=ipw.Layout(min_height="250px"), + ) + + self.tab.set_title(0, "Workflow") + self.tab.set_title(1, "Advanced settings") + + self._submission_blocker_messages = ipw.HTML() + + self.confirm_button = ipw.Button( + description="Confirm", + tooltip="Confirm the currently selected settings and go to the next step.", + button_style="success", + icon="check-circle", + disabled=True, + layout=ipw.Layout(width="auto"), + ) + + self.confirm_button.on_click(self.confirm) + + super().__init__( + children=[ + self.tab, + self._submission_blocker_messages, + self.confirm_button, + ], + **kwargs, + ) + + @tl.observe("input_structure") + def _update_input_structure(self, change): + if self.input_structure is not None: + self.advanced_settings.magnetization._update_widget(change) + + @tl.observe("previous_step_state") + def _observe_previous_step_state(self, change): + self._update_state() + + def set_input_parameters(self, parameters): + """Set the inputs in the GUI based on a set of parameters.""" + + with self.hold_trait_notifications(): + # Work chain settings + self.workchain_settings.relax_type.value = parameters["relax_type"] + self.workchain_settings.spin_type.value = parameters["spin_type"] + self.workchain_settings.electronic_type.value = parameters[ + "electronic_type" + ] + self.workchain_settings.bands_run.value = parameters["run_bands"] + self.workchain_settings.pdos_run.value = parameters["run_pdos"] + self.workchain_settings.workchain_protocol.value = parameters["protocol"] + + # Advanced settings + self.pseudo_family_selector.value = parameters["pseudo_family"] + if parameters.get("kpoints_distance_override", None) is not None: + self.advanced_settings.kpoints.distance.value = parameters[ + "kpoints_distance_override" + ] + self.advanced_settings.kpoints.override.value = True + if parameters.get("degauss_override", None) is not None: + self.advanced_settings.smearing.degauss.value = parameters[ + "degauss_override" + ] + self.advanced_settings.smearing.override.value = True + if parameters.get("smearing_override", None) is not None: + self.advanced_settings.smearing.smearing.value = parameters[ + "smearing_override" + ] + self.advanced_settings.smearing.override.value = True + + def _update_state(self, _=None): + if self.previous_step_state == self.State.SUCCESS: + self.confirm_button.disabled = False + self._submission_blocker_messages.value = "" + self.state = self.State.CONFIGURED + elif self.previous_step_state == self.State.FAIL: + self.state = self.State.FAIL + else: + self.confirm_button.disabled = True + self.state = self.State.INIT + self.set_input_parameters(DEFAULT_PARAMETERS) + + def confirm(self, _=None): + self.confirm_button.disabled = False + self.state = self.State.SUCCESS + + @tl.default("state") + def _default_state(self): + return self.State.INIT + + def reset(self): + with self.hold_trait_notifications(): + self.set_input_parameters(DEFAULT_PARAMETERS) diff --git a/src/aiidalab_qe/app/configuration/advanced.py b/src/aiidalab_qe/app/configuration/advanced.py new file mode 100644 index 00000000..fef74521 --- /dev/null +++ b/src/aiidalab_qe/app/configuration/advanced.py @@ -0,0 +1,376 @@ +# -*- coding: utf-8 -*- +"""Widgets for the submission of bands work chains. + +Authors: AiiDAlab team +""" +import os + +import ipywidgets as ipw +import traitlets as tl +from aiida import orm +from IPython.display import clear_output, display + + +class AdvancedSettings(ipw.VBox): + title = ipw.HTML( + """
    +

    Advanced Settings

    """ + ) + description = ipw.HTML("""Select the advanced settings for the pw.x code.""") + + def __init__(self, **kwargs): + self.override = ipw.Checkbox( + description="Override", + indent=False, + value=False, + ) + self.smearing = SmearingSettings() + self.kpoints = KpointSettings() + self.tot_charge = TotalCharge() + self.magnetization = MagnetizationSettings() + self.list_overrides = [ + self.smearing.override, + self.kpoints.override, + self.tot_charge.override, + self.magnetization.override, + ] + for override in self.list_overrides: + ipw.dlink( + (self.override, "value"), + (override, "disabled"), + lambda override: not override, + ) + self.override.observe(self.set_advanced_settings, "value") + super().__init__( + children=[ + self.title, + ipw.HBox( + [ + self.description, + self.override, + ], + ), + self.tot_charge, + self.magnetization, + self.smearing, + self.kpoints, + ], + layout=ipw.Layout(justify_content="space-between"), + **kwargs, + ) + + def set_advanced_settings(self, _=None): + self.smearing.reset() + self.kpoints.reset() + self.tot_charge.reset() + self.magnetization.reset() + + +class TotalCharge(ipw.VBox): + """Widget to define the total charge of the simulation""" + + tot_charge_default = tl.Float(default_value=0.0) + + def __init__(self, **kwargs): + self.override = ipw.Checkbox( + description="Override", + indent=False, + value=False, + ) + self.charge = ipw.BoundedFloatText( + value=0, + min=-3, + max=3, + step=0.01, + disabled=False, + description="Total charge:", + style={"description_width": "initial"}, + ) + ipw.dlink( + (self.override, "value"), + (self.charge, "disabled"), + lambda override: not override, + ) + super().__init__( + children=[ + ipw.HBox( + [ + self.override, + self.charge, + ], + ), + ], + layout=ipw.Layout(justify_content="space-between"), + **kwargs, + ) + self.charge.observe(self.set_tot_charge, "value") + self.override.observe(self.set_tot_charge, "value") + + def set_tot_charge(self, _=None): + self.charge.value = ( + self.charge.value if self.override.value else self.tot_charge_default + ) + + def _update_settings(self, **kwargs): + """Update the override and override_tot_charge and override_tot_charge values by the given keyword arguments + Therefore the override checkbox is not updated and defaults to True""" + self.override.value = True + with self.hold_trait_notifications(): + if "tot_charge" in kwargs: + self.charge.value = kwargs["tot_charge"] + + def reset(self): + with self.hold_trait_notifications(): + self.charge.value = self.tot_charge_default + self.override.value = False + + +class MagnetizationSettings(ipw.VBox): + """Widget to set the initial magnetic moments for each kind names defined in the StructureData (StructureDtaa.get_kind_names()) + Usually these are the names of the elements in the StructureData + (For example 'C' , 'N' , 'Fe' . However the StructureData can have defined kinds like 'Fe1' and 'Fe2') + + The widget generate a dictionary that can be used to set initial_magnetic_moments in the builder of PwBaseWorkChain + + Attributes: + input_structure(StructureData): trait that containes the input_strucgure (confirmed structure from previous step) + """ + + input_structure = tl.Instance(orm.StructureData, allow_none=True) + + def __init__(self, **kwargs): + self.input_structure = orm.StructureData() + self.input_structure_labels = [] + self.description = ipw.HTML( + "Define magnetization: Input structure not confirmed" + ) + self.kinds = self.create_kinds_widget() + self.kinds_widget_out = ipw.Output() + self.override = ipw.Checkbox( + description="Override", + indent=False, + value=False, + ) + super().__init__( + children=[ + ipw.HBox( + [ + self.override, + self.description, + self.kinds_widget_out, + ], + ), + ], + layout=ipw.Layout(justify_content="space-between"), + **kwargs, + ) + self.display_kinds() + self.override.observe(self._disable_kinds_widgets, "value") + + def _disable_kinds_widgets(self, _=None): + for i in range(len(self.kinds.children)): + self.kinds.children[i].disabled = not self.override.value + + def reset(self): + self.override.value = False + if hasattr(self.kinds, "children") and self.kinds.children: + for i in range(len(self.kinds.children)): + self.kinds.children[i].value = 0.0 + + def create_kinds_widget(self): + if self.input_structure_labels: + widgets_list = [] + for kind_label in self.input_structure_labels: + kind_widget = ipw.BoundedFloatText( + description=kind_label, + min=-1, + max=1, + step=0.1, + value=0.0, + disabled=True, + ) + widgets_list.append(kind_widget) + kinds_widget = ipw.VBox(widgets_list) + else: + kinds_widget = None + + return kinds_widget + + def update_kinds_widget(self): + self.input_structure_labels = self.input_structure.get_kind_names() + self.kinds = self.create_kinds_widget() + self.description.value = "Define magnetization: " + self.display_kinds() + + def display_kinds(self): + if "PYTEST_CURRENT_TEST" not in os.environ and self.kinds: + with self.kinds_widget_out: + clear_output() + display(self.kinds) + + def _update_widget(self, change): + self.input_structure = change["new"] + self.update_kinds_widget() + + def get_magnetization(self): + """Method to generate the dictionary with the initial magnetic moments""" + magnetization = {} + for i in range(len(self.kinds.children)): + magnetization[self.input_structure_labels[i]] = self.kinds.children[i].value + return magnetization + + def _set_magnetization_values(self, **kwargs): + """Update used for conftest setting all magnetization to a value""" + self.override.value = True + with self.hold_trait_notifications(): + if "initial_magnetic_moments" in kwargs: + for i in range(len(self.kinds.children)): + self.kinds.children[i].value = kwargs["initial_magnetic_moments"] + + +class SmearingSettings(ipw.VBox): + smearing_description = ipw.HTML( + """

    + The smearing type and width is set by the chosen protocol. + Tick the box to override the default, not advised unless you've mastered smearing effects (click here for a discussion). +

    """ + ) + + # The default of `smearing` and `degauss` the type and width + # must be linked to the `protocol` + degauss_default = tl.Float(default_value=0.01) + smearing_default = tl.Unicode(default_value="cold") + + def __init__(self, **kwargs): + self.override = ipw.Checkbox( + description="Override", + indent=False, + value=False, + ) + self.smearing = ipw.Dropdown( + options=["cold", "gaussian", "fermi-dirac", "methfessel-paxton"], + value=self.smearing_default, + description="Smearing type:", + disabled=False, + style={"description_width": "initial"}, + ) + self.degauss = ipw.FloatText( + value=self.degauss_default, + step=0.005, + description="Smearing width (Ry):", + disabled=False, + style={"description_width": "initial"}, + ) + ipw.dlink( + (self.override, "value"), + (self.degauss, "disabled"), + lambda override: not override, + ) + ipw.dlink( + (self.override, "value"), + (self.smearing, "disabled"), + lambda override: not override, + ) + self.degauss.observe(self.set_smearing, "value") + self.smearing.observe(self.set_smearing, "value") + self.override.observe(self.set_smearing, "value") + + super().__init__( + children=[ + self.smearing_description, + ipw.HBox([self.override, self.smearing, self.degauss]), + ], + layout=ipw.Layout(justify_content="space-between"), + **kwargs, + ) + + def set_smearing(self, _=None): + self.degauss.value = ( + self.degauss.value if self.override.value else self.degauss_default + ) + self.smearing.value = ( + self.smearing.value if self.override.value else self.smearing_default + ) + + def _update_settings(self, **kwargs): + """Update the smearing and degauss values by the given keyword arguments + This is the same as the `set_smearing` method but without the observer. + Therefore the override checkbox is not updated and defaults to True""" + self.override.value = True + + with self.hold_trait_notifications(): + if "smearing" in kwargs: + self.smearing.value = kwargs["smearing"] + + if "degauss" in kwargs: + self.degauss.value = kwargs["degauss"] + + def reset(self): + with self.hold_trait_notifications(): + self.degauss.value = self.degauss_default + self.smearing.value = self.smearing_default + self.override.value = False + + +class KpointSettings(ipw.VBox): + kpoints_distance_description = ipw.HTML( + """
    + The k-points mesh density of the SCF calculation is set by the protocol. + The value below represents the maximum distance between the k-points in each direction of reciprocal space. + Tick the box to override the default, smaller is more accurate and costly.
    """ + ) + + # The default of `kpoints_distance` must be linked to the `protocol` + kpoints_distance_default = tl.Float(default_value=0.15) + + def __init__(self, **kwargs): + self.override = ipw.Checkbox( + description="Override", + indent=False, + value=False, + ) + self.distance = ipw.FloatText( + value=self.kpoints_distance_default, + step=0.05, + description="K-points distance (1/Å):", + disabled=False, + style={"description_width": "initial"}, + ) + ipw.dlink( + (self.override, "value"), + (self.distance, "disabled"), + lambda override: not override, + ) + self.distance.observe(self.set_kpoints_distance, "value") + self.override.observe(self.set_kpoints_distance, "value") + self.observe(self.set_kpoints_distance, "kpoints_distance_default") + + super().__init__( + children=[ + self.kpoints_distance_description, + ipw.HBox([self.override, self.distance]), + ], + layout=ipw.Layout(justify_content="space-between"), + **kwargs, + ) + + def set_kpoints_distance(self, _=None): + self.distance.value = ( + self.distance.value + if self.override.value + else self.kpoints_distance_default + ) + + def _update_settings(self, **kwargs): + """Update the kpoints_distance value by the given keyword arguments. + This is the same as the `set_kpoints_distance` method but without the observer. + """ + self.override.value = True + if "kpoints_distance" in kwargs: + self.distance.value = kwargs["kpoints_distance"] + + def reset(self): + with self.hold_trait_notifications(): + self.distance.value = self.kpoints_distance_default + self.override.value = False diff --git a/aiidalab_qe/app/pseudos.py b/src/aiidalab_qe/app/configuration/pseudos.py similarity index 100% rename from aiidalab_qe/app/pseudos.py rename to src/aiidalab_qe/app/configuration/pseudos.py diff --git a/src/aiidalab_qe/app/configuration/workflow.py b/src/aiidalab_qe/app/configuration/workflow.py new file mode 100644 index 00000000..56c04133 --- /dev/null +++ b/src/aiidalab_qe/app/configuration/workflow.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +"""Widgets for the submission of bands work chains. + +Authors: AiiDAlab team +""" +import ipywidgets as ipw + +from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS + + +class WorkChainSettings(ipw.VBox): + structure_title = ipw.HTML( + """
    +

    Structure

    """ + ) + structure_help = ipw.HTML( + """
    + You have three options:
    + (1) Structure as is: perform a self consistent calculation using the structure provided as input.
    + (2) Atomic positions: perform a full relaxation of the internal atomic coordinates.
    + (3) Full geometry: perform a full relaxation for both the internal atomic coordinates and the cell vectors.
    """ + ) + materials_help = ipw.HTML( + """
    + Below you can indicate both if the material should be treated as an insulator + or a metal (if in doubt, choose "Metal"), + and if it should be studied with magnetization/spin polarization, + switch magnetism On or Off (On is at least twice more costly). +
    """ + ) + + properties_title = ipw.HTML( + """
    +

    Properties

    """ + ) + properties_help = ipw.HTML( + """
    + The band structure workflow will + automatically detect the default path in reciprocal space using the + + SeeK-path tool.
    """ + ) + + protocol_title = ipw.HTML( + """
    +

    Protocol

    """ + ) + protocol_help = ipw.HTML( + """
    + The "moderate" protocol represents a trade-off between + accuracy and speed. Choose the "fast" protocol for a faster calculation + with less precision and the "precise" protocol to aim at best accuracy (at the price of longer/costlier calculations).
    """ + ) + + def __init__(self, **kwargs): + # RelaxType: degrees of freedom in geometry optimization + self.relax_type = ipw.ToggleButtons( + options=[ + ("Structure as is", "none"), + ("Atomic positions", "positions"), + ("Full geometry", "positions_cell"), + ], + value="positions_cell", + ) + + # SpinType: magnetic properties of material + self.spin_type = ipw.ToggleButtons( + options=[("Off", "none"), ("On", "collinear")], + value=DEFAULT_PARAMETERS["spin_type"], + style={"description_width": "initial"}, + ) + + # ElectronicType: electronic properties of material + self.electronic_type = ipw.ToggleButtons( + options=[("Metal", "metal"), ("Insulator", "insulator")], + value=DEFAULT_PARAMETERS["electronic_type"], + style={"description_width": "initial"}, + ) + + # Checkbox to see if the band structure should be calculated + self.bands_run = ipw.Checkbox( + description="", + indent=False, + value=True, + layout=ipw.Layout(max_width="10%"), + ) + + # Checkbox to see if the PDOS should be calculated + self.pdos_run = ipw.Checkbox( + description="", + indent=False, + value=True, + layout=ipw.Layout(max_width="10%"), + ) + + # Work chain protocol + self.workchain_protocol = ipw.ToggleButtons( + options=["fast", "moderate", "precise"], + value="moderate", + ) + super().__init__( + children=[ + self.structure_title, + self.structure_help, + self.relax_type, + self.materials_help, + ipw.HBox( + children=[ + ipw.Label( + "Electronic Type:", + layout=ipw.Layout( + justify_content="flex-start", width="120px" + ), + ), + self.electronic_type, + ] + ), + ipw.HBox( + children=[ + ipw.Label( + "Magnetism:", + layout=ipw.Layout( + justify_content="flex-start", width="120px" + ), + ), + self.spin_type, + ] + ), + self.properties_title, + ipw.HTML("Select which properties to calculate:"), + ipw.HBox(children=[ipw.HTML("Band structure"), self.bands_run]), + ipw.HBox( + children=[ + ipw.HTML("Projected density of states"), + self.pdos_run, + ] + ), + self.properties_help, + self.protocol_title, + ipw.HTML("Select the protocol:", layout=ipw.Layout(flex="1 1 auto")), + self.workchain_protocol, + self.protocol_help, + ], + **kwargs, + ) + + def _update_settings(self, **kwargs): + """Update the settings based on the given dict.""" + for key in [ + "relax_type", + "spin_type", + "electronic_type", + "bands_run", + "pdos_run", + "workchain_protocol", + ]: + if key in kwargs: + getattr(self, key).value = kwargs[key] diff --git a/aiidalab_qe/app/parameters/__init__.py b/src/aiidalab_qe/app/parameters/__init__.py similarity index 100% rename from aiidalab_qe/app/parameters/__init__.py rename to src/aiidalab_qe/app/parameters/__init__.py diff --git a/aiidalab_qe/app/parameters/qeapp.yaml b/src/aiidalab_qe/app/parameters/qeapp.yaml similarity index 100% rename from aiidalab_qe/app/parameters/qeapp.yaml rename to src/aiidalab_qe/app/parameters/qeapp.yaml diff --git a/src/aiidalab_qe/app/result/__init__.py b/src/aiidalab_qe/app/result/__init__.py new file mode 100644 index 00000000..09d4597a --- /dev/null +++ b/src/aiidalab_qe/app/result/__init__.py @@ -0,0 +1,75 @@ +import ipywidgets as ipw +import traitlets as tl +from aiida import orm +from aiida.engine import ProcessState +from aiidalab_widgets_base import ( + AiidaNodeViewWidget, + ProcessMonitor, + ProcessNodesTreeWidget, + WizardAppWidgetStep, +) + +# trigger registration of the viewer widget: +from .workchain_viewer import WorkChainViewer # noqa: F401 + + +class ViewQeAppWorkChainStatusAndResultsStep(ipw.VBox, WizardAppWidgetStep): + process = tl.Unicode(allow_none=True) + + def __init__(self, **kwargs): + self.process_tree = ProcessNodesTreeWidget() + ipw.dlink( + (self, "process"), + (self.process_tree, "value"), + ) + + self.node_view = AiidaNodeViewWidget(layout={"width": "auto", "height": "auto"}) + ipw.dlink( + (self.process_tree, "selected_nodes"), + (self.node_view, "node"), + transform=lambda nodes: nodes[0] if nodes else None, + ) + self.process_status = ipw.VBox(children=[self.process_tree, self.node_view]) + + # Setup process monitor + self.process_monitor = ProcessMonitor( + timeout=0.2, + callbacks=[ + self.process_tree.update, + self._update_state, + ], + ) + ipw.dlink((self, "process"), (self.process_monitor, "value")) + + super().__init__([self.process_status], **kwargs) + + def can_reset(self): + "Do not allow reset while process is running." + return self.state is not self.State.ACTIVE + + def reset(self): + self.process = None + + def _update_state(self): + if self.process is None: + self.state = self.State.INIT + else: + process = orm.load_node(self.process) + process_state = process.process_state + if process_state in ( + ProcessState.CREATED, + ProcessState.RUNNING, + ProcessState.WAITING, + ): + self.state = self.State.ACTIVE + elif ( + process_state in (ProcessState.EXCEPTED, ProcessState.KILLED) + or process.is_failed + ): + self.state = self.State.FAIL + elif process.is_finished_ok: + self.state = self.State.SUCCESS + + @tl.observe("process") + def _observe_process(self, change): + self._update_state() diff --git a/src/aiidalab_qe/app/result/electronic_structure.py b/src/aiidalab_qe/app/result/electronic_structure.py new file mode 100644 index 00000000..c3cc846f --- /dev/null +++ b/src/aiidalab_qe/app/result/electronic_structure.py @@ -0,0 +1,179 @@ +import json +import random + +from aiida import orm +from monty.json import jsanitize + + +def export_data(work_chain_node, group_dos_by="atom"): + dos = export_pdos_data(work_chain_node, group_dos_by=group_dos_by) + fermi_energy = dos["fermi_energy"] if dos else None + + bands = export_bands_data(work_chain_node, fermi_energy) + + return dict( + bands=bands, + dos=dos, + ) + + +def export_pdos_data(work_chain_node, group_dos_by="atom"): + if "dos" in work_chain_node.outputs: + _, energy_dos, _ = work_chain_node.outputs.dos.get_x() + tdos_values = {f"{n}": v for n, v, _ in work_chain_node.outputs.dos.get_y()} + + dos = [] + + if "projections" in work_chain_node.outputs: + # The total dos parsed + tdos = { + "label": "Total DOS", + "x": energy_dos.tolist(), + "y": tdos_values.get("dos").tolist(), + "borderColor": "#8A8A8A", # dark gray + "backgroundColor": "#999999", # light gray + "backgroundAlpha": "40%", + "lineStyle": "solid", + } + dos.append(tdos) + + dos += _projections_curated( + work_chain_node.outputs.projections, + group_dos_by=group_dos_by, + spin_type="none", + ) + + else: + # The total dos parsed + tdos_up = { + "label": "Total DOS (↑)", + "x": energy_dos.tolist(), + "y": tdos_values.get("dos_spin_up").tolist(), + "borderColor": "#8A8A8A", # dark gray + "backgroundColor": "#999999", # light gray + "backgroundAlpha": "40%", + "lineStyle": "solid", + } + tdos_down = { + "label": "Total DOS (↓)", + "x": energy_dos.tolist(), + "y": (-tdos_values.get("dos_spin_down")).tolist(), # minus + "borderColor": "#8A8A8A", # dark gray + "backgroundColor": "#999999", # light gray + "backgroundAlpha": "40%", + "lineStyle": "dash", + } + dos += [tdos_up, tdos_down] + + # spin-up (↑) + dos += _projections_curated( + work_chain_node.outputs.projections_up, + group_dos_by=group_dos_by, + spin_type="up", + ) + + # spin-dn (↓) + dos += _projections_curated( + work_chain_node.outputs.projections_down, + group_dos_by=group_dos_by, + spin_type="down", + line_style="dash", + ) + + data_dict = { + "fermi_energy": work_chain_node.outputs.nscf_parameters["fermi_energy"], + "dos": dos, + } + + return json.loads(json.dumps(data_dict)) + + else: + return None + + +def export_bands_data(work_chain_node, fermi_energy=None): + if "band_structure" in work_chain_node.outputs: + data = json.loads( + work_chain_node.outputs.band_structure._exportcontent( + "json", comments=False + )[0] + ) + # The fermi energy from band calculation is not robust. + data["fermi_level"] = ( + fermi_energy or work_chain_node.outputs.band_parameters["fermi_energy"] + ) + return [ + jsanitize(data), + ] + else: + return None + + +def _projections_curated( + projections: orm.ProjectionData, + group_dos_by="atom", + spin_type="none", + line_style="solid", +): + """Collect the data from ProjectionData and parse it as dos list which can be + understand by bandsplot widget. `group_dos_by` is for which tag to be grouped, by atom or by orbital name. + The spin_type is used to invert all the y values of pdos to be shown as spin down pdos and to set label. + """ + _pdos = {} + + for orbital, pdos, energy in projections.get_pdos(): + orbital_data = orbital.get_orbital_dict() + kind_name = orbital_data["kind_name"] + atom_position = [round(i, 2) for i in orbital_data["position"]] + orbital_name = orbital.get_name_from_quantum_numbers( + orbital_data["angular_momentum"], orbital_data["magnetic_number"] + ).lower() + + if group_dos_by == "atom": + dos_group_name = atom_position + elif group_dos_by == "angular": + # by orbital label + dos_group_name = orbital_name[0] + elif group_dos_by == "angular_and_magnetic": + # by orbital label + dos_group_name = orbital_name + else: + raise Exception(f"Unknow dos type: {group_dos_by}!") + + key = f"{kind_name}-{dos_group_name}" + if key in _pdos: + _pdos[key][1] += pdos + else: + _pdos[key] = [energy, pdos] + + dos = [] + for label, (energy, pdos) in _pdos.items(): + if spin_type == "down": + # invert y-axis + pdos = -pdos + label = f"{label} (↓)" + + if spin_type == "up": + label = f"{label} (↑)" + + orbital_pdos = { + "label": label, + "x": energy.tolist(), + "y": pdos.tolist(), + "borderColor": cmap(label), + "lineStyle": line_style, + } + dos.append(orbital_pdos) + + return dos + + +def cmap(label: str) -> str: + """Return RGB string of color for given pseudo info + Hardcoded at the momment. + """ + # if a unknow type generate random color based on ascii sum + ascn = sum([ord(c) for c in label]) + random.seed(ascn) + + return "#%06x" % random.randint(0, 0xFFFFFF) diff --git a/aiidalab_qe/app/report.py b/src/aiidalab_qe/app/result/report.py similarity index 100% rename from aiidalab_qe/app/report.py rename to src/aiidalab_qe/app/result/report.py diff --git a/src/aiidalab_qe/app/result/summary_viewer.py b/src/aiidalab_qe/app/result/summary_viewer.py new file mode 100644 index 00000000..abc7739a --- /dev/null +++ b/src/aiidalab_qe/app/result/summary_viewer.py @@ -0,0 +1,14 @@ +import ipywidgets as ipw + +from .report import generate_report_html + + +class SummaryView(ipw.VBox): + def __init__(self, wc_node, **kwargs): + report_html = generate_report_html(wc_node) + + self.summary_view = ipw.HTML(report_html) + super().__init__( + children=[self.summary_view], + **kwargs, + ) diff --git a/aiidalab_qe/app/node_view.py b/src/aiidalab_qe/app/result/workchain_viewer.py similarity index 56% rename from aiidalab_qe/app/node_view.py rename to src/aiidalab_qe/app/result/workchain_viewer.py index 00e2a654..73d6d8f8 100644 --- a/aiidalab_qe/app/node_view.py +++ b/src/aiidalab_qe/app/result/workchain_viewer.py @@ -1,302 +1,187 @@ -"""Results view widgets (MOVE TO OTHER MODULE!) - -Authors: AiiDAlab team -""" - -import json -import random import shutil -import typing +import typing as t from importlib import resources from pathlib import Path from tempfile import TemporaryDirectory import ipywidgets as ipw -import nglview -import traitlets +import traitlets as tl +from aiida import orm from aiida.cmdline.utils.common import get_workchain_report from aiida.common import LinkType -from aiida.orm import CalcJobNode, Node, ProjectionData, WorkChainNode from aiidalab_widgets_base import ProcessMonitor, register_viewer_widget from aiidalab_widgets_base.viewers import StructureDataViewer -from ase import Atoms from filelock import FileLock, Timeout from IPython.display import HTML, display from jinja2 import Environment -from monty.json import jsanitize -from traitlets import Instance, Int, List, Unicode, Union, default, observe, validate from widget_bandsplot import BandsPlotWidget from aiidalab_qe.app import static -from aiidalab_qe.app.report import generate_report_html +from .electronic_structure import export_data +from .summary_viewer import SummaryView -class MinimalStructureViewer(ipw.VBox): - structure = Union([Instance(Atoms), Instance(Node)], allow_none=True) - _displayed_structure = Instance(Atoms, allow_none=True, read_only=True) - background = Unicode() - supercell = List(Int()) +@register_viewer_widget("process.workflow.workchain.WorkChainNode.") +class WorkChainViewer(ipw.VBox): + _results_shown = tl.Set() - def __init__(self, structure, *args, **kwargs): - self._viewer = nglview.NGLWidget() - self._viewer.camera = "orthographic" - self._viewer.stage.set_parameters(mouse_preset="pymol") - ipw.link((self, "background"), (self._viewer, "background")) + def __init__(self, node, **kwargs): + if node.process_label != "QeAppWorkChain": + super().__init__() + return - self.structure = structure + self.node = node - super().__init__( - children=[ - self._viewer, - ], - *args, - **kwargs, + self.title = ipw.HTML( + f""" +
    +

    QE App Workflow (pk: {self.node.pk}) — + {self.node.inputs.structure.get_formula()} +

    + """ ) + self.workflows_summary = SummaryView(self.node) - @default("background") - def _default_background(self): - return "#FFFFFF" - - @default("supercell") - def _default_supercell(self): - return [1, 1, 1] - - @validate("structure") - def _valid_structure(self, change): # pylint: disable=no-self-use - """Update structure.""" - structure = change["value"] - - if structure is None: - return None # if no structure provided, the rest of the code can be skipped - - if isinstance(structure, Atoms): - return structure - if isinstance(structure, Node): - return structure.get_ase() - raise ValueError( - "Unsupported type {}, structure must be one of the following types: " - "ASE Atoms object, AiiDA CifData or StructureData." + self.summary_tab = ipw.VBox(children=[self.workflows_summary]) + self.structure_tab = ipw.VBox( + [ipw.Label("Structure not available.")], + layout=ipw.Layout(min_height="380px"), ) - - @observe("structure") - def _update_displayed_structure(self, change): - """Update displayed_structure trait after the structure trait has been modified.""" - # Remove the current structure(s) from the viewer. - if change["new"] is not None: - self.set_trait("_displayed_structure", change["new"].repeat(self.supercell)) - else: - self.set_trait("_displayed_structure", None) - - @observe("_displayed_structure") - def _update_structure_viewer(self, change): - """Update the view if displayed_structure trait was modified.""" - with self.hold_trait_notifications(): - for ( - comp_id - ) in self._viewer._ngl_component_ids: # pylint: disable=protected-access - self._viewer.remove_component(comp_id) - self.selection = list() - if change["new"] is not None: - self._viewer.add_component(nglview.ASEStructure(change["new"])) - self._viewer.clear() - self._viewer.stage.set_parameters(clipDist=0) - self._viewer.add_representation("unitcell", diffuse="#df0587") - self._viewer.add_representation("ball+stick", aspectRatio=3.5) - - -def export_bands_data(work_chain_node, fermi_energy=None): - if "band_structure" in work_chain_node.outputs: - data = json.loads( - work_chain_node.outputs.band_structure._exportcontent( - "json", comments=False - )[0] + self.bands_tab = ipw.VBox( + [ipw.Label("Electronic Structure not available.")], + layout=ipw.Layout(min_height="380px"), ) - # The fermi energy from band calculation is not robust. - data["fermi_level"] = ( - fermi_energy or work_chain_node.outputs.band_parameters["fermi_energy"] + self.result_tabs = ipw.Tab( + children=[self.summary_tab, self.structure_tab, self.bands_tab] ) - return [ - jsanitize(data), - ] - else: - return None - - -def cmap(label: str) -> str: - """Return RGB string of color for given pseudo info - Hardcoded at the momment. - """ - # if a unknow type generate random color based on ascii sum - ascn = sum([ord(c) for c in label]) - random.seed(ascn) - - return "#%06x" % random.randint(0, 0xFFFFFF) - - -def _projections_curated( - projections: ProjectionData, - group_dos_by="atom", - spin_type="none", - line_style="solid", -): - """Collect the data from ProjectionData and parse it as dos list which can be - understand by bandsplot widget. `group_dos_by` is for which tag to be grouped, by atom or by orbital name. - The spin_type is used to invert all the y values of pdos to be shown as spin down pdos and to set label. - """ - _pdos = {} - - for orbital, pdos, energy in projections.get_pdos(): - orbital_data = orbital.get_orbital_dict() - kind_name = orbital_data["kind_name"] - atom_position = [round(i, 2) for i in orbital_data["position"]] - orbital_name = orbital.get_name_from_quantum_numbers( - orbital_data["angular_momentum"], orbital_data["magnetic_number"] - ).lower() - - if group_dos_by == "atom": - dos_group_name = atom_position - elif group_dos_by == "angular": - # by orbital label - dos_group_name = orbital_name[0] - elif group_dos_by == "angular_and_magnetic": - # by orbital label - dos_group_name = orbital_name - else: - raise Exception(f"Unknow dos type: {group_dos_by}!") - - key = f"{kind_name}-{dos_group_name}" - if key in _pdos: - _pdos[key][1] += pdos - else: - _pdos[key] = [energy, pdos] - - dos = [] - for label, (energy, pdos) in _pdos.items(): - if spin_type == "down": - # invert y-axis - pdos = -pdos - label = f"{label} (↓)" - - if spin_type == "up": - label = f"{label} (↑)" - - orbital_pdos = { - "label": label, - "x": energy.tolist(), - "y": pdos.tolist(), - "borderColor": cmap(label), - "lineStyle": line_style, - } - dos.append(orbital_pdos) - - return dos - - -def export_pdos_data(work_chain_node, group_dos_by="atom"): - if "dos" in work_chain_node.outputs: - _, energy_dos, _ = work_chain_node.outputs.dos.get_x() - tdos_values = {f"{n}": v for n, v, _ in work_chain_node.outputs.dos.get_y()} - - dos = [] - - if "projections" in work_chain_node.outputs: - # The total dos parsed - tdos = { - "label": "Total DOS", - "x": energy_dos.tolist(), - "y": tdos_values.get("dos").tolist(), - "borderColor": "#8A8A8A", # dark gray - "backgroundColor": "#999999", # light gray - "backgroundAlpha": "40%", - "lineStyle": "solid", - } - dos.append(tdos) - - dos += _projections_curated( - work_chain_node.outputs.projections, - group_dos_by=group_dos_by, - spin_type="none", - ) - else: - # The total dos parsed - tdos_up = { - "label": "Total DOS (↑)", - "x": energy_dos.tolist(), - "y": tdos_values.get("dos_spin_up").tolist(), - "borderColor": "#8A8A8A", # dark gray - "backgroundColor": "#999999", # light gray - "backgroundAlpha": "40%", - "lineStyle": "solid", - } - tdos_down = { - "label": "Total DOS (↓)", - "x": energy_dos.tolist(), - "y": (-tdos_values.get("dos_spin_down")).tolist(), # minus - "borderColor": "#8A8A8A", # dark gray - "backgroundColor": "#999999", # light gray - "backgroundAlpha": "40%", - "lineStyle": "dash", - } - dos += [tdos_up, tdos_down] - - # spin-up (↑) - dos += _projections_curated( - work_chain_node.outputs.projections_up, - group_dos_by=group_dos_by, - spin_type="up", - ) + self.result_tabs.set_title(0, "Workflow Summary") + self.result_tabs.set_title(1, "Final Geometry (n/a)") + self.result_tabs.set_title(2, "Electronic Structure (n/a)") - # spin-dn (↓) - dos += _projections_curated( - work_chain_node.outputs.projections_down, - group_dos_by=group_dos_by, - spin_type="down", - line_style="dash", - ) + # An ugly fix to the structure appearance problem + # https://github.com/aiidalab/aiidalab-qe/issues/69 + def on_selected_index_change(change): + index = change["new"] + # Accessing the viewer only if the corresponding tab is present. + if self.result_tabs._titles[str(index)] == "Final Geometry": + self._structure_view._viewer.handle_resize() - data_dict = { - "fermi_energy": work_chain_node.outputs.nscf_parameters["fermi_energy"], - "dos": dos, - } + def toggle_camera(): + """Toggle camera between perspective and orthographic.""" + self._structure_view._viewer.camera = ( + "perspective" + if self._structure_view._viewer.camera == "orthographic" + else "orthographic" + ) - return json.loads(json.dumps(data_dict)) + toggle_camera() + toggle_camera() - else: - return None + self.result_tabs.observe(on_selected_index_change, "selected_index") + self._update_view() + super().__init__( + children=[self.title, self.result_tabs], + **kwargs, + ) + self._process_monitor = ProcessMonitor( + process=self.node, + callbacks=[ + self._update_view, + ], + ) -def export_data(work_chain_node, group_dos_by="atom"): - dos = export_pdos_data(work_chain_node, group_dos_by=group_dos_by) - fermi_energy = dos["fermi_energy"] if dos else None + def _update_view(self): + with self.hold_trait_notifications(): + if self.node.is_finished: + self._show_workflow_output() + if ( + "structure" not in self._results_shown + and "structure" in self.node.outputs + ): + self._show_structure() + self._results_shown.add("structure") - bands = export_bands_data(work_chain_node, fermi_energy) + if "electronic_structure" not in self._results_shown and ( + "band_structure" in self.node.outputs or "dos" in self.node.outputs + ): + self._show_electronic_structure() + self._results_shown.add("electronic_structure") - return dict( - bands=bands, - dos=dos, - ) + def _show_structure(self): + self._structure_view = StructureDataViewer( + structure=self.node.outputs.structure + ) + self.result_tabs.children[1].children = [self._structure_view] + self.result_tabs.set_title(1, "Final Geometry") + def _show_electronic_structure(self): + group_dos_by = ipw.ToggleButtons( + options=[ + ("Atom", "atom"), + ("Orbital", "angular"), + ], + value="atom", + ) + settings = ipw.VBox( + children=[ + ipw.HBox( + children=[ + ipw.Label( + "DOS grouped by:", + layout=ipw.Layout( + justify_content="flex-start", width="120px" + ), + ), + group_dos_by, + ] + ), + ], + layout={"margin": "0 0 30px 30px"}, + ) + # + data = export_data(self.node, group_dos_by=group_dos_by.value) + bands_data = data.get("bands", None) + dos_data = data.get("dos", None) + _bands_plot_view = BandsPlotWidget( + bands=bands_data, + dos=dos_data, + ) -class VBoxWithCaption(ipw.VBox): - def __init__(self, caption, body, *args, **kwargs): - super().__init__(children=[ipw.HTML(caption), body], *args, **kwargs) + def response(change): + data = export_data(self.node, group_dos_by=group_dos_by.value) + bands_data = data.get("bands", None) + dos_data = data.get("dos", None) + _bands_plot_view = BandsPlotWidget( + bands=bands_data, + dos=dos_data, + ) + self.result_tabs.children[2].children = [ + settings, + _bands_plot_view, + ] + group_dos_by.observe(response, names="value") + # update the electronic structure tab + self.result_tabs.children[2].children = [ + settings, + _bands_plot_view, + ] + self.result_tabs.set_title(2, "Electronic Structure") -class SummaryView(ipw.VBox): - def __init__(self, wc_node, **kwargs): - report_html = generate_report_html(wc_node) + def _show_workflow_output(self): + self.workflows_output = WorkChainOutputs(self.node) - self.summary_view = ipw.HTML(report_html) - super().__init__( - children=[self.summary_view], - **kwargs, - ) + self.result_tabs.children[0].children = [ + self.workflows_summary, + self.workflows_output, + ] class WorkChainOutputs(ipw.VBox): - _busy = traitlets.Bool(read_only=True) + _busy = tl.Bool(read_only=True) def __init__(self, node, export_dir=None, **kwargs): self.export_dir = Path.cwd().joinpath("exports") @@ -349,11 +234,11 @@ def __init__(self, node, export_dir=None, **kwargs): **kwargs, ) - @traitlets.default("_busy") + @tl.default("_busy") def _default_busy(self): return False - @traitlets.observe("_busy") + @tl.observe("_busy") def _observe_busy(self, change): self._download_button_container.children = [ self._create_archive_indicator @@ -409,7 +294,7 @@ def _download_archive(self, _): ) @classmethod - def _prepare_calcjob_io(cls, node: WorkChainNode, root_folder: Path): + def _prepare_calcjob_io(cls, node: orm.WorkChainNode, root_folder: Path): """Prepare the calculation job input and output files. :param node: QeAppWorkChain node. @@ -440,7 +325,7 @@ def _prepare_calcjob_io(cls, node: WorkChainNode, root_folder: Path): counter += 1 @staticmethod - def _get_final_calcjob(node: WorkChainNode) -> typing.Union[None, CalcJobNode]: + def _get_final_calcjob(node: orm.WorkChainNode) -> t.Union[None, orm.CalcJobNode]: """Get the final calculation job node called by a work chain node. :param node: Work chain node. @@ -449,7 +334,7 @@ def _get_final_calcjob(node: WorkChainNode) -> typing.Union[None, CalcJobNode]: final_calcjob = [ process for process in node.called_descendants - if isinstance(process, CalcJobNode) and process.is_finished + if isinstance(process, orm.CalcJobNode) and process.is_finished ][-1] except IndexError: final_calcjob = None @@ -457,7 +342,7 @@ def _get_final_calcjob(node: WorkChainNode) -> typing.Union[None, CalcJobNode]: return final_calcjob @staticmethod - def _write_calcjob_io(calcjob: CalcJobNode, folder: Path) -> None: + def _write_calcjob_io(calcjob: orm.CalcJobNode, folder: Path) -> None: """Write the ``calcjob`` in and output files to ``folder``. :param calcjob: calculation job node for which to write the IO files. @@ -485,161 +370,3 @@ def _write_calcjob_io(calcjob: CalcJobNode, folder: Path) -> None: out_filepath = folder / filename with out_filepath.open("w") as handle: handle.write(retrieved.get_object_content(filename)) - - -@register_viewer_widget("process.workflow.workchain.WorkChainNode.") -class WorkChainViewer(ipw.VBox): - _results_shown = traitlets.Set() - - def __init__(self, node, **kwargs): - if node.process_label != "QeAppWorkChain": - super().__init__() - return - - self.node = node - - self.title = ipw.HTML( - f""" -
    -

    QE App Workflow (pk: {self.node.pk}) — - {self.node.inputs.structure.get_formula()} -

    - """ - ) - self.workflows_summary = SummaryView(self.node) - - self.summary_tab = ipw.VBox(children=[self.workflows_summary]) - self.structure_tab = ipw.VBox( - [ipw.Label("Structure not available.")], - layout=ipw.Layout(min_height="380px"), - ) - self.bands_tab = ipw.VBox( - [ipw.Label("Electronic Structure not available.")], - layout=ipw.Layout(min_height="380px"), - ) - self.result_tabs = ipw.Tab( - children=[self.summary_tab, self.structure_tab, self.bands_tab] - ) - - self.result_tabs.set_title(0, "Workflow Summary") - self.result_tabs.set_title(1, "Final Geometry (n/a)") - self.result_tabs.set_title(2, "Electronic Structure (n/a)") - - # An ugly fix to the structure appearance problem - # https://github.com/aiidalab/aiidalab-qe/issues/69 - def on_selected_index_change(change): - index = change["new"] - # Accessing the viewer only if the corresponding tab is present. - if self.result_tabs._titles[str(index)] == "Final Geometry": - self._structure_view._viewer.handle_resize() - - def toggle_camera(): - """Toggle camera between perspective and orthographic.""" - self._structure_view._viewer.camera = ( - "perspective" - if self._structure_view._viewer.camera == "orthographic" - else "orthographic" - ) - - toggle_camera() - toggle_camera() - - self.result_tabs.observe(on_selected_index_change, "selected_index") - self._update_view() - - super().__init__( - children=[self.title, self.result_tabs], - **kwargs, - ) - self._process_monitor = ProcessMonitor( - process=self.node, - callbacks=[ - self._update_view, - ], - ) - - def _update_view(self): - with self.hold_trait_notifications(): - if self.node.is_finished: - self._show_workflow_output() - if ( - "structure" not in self._results_shown - and "structure" in self.node.outputs - ): - self._show_structure() - self._results_shown.add("structure") - - if "electronic_structure" not in self._results_shown and ( - "band_structure" in self.node.outputs or "dos" in self.node.outputs - ): - self._show_electronic_structure() - self._results_shown.add("electronic_structure") - - def _show_structure(self): - self._structure_view = StructureDataViewer( - structure=self.node.outputs.structure - ) - self.result_tabs.children[1].children = [self._structure_view] - self.result_tabs.set_title(1, "Final Geometry") - - def _show_electronic_structure(self): - group_dos_by = ipw.ToggleButtons( - options=[ - ("Atom", "atom"), - ("Orbital", "angular"), - ], - value="atom", - ) - settings = ipw.VBox( - children=[ - ipw.HBox( - children=[ - ipw.Label( - "DOS grouped by:", - layout=ipw.Layout( - justify_content="flex-start", width="120px" - ), - ), - group_dos_by, - ] - ), - ], - layout={"margin": "0 0 30px 30px"}, - ) - # - data = export_data(self.node, group_dos_by=group_dos_by.value) - bands_data = data.get("bands", None) - dos_data = data.get("dos", None) - _bands_plot_view = BandsPlotWidget( - bands=bands_data, - dos=dos_data, - ) - - def response(change): - data = export_data(self.node, group_dos_by=group_dos_by.value) - bands_data = data.get("bands", None) - dos_data = data.get("dos", None) - _bands_plot_view = BandsPlotWidget( - bands=bands_data, - dos=dos_data, - ) - self.result_tabs.children[2].children = [ - settings, - _bands_plot_view, - ] - - group_dos_by.observe(response, names="value") - # update the electronic structure tab - self.result_tabs.children[2].children = [ - settings, - _bands_plot_view, - ] - self.result_tabs.set_title(2, "Electronic Structure") - - def _show_workflow_output(self): - self.workflows_output = WorkChainOutputs(self.node) - - self.result_tabs.children[0].children = [ - self.workflows_summary, - self.workflows_output, - ] diff --git a/aiidalab_qe/app/static/__init__.py b/src/aiidalab_qe/app/static/__init__.py similarity index 100% rename from aiidalab_qe/app/static/__init__.py rename to src/aiidalab_qe/app/static/__init__.py diff --git a/aiidalab_qe/app/static/style.css b/src/aiidalab_qe/app/static/style.css similarity index 100% rename from aiidalab_qe/app/static/style.css rename to src/aiidalab_qe/app/static/style.css diff --git a/aiidalab_qe/app/static/welcome.jinja b/src/aiidalab_qe/app/static/welcome.jinja similarity index 100% rename from aiidalab_qe/app/static/welcome.jinja rename to src/aiidalab_qe/app/static/welcome.jinja diff --git a/aiidalab_qe/app/static/workflow_failure.jinja b/src/aiidalab_qe/app/static/workflow_failure.jinja similarity index 100% rename from aiidalab_qe/app/static/workflow_failure.jinja rename to src/aiidalab_qe/app/static/workflow_failure.jinja diff --git a/aiidalab_qe/app/static/workflow_summary.jinja b/src/aiidalab_qe/app/static/workflow_summary.jinja similarity index 100% rename from aiidalab_qe/app/static/workflow_summary.jinja rename to src/aiidalab_qe/app/static/workflow_summary.jinja diff --git a/aiidalab_qe/app/structures.py b/src/aiidalab_qe/app/structure/__init__.py similarity index 100% rename from aiidalab_qe/app/structures.py rename to src/aiidalab_qe/app/structure/__init__.py diff --git a/src/aiidalab_qe/app/submission/__init__.py b/src/aiidalab_qe/app/submission/__init__.py new file mode 100644 index 00000000..7efb0473 --- /dev/null +++ b/src/aiidalab_qe/app/submission/__init__.py @@ -0,0 +1,672 @@ +# -*- coding: utf-8 -*- +"""Widgets for the submission of bands work chains. + +Authors: AiiDAlab team +""" +from __future__ import annotations + +import os +import typing as t +from dataclasses import dataclass + +import ipywidgets as ipw +import traitlets as tl +from aiida import orm +from aiida.common import NotExistent +from aiida.engine import ProcessBuilderNamespace, submit +from aiida_quantumespresso.common.types import ElectronicType, RelaxType, SpinType +from aiidalab_widgets_base import ComputationalResourcesWidget, WizardAppWidgetStep +from IPython.display import display + +from aiidalab_qe.app.common.setup_codes import QESetupWidget +from aiidalab_qe.app.configuration.advanced import AdvancedSettings +from aiidalab_qe.app.configuration.pseudos import PseudoFamilySelector +from aiidalab_qe.app.configuration.workflow import WorkChainSettings +from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS +from aiidalab_qe.workflows import QeAppWorkChain + +from .resource import ParallelizationSettings, ResourceSelectionWidget +from .sssp import SSSPInstallWidget + + +# The static input parameters for the QE App WorkChain +# The dataclass does not include codes and structure which will be set +# from widgets separately. +# Relax type, electronic type, spin type, are str because they are used also +# for serialized input of extras attributes of the workchain +@dataclass(frozen=True) +class QeWorkChainParameters: + protocol: str + relax_type: str + properties: t.List[str] + spin_type: str + electronic_type: str + overrides: t.Dict[str, t.Any] + initial_magnetic_moments: t.Dict[str, float] + + +PROTOCOL_PSEUDO_MAP = { + "fast": "SSSP/1.2/PBE/efficiency", + "moderate": "SSSP/1.2/PBE/efficiency", + "precise": "SSSP/1.2/PBE/precision", +} + + +class SubmitQeAppWorkChainStep(ipw.VBox, WizardAppWidgetStep): + """Step for submission of a bands workchain.""" + + codes_title = ipw.HTML( + """
    +

    Codes

    """ + ) + codes_help = ipw.HTML( + """
    Select the code to use for running the calculations. The codes + on the local machine (localhost) are installed by default, but you can + configure new ones on potentially more powerful machines by clicking on + "Setup new code".
    """ + ) + + # This number provides a rough estimate for how many MPI tasks are needed + # for a given structure. + NUM_SITES_PER_MPI_TASK_DEFAULT = 6 + + # Warn the user if they are trying to run calculations for a large + # structure on localhost. + RUN_ON_LOCALHOST_NUM_SITES_WARN_THRESHOLD = 10 + + # Put a limit on how many MPI tasks you want to run per k-pool by default + MAX_MPI_PER_POOL = 20 + + input_structure = tl.Instance(orm.StructureData, allow_none=True) + process = tl.Instance(orm.WorkChainNode, allow_none=True) + previous_step_state = tl.UseEnum(WizardAppWidgetStep.State) + workchain_settings = tl.Instance(WorkChainSettings, allow_none=True) + pseudo_family_selector = tl.Instance(PseudoFamilySelector, allow_none=True) + advanced_settings = tl.Instance(AdvancedSettings, allow_none=True) + _submission_blockers = tl.List(tl.Unicode()) + + def __init__(self, qe_auto_setup=True, **kwargs): + self.message_area = ipw.Output() + self._submission_blocker_messages = ipw.HTML() + + self.pw_code = ComputationalResourcesWidget( + description="pw.x:", default_calc_job_plugin="quantumespresso.pw" + ) + self.dos_code = ComputationalResourcesWidget( + description="dos.x:", + default_calc_job_plugin="quantumespresso.dos", + ) + self.projwfc_code = ComputationalResourcesWidget( + description="projwfc.x:", + default_calc_job_plugin="quantumespresso.projwfc", + ) + + self.resources_config = ResourceSelectionWidget() + self.parallelization = ParallelizationSettings() + + self.set_selected_codes(DEFAULT_PARAMETERS) + self.set_resource_defaults() + + self.pw_code.observe(self._update_state, "value") + self.pw_code.observe(self._update_resources, "value") + self.dos_code.observe(self._update_state, "value") + self.projwfc_code.observe(self._update_state, "value") + + self.submit_button = ipw.Button( + description="Submit", + tooltip="Submit the calculation with the selected parameters.", + icon="play", + button_style="success", + layout=ipw.Layout(width="auto", flex="1 1 auto"), + disabled=True, + ) + + self.submit_button.on_click(self._on_submit_button_clicked) + + # The SSSP installation status widget shows the installation status of + # the SSSP pseudo potentials and triggers the installation in case that + # they are not yet installed. The widget will remain in a "busy" state + # in case that the installation was already triggered elsewhere, e.g., + # by the start up scripts. The submission is blocked while the + # potentials are not yet installed. + self.sssp_installation_status = SSSPInstallWidget(auto_start=qe_auto_setup) + self.sssp_installation_status.observe(self._update_state, ["busy", "installed"]) + self.sssp_installation_status.observe(self._toggle_install_widgets, "installed") + + # The QE setup widget checks whether there are codes that match specific + # expected labels (e.g. "pw-7.2@localhost") and triggers both the + # installation of QE into a dedicated conda environment and the setup of + # the codes in case that they are not already configured. + self.qe_setup_status = QESetupWidget(auto_start=qe_auto_setup) + self.qe_setup_status.observe(self._update_state, "busy") + self.qe_setup_status.observe(self._toggle_install_widgets, "installed") + self.qe_setup_status.observe(self._auto_select_code, "installed") + + super().__init__( + children=[ + self.codes_title, + self.codes_help, + self.pw_code, + self.dos_code, + self.projwfc_code, + self.resources_config, + self.parallelization, + self.message_area, + self.sssp_installation_status, + self.qe_setup_status, + self._submission_blocker_messages, + self.submit_button, + ] + ) + + @tl.observe("_submission_blockers") + def _observe_submission_blockers(self, change): + if change["new"]: + fmt_list = "\n".join((f"
  • {item}
  • " for item in sorted(change["new"]))) + self._submission_blocker_messages.value = f""" +
    + The submission is blocked, due to the following reason(s): +
    """ + else: + self._submission_blocker_messages.value = "" + + def _identify_submission_blockers(self): + # Do not submit while any of the background setup processes are running. + if self.qe_setup_status.busy or self.sssp_installation_status.busy: + yield "Background setup processes must finish." + + # No code selected (this is ignored while the setup process is running). + if self.pw_code.value is None and not self.qe_setup_status.busy: + yield ("No pw code selected") + + # No code selected for pdos (this is ignored while the setup process is running). + if ( + self.workchain_settings.pdos_run.value + and (self.dos_code.value is None or self.projwfc_code.value is None) + and not self.qe_setup_status.busy + ): + yield "Calculating the PDOS requires both dos.x and projwfc.x to be set." + + # SSSP library not installed + if not self.sssp_installation_status.installed: + yield "The SSSP library is not installed." + + if ( + self.workchain_settings.pdos_run.value + and not any( + [ + self.pw_code.value is None, + self.dos_code.value is None, + self.projwfc_code.value is None, + ] + ) + and len( + set( + ( + orm.load_code(self.pw_code.value).computer.pk, + orm.load_code(self.dos_code.value).computer.pk, + orm.load_code(self.projwfc_code.value).computer.pk, + ) + ) + ) + != 1 + ): + yield ( + "All selected codes must be installed on the same computer. This is because the " + "PDOS calculations rely on large files that are not retrieved by AiiDA." + ) + + def _update_state(self, _=None): + # If the previous step has failed, this should fail as well. + if self.previous_step_state is self.State.FAIL: + self.state = self.State.FAIL + return + # Do not interact with the user if they haven't successfully completed the previous step. + elif self.previous_step_state is not self.State.SUCCESS: + self.state = self.State.INIT + return + + # Process is already running. + if self.process is not None: + self.state = self.State.SUCCESS + return + + blockers = list(self._identify_submission_blockers()) + if any(blockers): + self._submission_blockers = blockers + self.state = self.State.READY + return + + self._submission_blockers = [] + self.state = self.state.CONFIGURED + + def _toggle_install_widgets(self, change): + if change["new"]: + self.children = [ + child for child in self.children if child is not change["owner"] + ] + + def _auto_select_code(self, change): + if change["new"] and not change["old"]: + for code in [ + "pw_code", + "dos_code", + "projwfc_code", + ]: + try: + code_widget = getattr(self, code) + code_widget.refresh() + code_widget.value = orm.load_code(DEFAULT_PARAMETERS[code]).uuid + except NotExistent: + pass + + _ALERT_MESSAGE = """ +
    + × + × + {message} +
    """ + + def _show_alert_message(self, message, alert_class="info"): + with self.message_area: + display( + ipw.HTML( + self._ALERT_MESSAGE.format(alert_class=alert_class, message=message) + ) + ) + + def _update_resources(self, change): + if change["new"] and ( + change["old"] is None + or orm.load_code(change["new"]).computer.pk + != orm.load_code(change["old"]).computer.pk + ): + self.set_resource_defaults(orm.load_code(change["new"]).computer) + + def set_resource_defaults(self, computer=None): + if computer is None or computer.hostname == "localhost": + self.resources_config.num_nodes.disabled = True + self.resources_config.num_nodes.value = 1 + self.resources_config.num_cpus.max = os.cpu_count() + self.resources_config.num_cpus.value = 1 + self.resources_config.num_cpus.description = "CPUs" + self.parallelization.npools.value = 1 + else: + default_mpiprocs = computer.get_default_mpiprocs_per_machine() + self.resources_config.num_nodes.disabled = False + self.resources_config.num_cpus.max = default_mpiprocs + self.resources_config.num_cpus.value = default_mpiprocs + self.resources_config.num_cpus.description = "CPUs/node" + self.parallelization.npools.value = self._get_default_parallelization() + + self._check_resources() + + def _get_default_parallelization(self): + """A _very_ rudimentary approach for obtaining a minimal npools setting.""" + num_mpiprocs = ( + self.resources_config.num_nodes.value * self.resources_config.num_cpus.value + ) + + for i in range(1, num_mpiprocs + 1): + if num_mpiprocs % i == 0 and num_mpiprocs // i < self.MAX_MPI_PER_POOL: + return i + + def _check_resources(self): + """Check whether the currently selected resources will be sufficient and warn if not.""" + if not self.pw_code.value: + return # No code selected, nothing to do. + + num_cpus = self.resources_config.num_cpus.value + on_localhost = ( + orm.load_node(self.pw_code.value).computer.hostname == "localhost" + ) + if self.pw_code.value and on_localhost and num_cpus > 1: + self._show_alert_message( + "The selected code would be executed on the local host, but " + "the number of CPUs is larger than one. Please review " + "the configuration and consider to select a code that runs " + "on a larger system if necessary.", + alert_class="warning", + ) + elif ( + self.input_structure + and on_localhost + and len(self.input_structure.sites) + > self.RUN_ON_LOCALHOST_NUM_SITES_WARN_THRESHOLD + ): + self._show_alert_message( + "The selected code would be executed on the local host, but the " + "number of sites of the selected structure is relatively large. " + "Consider to select a code that runs on a larger system if " + "necessary.", + alert_class="warning", + ) + + @tl.observe("state") + def _observe_state(self, change): + with self.hold_trait_notifications(): + self.submit_button.disabled = change["new"] != self.State.CONFIGURED + + @tl.observe("previous_step_state") + def _observe_input_structure(self, _): + self._update_state() + self.set_pdos_status() + + @tl.observe("process") + def _observe_process(self, change): + with self.hold_trait_notifications(): + process_node = change["new"] + if process_node is not None: + self.input_structure = process_node.inputs.structure + builder_parameters = process_node.base.extras.get( + "builder_parameters", None + ) + if builder_parameters is not None: + self.set_selected_codes(builder_parameters) + self._update_state() + + def _on_submit_button_clicked(self, _): + self.submit_button.disabled = True + self.submit() + + def set_selected_codes(self, parameters): + """Set the inputs in the GUI based on a set of parameters.""" + + # Codes + def _get_code_uuid(code): + if code is not None: + try: + return orm.load_code(code).uuid + except NotExistent: + return None + + with self.hold_trait_notifications(): + # Codes + self.pw_code.value = _get_code_uuid(parameters["pw_code"]) + self.dos_code.value = _get_code_uuid(parameters["dos_code"]) + self.projwfc_code.value = _get_code_uuid(parameters["projwfc_code"]) + + def set_pdos_status(self): + if self.workchain_settings.pdos_run.value: + self.dos_code.code_select_dropdown.disabled = False + self.projwfc_code.code_select_dropdown.disabled = False + else: + self.dos_code.code_select_dropdown.disabled = True + self.projwfc_code.code_select_dropdown.disabled = True + + def submit(self, _=None): + """Submit the work chain with the current inputs.""" + builder = self._create_builder() + extra_parameters = self._create_extra_report_parameters() + + with self.hold_trait_notifications(): + self.process = submit(builder) + + # Set the builder parameters on the work chain + builder_parameters = self._extract_report_parameters( + builder, extra_parameters + ) + self.process.base.extras.set("builder_parameters", builder_parameters) + + self._update_state() + + def _get_qe_workchain_parameters(self) -> QeWorkChainParameters: + """Get the parameters of the `QeWorkChain` from widgets.""" + # create the the initial_magnetic_moments as None (Default) + initial_magnetic_moments = None + # create the override parameters for sub PwBaseWorkChain + pw_overrides = {"base": {}, "scf": {}, "nscf": {}, "band": {}} + for key in ["base", "scf", "nscf", "band"]: + if self.pseudo_family_selector.override_protocol_pseudo_family.value: + pw_overrides[key]["pseudo_family"] = self.pseudo_family_selector.value + if self.advanced_settings.override.value: + pw_overrides[key]["pw"] = {"parameters": {"SYSTEM": {}}} + if self.advanced_settings.tot_charge.override.value: + pw_overrides[key]["pw"]["parameters"]["SYSTEM"][ + "tot_charge" + ] = self.advanced_settings.tot_charge.charge.value + if ( + self.advanced_settings.magnetization.override.value + and self.workchain_settings.spin_type.value == "collinear" + ): + initial_magnetic_moments = ( + self.advanced_settings.magnetization.get_magnetization() + ) + + if key in ["base", "scf"]: + if self.advanced_settings.kpoints.override.value: + pw_overrides[key][ + "kpoints_distance" + ] = self.advanced_settings.kpoints.distance.value + if ( + self.advanced_settings.smearing.override.value + and self.workchain_settings.electronic_type.value == "metal" + ): + # smearing type setting + pw_overrides[key]["pw"]["parameters"]["SYSTEM"][ + "smearing" + ] = self.advanced_settings.smearing.smearing.value + + # smearing degauss setting + pw_overrides[key]["pw"]["parameters"]["SYSTEM"][ + "degauss" + ] = self.advanced_settings.smearing.degauss.value + + overrides = { + "relax": { + "base": pw_overrides["base"], + }, + "bands": { + "scf": pw_overrides["scf"], + "bands": pw_overrides["band"], + }, + "pdos": { + "scf": pw_overrides["scf"], + "nscf": pw_overrides["nscf"], + }, + } + + # Work chain settings + relax_type = self.workchain_settings.relax_type.value + electronic_type = self.workchain_settings.electronic_type.value + spin_type = self.workchain_settings.spin_type.value + + run_bands = self.workchain_settings.bands_run.value + run_pdos = self.workchain_settings.pdos_run.value + protocol = self.workchain_settings.workchain_protocol.value + + properties = [] + + if run_bands: + properties.append("bands") + if run_pdos: + properties.append("pdos") + + if RelaxType(relax_type) is not RelaxType.NONE or not (run_bands or run_pdos): + properties.append("relax") + + return QeWorkChainParameters( + protocol=protocol, + relax_type=relax_type, + properties=properties, + spin_type=spin_type, + electronic_type=electronic_type, + overrides=overrides, + initial_magnetic_moments=initial_magnetic_moments, + ) + + def _create_builder(self) -> ProcessBuilderNamespace: + """Create the builder for the `QeAppWorkChain` submit.""" + pw_code = self.pw_code.value + dos_code = self.dos_code.value + projwfc_code = self.projwfc_code.value + + parameters = self._get_qe_workchain_parameters() + + builder = QeAppWorkChain.get_builder_from_protocol( + structure=self.input_structure, + pw_code=orm.load_code(pw_code), + dos_code=orm.load_code(dos_code), + projwfc_code=orm.load_code(projwfc_code), + protocol=parameters.protocol, + relax_type=RelaxType(parameters.relax_type), + properties=parameters.properties, + spin_type=SpinType(parameters.spin_type), + electronic_type=ElectronicType(parameters.electronic_type), + overrides=parameters.overrides, + initial_magnetic_moments=parameters.initial_magnetic_moments, + ) + + resources = { + "num_machines": self.resources_config.num_nodes.value, + "num_mpiprocs_per_machine": self.resources_config.num_cpus.value, + } + + npool = self.parallelization.npools.value + self._update_builder(builder, resources, npool, self.MAX_MPI_PER_POOL) + + return builder + + def _update_builder(self, buildy, resources, npools, max_mpi_per_pool): + """Update the resources and parallelization of the ``QeAppWorkChain`` builder.""" + for k, v in buildy.items(): + if isinstance(v, (dict, ProcessBuilderNamespace)): + if k == "pw" and v["pseudos"]: + v["parallelization"] = orm.Dict(dict={"npool": npools}) + if k == "projwfc": + v["settings"] = orm.Dict(dict={"cmdline": ["-nk", str(npools)]}) + if k == "dos": + v["metadata"]["options"]["resources"] = { + "num_machines": 1, + "num_mpiprocs_per_machine": min( + max_mpi_per_pool, + resources["num_mpiprocs_per_machine"], + ), + } + # Continue to the next item to avoid overriding the resources in the + # recursive `update_builder` call. + continue + if k == "resources": + buildy["resources"] = resources + else: + self._update_builder(v, resources, npools, max_mpi_per_pool) + + def _create_extra_report_parameters(self) -> dict[str, t.Any]: + """This method will also create a dictionary of the parameters that were not + readably represented in the builder, which will be used to the report. + It is stored in the `extra_report_parameters`. + """ + qe_workchain_parameters = self._get_qe_workchain_parameters() + + # Construct the extra report parameters needed for the report + extra_report_parameters = { + "relax_type": qe_workchain_parameters.relax_type, + "electronic_type": qe_workchain_parameters.electronic_type, + "spin_type": qe_workchain_parameters.spin_type, + "protocol": qe_workchain_parameters.protocol, + "initial_magnetic_moments": qe_workchain_parameters.initial_magnetic_moments, + } + + # update pseudo family information to extra_report_parameters + if self.pseudo_family_selector.override_protocol_pseudo_family.value: + # If the pseudo family is overridden, use that + pseudo_family = self.pseudo_family_selector.value + else: + # otherwise extract the information from protocol + pseudo_family = PROTOCOL_PSEUDO_MAP[qe_workchain_parameters.protocol] + + pseudo_family_info = pseudo_family.split("/") + if pseudo_family_info[0] == "SSSP": + pseudo_protocol = pseudo_family_info[3] + elif pseudo_family_info[0] == "PseudoDojo": + pseudo_protocol = pseudo_family_info[4] + extra_report_parameters.update( + { + "pseudo_family": pseudo_family, + "pseudo_library": pseudo_family_info[0], + "pseudo_version": pseudo_family_info[1], + "functional": pseudo_family_info[2], + "pseudo_protocol": pseudo_protocol, + } + ) + + # store codes info into extra_report_parameters for loading the process + pw_code = self.pw_code.value + dos_code = self.dos_code.value + projwfc_code = self.projwfc_code.value + + extra_report_parameters.update( + { + "pw_code": pw_code, + "dos_code": dos_code, + "projwfc_code": projwfc_code, + } + ) + + return extra_report_parameters + + @staticmethod + def _extract_report_parameters( + builder, extra_report_parameters + ) -> dict[str, t.Any]: + """Extract (recover) the parameters for report from the builder. + + There are some parameters that are not stored in the builder, but can be extracted + directly from the widgets, such as the ``pseudo_family`` and ``relax_type``. + """ + parameters = { + "run_relax": "relax" in builder.properties, + "run_bands": "bands" in builder.properties, + "run_pdos": "pdos" in builder.properties, + } + + # Extract the pw calculation parameters from the builder + + # energy_cutoff is same for all pw calculations when pseudopotentials are fixed + # as well as the smearing settings (semaring and degauss) and scf kpoints distance + # read from the first pw calculation of relax workflow. + # It is safe then to extract these parameters from the first pw calculation, since the + # builder is anyway set with subworkchain inputs even it is not run which controlled by + # the properties inputs. + energy_cutoff_wfc = builder.relax.base["pw"]["parameters"]["SYSTEM"]["ecutwfc"] + energy_cutoff_rho = builder.relax.base["pw"]["parameters"]["SYSTEM"]["ecutrho"] + occupation = builder.relax.base["pw"]["parameters"]["SYSTEM"]["occupations"] + scf_kpoints_distance = builder.relax.base.kpoints_distance.value + + parameters.update( + { + "energy_cutoff_wfc": energy_cutoff_wfc, + "energy_cutoff_rho": energy_cutoff_rho, + "occupation": occupation, + "scf_kpoints_distance": scf_kpoints_distance, + } + ) + + if occupation == "smearing": + parameters["degauss"] = builder.relax.base["pw"]["parameters"]["SYSTEM"][ + "degauss" + ] + parameters["smearing"] = builder.relax.base["pw"]["parameters"]["SYSTEM"][ + "smearing" + ] + + parameters[ + "bands_kpoints_distance" + ] = builder.bands.bands_kpoints_distance.value + parameters["nscf_kpoints_distance"] = builder.pdos.nscf.kpoints_distance.value + + parameters["tot_charge"] = builder.relax.base["pw"]["parameters"]["SYSTEM"].get( + "tot_charge", 0.0 + ) + + # parameters from extra_report_parameters + for k, v in extra_report_parameters.items(): + parameters.update({k: v}) + + return parameters + + def reset(self): + with self.hold_trait_notifications(): + self.process = None + self.input_structure = None diff --git a/src/aiidalab_qe/app/submission/resource.py b/src/aiidalab_qe/app/submission/resource.py new file mode 100644 index 00000000..3d3e5fb9 --- /dev/null +++ b/src/aiidalab_qe/app/submission/resource.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +"""Widgets for the submission of bands work chains. + +Authors: AiiDAlab team +""" +import ipywidgets as ipw + + +class ResourceSelectionWidget(ipw.VBox): + """Widget for the selection of compute resources.""" + + title = ipw.HTML( + """
    +

    Resources

    +
    """ + ) + prompt = ipw.HTML( + """
    +

    + Specify the resources to use for the pw.x calculation. +

    """ + ) + + def __init__(self, **kwargs): + extra = { + "style": {"description_width": "150px"}, + "layout": {"min_width": "180px"}, + } + self.num_nodes = ipw.BoundedIntText( + value=1, step=1, min=1, max=1000, description="Nodes", **extra + ) + self.num_cpus = ipw.BoundedIntText( + value=1, step=1, min=1, description="CPUs", **extra + ) + + super().__init__( + children=[ + self.title, + ipw.HBox( + children=[self.prompt, self.num_nodes, self.num_cpus], + layout=ipw.Layout(justify_content="space-between"), + ), + ] + ) + + def reset(self): + self.num_nodes.value = 1 + self.num_cpus.value = 1 + + +class ParallelizationSettings(ipw.VBox): + """Widget for setting the parallelization settings.""" + + title = ipw.HTML( + """
    +

    Parallelization

    +
    """ + ) + prompt = ipw.HTML( + """
    +

    + Specify the number of k-points pools for the calculations. +

    """ + ) + + def __init__(self, **kwargs): + extra = { + "style": {"description_width": "150px"}, + "layout": {"min_width": "180px"}, + } + self.npools = ipw.BoundedIntText( + value=1, step=1, min=1, max=128, description="Number of k-pools", **extra + ) + super().__init__( + children=[ + self.title, + ipw.HBox( + children=[self.prompt, self.npools], + layout=ipw.Layout(justify_content="space-between"), + ), + ] + ) + + def reset(self): + self.npools.value = 1 diff --git a/aiidalab_qe/app/sssp.py b/src/aiidalab_qe/app/submission/sssp.py similarity index 99% rename from aiidalab_qe/app/sssp.py rename to src/aiidalab_qe/app/submission/sssp.py index ea4636a7..eecd55cc 100644 --- a/aiidalab_qe/app/sssp.py +++ b/src/aiidalab_qe/app/submission/sssp.py @@ -10,7 +10,7 @@ from aiida_pseudo.groups.family import PseudoPotentialFamily from filelock import FileLock, Timeout -from aiidalab_qe.app.widgets import ProgressBar +from aiidalab_qe.app.common.widgets import ProgressBar EXPECTED_PSEUDOS = { "SSSP/1.2/PBE/efficiency", diff --git a/aiidalab_qe/version.py b/src/aiidalab_qe/version.py similarity index 100% rename from aiidalab_qe/version.py rename to src/aiidalab_qe/version.py diff --git a/aiidalab_qe/workflows/__init__.py b/src/aiidalab_qe/workflows/__init__.py similarity index 100% rename from aiidalab_qe/workflows/__init__.py rename to src/aiidalab_qe/workflows/__init__.py diff --git a/tests/conftest.py b/tests/conftest.py index d2845950..1d615542 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -137,7 +137,7 @@ def projwfc_code(aiida_local_code_factory): @pytest.fixture() def workchain_settings_generator(): """Return a function that generates a workchain settings dictionary.""" - from aiidalab_qe.app.steps import WorkChainSettings + from aiidalab_qe.app.configuration.workflow import WorkChainSettings def _workchain_settings_generator(**kwargs): workchain_settings = WorkChainSettings() @@ -150,7 +150,7 @@ def _workchain_settings_generator(**kwargs): @pytest.fixture() def initial_magnetic_moments_generator(structure_data_object): """Retturn a function that generatates a initial_magnetic_moments dictionary""" - from aiidalab_qe.app.steps import MagnetizationSettings + from aiidalab_qe.app.configuration.advanced import MagnetizationSettings def _initial_moments_generator(**kwargs): initial_magnetic_moments = MagnetizationSettings() @@ -165,7 +165,7 @@ def _initial_moments_generator(**kwargs): @pytest.fixture() def tot_charge_generator(): """Return a function that generates a tot_charge dictionary.""" - from aiidalab_qe.app.steps import TotalCharge + from aiidalab_qe.app.configuration.advanced import TotalCharge def _tot_charge_generator(**kwargs): tot_charge = TotalCharge() @@ -178,7 +178,7 @@ def _tot_charge_generator(**kwargs): @pytest.fixture() def smearing_settings_generator(): """Return a function that generates a smearing settings dictionary.""" - from aiidalab_qe.app.steps import SmearingSettings + from aiidalab_qe.app.configuration.advanced import SmearingSettings def _smearing_settings_generator(**kwargs): smearing_settings = SmearingSettings() @@ -191,7 +191,7 @@ def _smearing_settings_generator(**kwargs): @pytest.fixture() def kpoints_settings_generator(): """Return a function that generates a kpoints settings dictionary.""" - from aiidalab_qe.app.steps import KpointSettings + from aiidalab_qe.app.configuration.advanced import KpointSettings def _kpoints_settings_generator(**kwargs): kpoints_settings = KpointSettings() @@ -215,8 +215,9 @@ def submit_step_widget_generator( initial_magnetic_moments_generator, ): """Return a function that generates a submit step widget.""" - from aiidalab_qe.app.pseudos import PseudoFamilySelector - from aiidalab_qe.app.steps import AdvancedSettings, SubmitQeAppWorkChainStep + from aiidalab_qe.app.configuration.advanced import AdvancedSettings + from aiidalab_qe.app.configuration.pseudos import PseudoFamilySelector + from aiidalab_qe.app.submission import SubmitQeAppWorkChainStep def _submit_step_widget_generator( relax_type="positions_cell", diff --git a/tests/test_pseudo.py b/tests/test_pseudo.py index b88b4bd6..1a09d9c2 100644 --- a/tests/test_pseudo.py +++ b/tests/test_pseudo.py @@ -1,6 +1,6 @@ def test_pseudos_family_selector_widget(): """Test the pseudos widget.""" - from aiidalab_qe.app.pseudos import PseudoFamilySelector + from aiidalab_qe.app.configuration.pseudos import PseudoFamilySelector wg = PseudoFamilySelector() wg.override_protocol_pseudo_family.value = True diff --git a/tests/test_submit_qe_workchain.py b/tests/test_submit_qe_workchain.py index 5c98a57b..0557054a 100644 --- a/tests/test_submit_qe_workchain.py +++ b/tests/test_submit_qe_workchain.py @@ -1,6 +1,6 @@ def test_reload_selected_code(submit_step_widget_generator): """Test set_selected_codes method.""" - from aiidalab_qe.app.steps import SubmitQeAppWorkChainStep + from aiidalab_qe.app.submission import SubmitQeAppWorkChainStep submit_step = submit_step_widget_generator() @@ -28,7 +28,7 @@ def test_create_builder_default( """ from bs4 import BeautifulSoup - from aiidalab_qe.app.report import _generate_report_html + from aiidalab_qe.app.result.report import _generate_report_html submit_step = submit_step_widget_generator() @@ -63,7 +63,7 @@ def test_create_builder_insulator( the occupation type is set to fixed, smearing and degauss should not be set""" from bs4 import BeautifulSoup - from aiidalab_qe.app.report import _generate_report_html + from aiidalab_qe.app.result.report import _generate_report_html submit_step = submit_step_widget_generator( electronic_type="insulator", @@ -101,7 +101,7 @@ def test_create_builder_advanced_settings( -tot_charge -initial_magnetic_moments """ - from aiidalab_qe.app.report import _generate_report_html + from aiidalab_qe.app.result.report import _generate_report_html submit_step = submit_step_widget_generator( electronic_type="metal",