From cae7fed4911d087abca8b90046cc4e65e792db79 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Thu, 5 Sep 2024 18:12:11 +0200 Subject: [PATCH] Avoids deepcopy the input data of the tasks, bump node-graph to 0.0.16(#291) Bump node-graph to 0.0.16. This avoids deepcopy the input data of the tasks --- aiida_workgraph/utils/analysis.py | 14 ++++++++++---- pyproject.toml | 2 +- tests/test_shell.py | 6 ++++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/aiida_workgraph/utils/analysis.py b/aiida_workgraph/utils/analysis.py index 7b17f0e6..0ce58d46 100644 --- a/aiida_workgraph/utils/analysis.py +++ b/aiida_workgraph/utils/analysis.py @@ -214,7 +214,6 @@ def insert_workgraph_to_db(self) -> None: prop["value"] = PickledLocalFunction(prop["value"]).store() self.wgdata["tasks"][name] = serialize(task) # nodes is a copy of tasks, so we need to pop it out - self.wgdata.pop("nodes") self.wgdata["error_handlers"] = serialize(self.wgdata["error_handlers"]) self.wgdata["context"] = serialize(self.wgdata["context"]) self.process.base.extras.set("_workgraph", self.wgdata) @@ -277,8 +276,6 @@ def get_wgdata_from_db( return for name, task in wgdata["tasks"].items(): wgdata["tasks"][name] = deserialize_unsafe(task) - # also make a alias for nodes - wgdata["nodes"] = wgdata["tasks"] wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"]) return wgdata @@ -294,12 +291,18 @@ def check_diff( from node_graph.analysis import DifferenceAnalysis wg1 = self.get_wgdata_from_db(restart_process) + # change tasks to nodes for DifferenceAnalysis + wg1["nodes"] = wg1.pop("tasks") + self.wgdata["nodes"] = self.wgdata.pop("tasks") dc = DifferenceAnalysis(nt1=wg1, nt2=self.wgdata) ( new_tasks, modified_tasks, update_metadata, ) = dc.build_difference() + # change nodes back to tasks + wg1["tasks"] = wg1.pop("nodes") + self.wgdata["tasks"] = self.wgdata.pop("nodes") return new_tasks, modified_tasks, update_metadata def exist_in_db(self) -> bool: @@ -316,7 +319,10 @@ def build_connectivity(self) -> None: """Analyze the connectivity of workgraph and save it into dict.""" from node_graph.analysis import ConnectivityAnalysis - self.wgdata["nodes"] = self.wgdata["tasks"] + # ConnectivityAnalysis use nodes instead of tasks + self.wgdata["nodes"] = self.wgdata.pop("tasks") nc = ConnectivityAnalysis(self.wgdata) self.wgdata["connectivity"] = nc.build_connectivity() self.wgdata["connectivity"]["zone"] = {} + # change nodes back to tasks + self.wgdata["tasks"] = self.wgdata.pop("nodes") diff --git a/pyproject.toml b/pyproject.toml index b67eee61..0f7df691 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "numpy~=1.21", "scipy", "ase", - "node-graph>=0.0.14", + "node-graph>=0.0.16", "aiida-core>=2.3", "cloudpickle", "aiida-shell", diff --git a/tests/test_shell.py b/tests/test_shell.py index 74e77281..81246d75 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -1,11 +1,11 @@ import pytest from aiida_workgraph import WorkGraph, task from aiida_shell.launch import prepare_code -from aiida.orm import SinglefileData +from aiida.orm import SinglefileData, load_computer @pytest.mark.usefixtures("started_daemon_client") -def test_shell_command(): +def test_shell_command(fixture_localhost): """Test the ShellJob with command as a string.""" wg = WorkGraph(name="test_shell_command") job1 = wg.add_task( @@ -18,6 +18,8 @@ def test_shell_command(): "file_b": SinglefileData.from_string("string b"), }, ) + # also check if we can set the computer explicitly + job1.set({"metadata.computer": load_computer("localhost")}) wg.submit(wait=True) assert job1.node.outputs.stdout.get_content() == "string astring b"