Skip to content

Commit

Permalink
Serialize the WorkGraph's context (#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 authored Aug 27, 2024
1 parent cff769e commit e000261
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
1 change: 1 addition & 0 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ def read_wgdata_from_base(self) -> t.Dict[str, t.Any]:
if isinstance(prop["value"], PickledLocalFunction):
prop["value"] = prop["value"].value
wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"])
wgdata["context"] = deserialize_unsafe(wgdata["context"])
return wgdata

def update_workgraph_from_base(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions aiida_workgraph/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def insert_workgraph_to_db(self) -> None:
# nodes is a copy of tasks, so we need to pop it out
self.wgdata.pop("nodes")
self.wgdata["error_handlers"] = serialize(self.wgdata["error_handlers"])
self.wgdata["context"] = serialize(self.wgdata["context"])
self.process.base.extras.set("_workgraph", self.wgdata)

def save_task_states(self) -> Dict:
Expand Down
9 changes: 7 additions & 2 deletions tests/test_ctx.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from aiida_workgraph import WorkGraph
from typing import Callable
from aiida.orm import Float
from aiida.orm import Float, ArrayData
import numpy as np


def test_workgraph_ctx(decorated_add: Callable) -> None:
"""Set/get data to/from context."""

wg = WorkGraph(name="test_workgraph_ctx")
wg.context = {"x": Float(2), "data.y": Float(3)}
# create a array data object to test if it can be set to context
# the workgraph should be able to serialize it
array = ArrayData()
array.set_array("matrix", np.array([[1, 2], [3, 4]]))
wg.context = {"x": Float(2), "data.y": Float(3), "array": array}
add1 = wg.add_task(decorated_add, "add1", x="{{ x }}", y="{{ data.y }}")
wg.add_task(
"workgraph.to_context", name="to_ctx1", key="x", value=add1.outputs["result"]
Expand Down

0 comments on commit e000261

Please sign in to comment.