Skip to content

Commit

Permalink
Avoids deepcopy the input data of the tasks, bump node-graph to 0.0.16(
Browse files Browse the repository at this point in the history
…#291)

Bump node-graph to 0.0.16. This avoids deepcopy the input data of the tasks
  • Loading branch information
superstar54 authored Sep 5, 2024
1 parent d9de55a commit cae7fed
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
14 changes: 10 additions & 4 deletions aiida_workgraph/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def insert_workgraph_to_db(self) -> None:
prop["value"] = PickledLocalFunction(prop["value"]).store()
self.wgdata["tasks"][name] = serialize(task)
# nodes is a copy of tasks, so we need to pop it out
self.wgdata.pop("nodes")
self.wgdata["error_handlers"] = serialize(self.wgdata["error_handlers"])
self.wgdata["context"] = serialize(self.wgdata["context"])
self.process.base.extras.set("_workgraph", self.wgdata)
Expand Down Expand Up @@ -277,8 +276,6 @@ def get_wgdata_from_db(
return
for name, task in wgdata["tasks"].items():
wgdata["tasks"][name] = deserialize_unsafe(task)
# also make a alias for nodes
wgdata["nodes"] = wgdata["tasks"]
wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"])
return wgdata

Expand All @@ -294,12 +291,18 @@ def check_diff(
from node_graph.analysis import DifferenceAnalysis

wg1 = self.get_wgdata_from_db(restart_process)
# change tasks to nodes for DifferenceAnalysis
wg1["nodes"] = wg1.pop("tasks")
self.wgdata["nodes"] = self.wgdata.pop("tasks")
dc = DifferenceAnalysis(nt1=wg1, nt2=self.wgdata)
(
new_tasks,
modified_tasks,
update_metadata,
) = dc.build_difference()
# change nodes back to tasks
wg1["tasks"] = wg1.pop("nodes")
self.wgdata["tasks"] = self.wgdata.pop("nodes")
return new_tasks, modified_tasks, update_metadata

def exist_in_db(self) -> bool:
Expand All @@ -316,7 +319,10 @@ def build_connectivity(self) -> None:
"""Analyze the connectivity of workgraph and save it into dict."""
from node_graph.analysis import ConnectivityAnalysis

self.wgdata["nodes"] = self.wgdata["tasks"]
# ConnectivityAnalysis use nodes instead of tasks
self.wgdata["nodes"] = self.wgdata.pop("tasks")
nc = ConnectivityAnalysis(self.wgdata)
self.wgdata["connectivity"] = nc.build_connectivity()
self.wgdata["connectivity"]["zone"] = {}
# change nodes back to tasks
self.wgdata["tasks"] = self.wgdata.pop("nodes")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"numpy~=1.21",
"scipy",
"ase",
"node-graph>=0.0.14",
"node-graph>=0.0.16",
"aiida-core>=2.3",
"cloudpickle",
"aiida-shell",
Expand Down
6 changes: 4 additions & 2 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
from aiida_workgraph import WorkGraph, task
from aiida_shell.launch import prepare_code
from aiida.orm import SinglefileData
from aiida.orm import SinglefileData, load_computer


@pytest.mark.usefixtures("started_daemon_client")
def test_shell_command():
def test_shell_command(fixture_localhost):
"""Test the ShellJob with command as a string."""
wg = WorkGraph(name="test_shell_command")
job1 = wg.add_task(
Expand All @@ -18,6 +18,8 @@ def test_shell_command():
"file_b": SinglefileData.from_string("string b"),
},
)
# also check if we can set the computer explicitly
job1.set({"metadata.computer": load_computer("localhost")})
wg.submit(wait=True)
assert job1.node.outputs.stdout.get_content() == "string astring b"

Expand Down

0 comments on commit cae7fed

Please sign in to comment.