diff --git a/src/aiida_workgraph/config.py b/src/aiida_workgraph/config.py index b97fb8cd..519747ec 100644 --- a/src/aiida_workgraph/config.py +++ b/src/aiida_workgraph/config.py @@ -1,6 +1,8 @@ import json from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER +WORKGRAPH_EXTRA_KEY = "_workgraph" + def load_config() -> dict: """Load the configuration from the config file.""" diff --git a/src/aiida_workgraph/engine/workgraph.py b/src/aiida_workgraph/engine/workgraph.py index 5b79c27a..6cfe0635 100644 --- a/src/aiida_workgraph/engine/workgraph.py +++ b/src/aiida_workgraph/engine/workgraph.py @@ -308,8 +308,9 @@ def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None: def read_wgdata_from_base(self) -> t.Dict[str, t.Any]: """Read workgraph data from base.extras.""" from aiida_workgraph.orm.function_data import PickledLocalFunction + from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY - wgdata = self.node.base.extras.get("_workgraph") + wgdata = self.node.base.extras.get(WORKGRAPH_EXTRA_KEY) for name, task in wgdata["tasks"].items(): wgdata["tasks"][name] = deserialize_unsafe(task) for _, input in wgdata["tasks"][name]["inputs"].items(): diff --git a/src/aiida_workgraph/utils/__init__.py b/src/aiida_workgraph/utils/__init__.py index 55d2fb2e..9b6e91d6 100644 --- a/src/aiida_workgraph/utils/__init__.py +++ b/src/aiida_workgraph/utils/__init__.py @@ -322,10 +322,11 @@ def get_workgraph_data(process: Union[int, orm.Node]) -> Optional[Dict[str, Any] """Get the workgraph data from the process node.""" from aiida.orm.utils.serialize import deserialize_unsafe from aiida.orm import load_node + from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY if isinstance(process, int): process = load_node(process) - wgdata = process.base.extras.get("_workgraph", None) + wgdata = process.base.extras.get(WORKGRAPH_EXTRA_KEY, None) if wgdata is None: return for name, task in wgdata["tasks"].items(): diff --git a/src/aiida_workgraph/utils/analysis.py b/src/aiida_workgraph/utils/analysis.py index 477a2f27..f9d43716 100644 --- a/src/aiida_workgraph/utils/analysis.py +++ b/src/aiida_workgraph/utils/analysis.py @@ -3,6 +3,7 @@ # import datetime from aiida.orm import ProcessNode from aiida.orm.utils.serialize import serialize, deserialize_unsafe +from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY class WorkGraphSaver: @@ -223,7 +224,7 @@ def insert_workgraph_to_db(self) -> None: # nodes is a copy of tasks, so we need to pop it out self.wgdata["error_handlers"] = serialize(self.wgdata["error_handlers"]) self.wgdata["context"] = serialize(self.wgdata["context"]) - self.process.base.extras.set("_workgraph", self.wgdata) + self.process.base.extras.set(WORKGRAPH_EXTRA_KEY, self.wgdata) def save_task_states(self) -> Dict: """Get task states.""" @@ -277,7 +278,7 @@ def get_wgdata_from_db( ) -> Optional[Dict]: process = self.process if process is None else process - wgdata = process.base.extras.get("_workgraph", None) + wgdata = process.base.extras.get(WORKGRAPH_EXTRA_KEY, None) if wgdata is None: print("No workgraph data found in the process node.") return @@ -318,7 +319,7 @@ def exist_in_db(self) -> bool: Returns: bool: _description_ """ - if self.process.base.extras.get("_workgraph", None) is not None: + if self.process.base.extras.get(WORKGRAPH_EXTRA_KEY, None) is not None: return True return False diff --git a/src/aiida_workgraph/workgraph.py b/src/aiida_workgraph/workgraph.py index aad7b10e..dfd371bf 100644 --- a/src/aiida_workgraph/workgraph.py +++ b/src/aiida_workgraph/workgraph.py @@ -362,8 +362,7 @@ def load(cls, pk: int) -> Optional["WorkGraph"]: process = aiida.orm.load_node(pk) wgdata = get_workgraph_data(process) if wgdata is None: - print("No workgraph data found in the process node.") - return + raise ValueError(f"WorkGraph data not found for process {pk}.") wg = cls.from_dict(wgdata) wg.process = process wg.update() diff --git a/tests/test_error_handler.py b/tests/test_error_handler.py index a7e7d89d..d1eff81f 100644 --- a/tests/test_error_handler.py +++ b/tests/test_error_handler.py @@ -38,6 +38,7 @@ def handle_negative_sum(task: Task): } }, ) + assert len(wg.error_handlers) == 1 wg.submit( inputs={ "add1": {"code": add_code, "x": orm.Int(1), "y": orm.Int(-2)}, diff --git a/tests/test_workgraph.py b/tests/test_workgraph.py index 1fa1366b..f65a7329 100644 --- a/tests/test_workgraph.py +++ b/tests/test_workgraph.py @@ -28,8 +28,29 @@ def test_add_task(): assert len(wg.links) == 1 +def test_show_state(wg_calcfunction): + from io import StringIO + import sys + + # Redirect stdout to capture prints + captured_output = StringIO() + sys.stdout = captured_output + # Call the method + wg_calcfunction.name = "test_show_state" + wg_calcfunction.show() + # Reset stdout + sys.stdout = sys.__stdout__ + # Check the output + output = captured_output.getvalue() + assert "WorkGraph: test_show_state, PK: None, State: CREATED" in output + assert "sumdiff1" in output + assert "PLANNED" in output + + def test_save_load(wg_calcfunction): """Save the workgraph""" + from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY + wg = wg_calcfunction wg.name = "test_save_load" wg.save() @@ -38,6 +59,12 @@ def test_save_load(wg_calcfunction): assert wg.process.label == "test_save_load" wg2 = WorkGraph.load(wg.process.pk) assert len(wg.tasks) == len(wg2.tasks) + # remove the extra + wg.process.base.extras.delete(WORKGRAPH_EXTRA_KEY) + with pytest.raises( + ValueError, match=f"WorkGraph data not found for process {wg.process.pk}." + ): + WorkGraph.load(wg.process.pk) def test_organize_nested_inputs(): @@ -86,7 +113,7 @@ def test_reset_message(wg_calcjob): assert "Action: reset. {'add2'}" in report -def test_restart(wg_calcfunction): +def test_restart_and_reset(wg_calcfunction): """Restart from a finished workgraph. Load the workgraph, modify the task, and restart the workgraph. Only the modified node and its child tasks will be rerun.""" @@ -109,6 +136,10 @@ def test_restart(wg_calcfunction): assert wg1.tasks["sumdiff2"].node.pk != wg.tasks["sumdiff2"].pk assert wg1.tasks["sumdiff3"].node.pk != wg.tasks["sumdiff3"].pk assert wg1.tasks["sumdiff3"].node.outputs.sum == 19 + wg1.reset() + assert wg1.process is None + assert wg1.tasks["sumdiff3"].process is None + assert wg1.tasks["sumdiff3"].state == "PLANNED" def test_extend_workgraph(decorated_add_multiply_group): diff --git a/tests/widget/test_widget.py b/tests/widget/test_widget.py index 3c934a5b..33c2ad7e 100644 --- a/tests/widget/test_widget.py +++ b/tests/widget/test_widget.py @@ -14,6 +14,8 @@ def test_workgraph_widget(wg_calcfunction): # to_html data = wg.to_html() assert isinstance(data, IFrame) + # check _repr_mimebundle_ is working + data = wg._repr_mimebundle_() def test_workgraph_task(wg_calcfunction): @@ -26,3 +28,5 @@ def test_workgraph_task(wg_calcfunction): # to html data = wg.tasks["sumdiff2"].to_html() assert isinstance(data, IFrame) + # check _repr_mimebundle_ is working + data = wg._repr_mimebundle_()