Skip to content

Commit

Permalink
swap key and value in set_context (#371)
Browse files Browse the repository at this point in the history
* swap key and value in `set_context`

* fix monitor docs

* Rename ToContext to SetContext

* Renae FromContext to GetContext
  • Loading branch information
superstar54 authored Nov 30, 2024
1 parent d23c0dc commit 59a8d77
Show file tree
Hide file tree
Showing 22 changed files with 78 additions and 81 deletions.
23 changes: 10 additions & 13 deletions aiida_workgraph/engine/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def task_set_context(self, name: str) -> None:
from aiida_workgraph.utils import update_nested_dict

items = self.ctx._tasks[name]["context_mapping"]
for key, value in items.items():
result = self.ctx._tasks[name]["results"][key]
update_nested_dict(self.ctx, value, result)
for key, result_name in items.items():
result = self.ctx._tasks[name]["results"][result_name]
update_nested_dict(self.ctx, key, result)

def get_task_state_info(self, name: str, key: str) -> str:
"""Get task state info from ctx."""
Expand Down Expand Up @@ -213,11 +213,8 @@ def continue_workgraph(self) -> None:
def run_tasks(self, names: List[str], continue_workgraph: bool = True) -> None:
"""Run tasks.
Task type includes: Node, Data, CalcFunction, WorkFunction, CalcJob, WorkChain, GraphBuilder,
WorkGraph, PythonJob, ShellJob, While, If, Zone, FromContext, ToContext, Normal.
WorkGraph, PythonJob, ShellJob, While, If, Zone, GetContext, SetContext, Normal.
Here we use ToContext to pass the results of the run to the next step.
This will force the engine to wait for all the submitted processes to
finish before continuing to the next step.
"""
from aiida_workgraph.utils import (
get_executor,
Expand Down Expand Up @@ -285,10 +282,10 @@ def run_tasks(self, names: List[str], continue_workgraph: bool = True) -> None:
self.execute_if_task(task)
elif task_type == "ZONE":
self.execute_zone_task(task)
elif task_type == "FROM_CONTEXT":
self.execute_from_context_task(task, kwargs)
elif task_type == "TO_CONTEXT":
self.execute_to_context_task(task, kwargs)
elif task_type == "GET_CONTEXT":
self.execute_get_context_task(task, kwargs)
elif task_type == "SET_CONTEXT":
self.execute_set_context_task(task, kwargs)
elif task_type == "AWAITABLE":
self.execute_awaitable_task(
task, executor, args, kwargs, var_args, var_kwargs
Expand Down Expand Up @@ -499,7 +496,7 @@ def execute_zone_task(self, task):
self.set_task_state_info(name, "state", "RUNNING")
self.continue_workgraph()

def execute_from_context_task(self, task, kwargs):
def execute_get_context_task(self, task, kwargs):
# get the results from the context
name = task["name"]
results = {"result": getattr(self.ctx, kwargs["key"])}
Expand All @@ -508,7 +505,7 @@ def execute_from_context_task(self, task, kwargs):
self.update_parent_task_state(name)
self.continue_workgraph()

def execute_to_context_task(self, task, kwargs):
def execute_set_context_task(self, task, kwargs):
name = task["name"]
# get the results from the context
setattr(self.ctx, kwargs["key"], kwargs["value"])
Expand Down
8 changes: 5 additions & 3 deletions aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ def to_dict(self, short: bool = False) -> Dict[str, Any]:
return tdata

def set_context(self, context: Dict[str, Any]) -> None:
"""Update the context mappings for this task."""
# all keys should belong to the outputs.keys()
remain_keys = set(context.keys()).difference(self.outputs.keys())
"""Set the output of the task as a value in the context.
key is the context key, value is the output key.
"""
# all values should belong to the outputs.keys()
remain_keys = set(context.values()).difference(self.outputs.keys())
if remain_keys:
msg = f"Keys {remain_keys} are not in the outputs of this task."
raise ValueError(msg)
Expand Down
20 changes: 10 additions & 10 deletions aiida_workgraph/tasks/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ def create_sockets(self) -> None:
self.outputs.new("workgraph.any", "result")


class ToContext(Task):
"""ToContext"""
class SetContext(Task):
"""SetContext"""

identifier = "workgraph.to_context"
name = "ToContext"
node_type = "TO_CONTEXT"
identifier = "workgraph.set_context"
name = "SetContext"
node_type = "SET_CONTEXT"
catalog = "Control"
args = ["key", "value"]

Expand All @@ -114,12 +114,12 @@ def create_sockets(self) -> None:
self.outputs.new("workgraph.any", "_wait")


class FromContext(Task):
"""FromContext"""
class GetContext(Task):
"""GetContext"""

identifier = "workgraph.from_context"
name = "FromContext"
node_type = "FROM_CONTEXT"
identifier = "workgraph.get_context"
name = "GetContext"
node_type = "GET_CONTEXT"
catalog = "Control"
args = ["key"]

Expand Down
2 changes: 1 addition & 1 deletion docs/gallery/howto/autogen/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def generator_loop(nb_iterations: Int):
wg = WorkGraph()
for i in range(nb_iterations.value): # this can be chosen as wanted
generator_task = wg.add_task(generator, name=f"generator{i}", seed=Int(i))
generator_task.set_context({"result": f"generated.seed{i}"})
generator_task.set_context({f"generated.seed{i}": "result"})
return wg


Expand Down
4 changes: 2 additions & 2 deletions docs/gallery/howto/autogen/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def for_loop(nb_iterations: Int):
# of the graph builder decorator.

# Put result of the task to the context under the name task_out
task.set_context({"result": "task_out"})
task.set_context({"task_out": "result"})
# If want to know more about the usage of the context please refer to the
# context howto in the documentation
return wg
Expand Down Expand Up @@ -244,7 +244,7 @@ def if_then_else(i: Int):
task = wg.add_task(modulo_two, x=i)

# same concept as before, please read the for loop example for explanation
task.set_context({"result": "task_out"})
task.set_context({"task_out": "result"})
return wg


Expand Down
2 changes: 1 addition & 1 deletion docs/gallery/howto/autogen/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def multiply_parallel_gather(X, y):
multiply1 = wg.add_task(multiply, x=value, y=y)
# add result of multiply1 to `self.context.mul`
# self.context.mul is a dict {"a": value1, "b": value2, "c": value3}
multiply1.set_context({"result": f"mul.{key}"})
multiply1.set_context({f"mul.{key}": "result"})
return wg


Expand Down
2 changes: 1 addition & 1 deletion docs/gallery/tutorial/autogen/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def all_scf(structures, scf_inputs):
pw1 = wg.add_task(PwCalculation, name=f"pw1_{key}", structure=structure)
pw1.set(scf_inputs)
# save the output parameters to the context
pw1.set_context({"output_parameters": f"result.{key}"})
pw1.set_context({f"result.{key}": "output_parameters"})
return wg


Expand Down
2 changes: 1 addition & 1 deletion docs/gallery/tutorial/autogen/zero_to_hero.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def all_scf(structures, scf_inputs):
pw1 = wg.add_task(PwCalculation, name=f"pw1_{key}", structure=structure)
pw1.set(scf_inputs)
# save the output parameters to the context
pw1.set_context({"output_parameters": f"result.{key}"})
pw1.set_context({f"result.{key}": "output_parameters"})
return wg


Expand Down
2 changes: 1 addition & 1 deletion docs/source/built-in/pythonjob.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,7 @@
" emt1 = wg.add_task(\"PythonJob\", function=emt, name=f\"emt1_{key}\", atoms=atoms)\n",
" emt1.set({\"computer\": \"localhost\"})\n",
" # save the output parameters to the context\n",
" emt1.set_context({\"result\": f\"results.{key}\"})\n",
" emt1.set_context({f\"results.{key}\": \"result\"})\n",
" return wg\n",
"\n",
"\n",
Expand Down
32 changes: 16 additions & 16 deletions docs/source/howto/context.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
"metadata": {},
"source": [
"## Introduction\n",
"In AiiDA workflow, the context is a internal data container that can hold and pass information between steps. It is very usefull for complex workflows.\n",
"In AiiDA workflow, the context is a internal container that can hold data that shared between different tasks. It is very usefull for complex workflows.\n",
"\n",
"## Pass data to context\n",
"\n",
"There are three ways to pass data to context.\n",
"There are three ways to set data to context.\n",
"\n",
"- Initialize the context data when creating the WorkGraph.\n",
" ```python\n",
Expand All @@ -27,30 +27,30 @@
" wg.context = {\"x\": Int(2), \"data.y\": Int(3)}\n",
" ```\n",
"\n",
"- Export the task result to context.\n",
"- Set the task result to context when the task is done.\n",
" ```python\n",
" # define add task\n",
" @task.calcfunction()\n",
" def add(x, y):\n",
" return x + y\n",
" add1 = wg.add_task(add, \"add1\", x=2, y=3)\n",
" # set result of add1 to context.sum\n",
" add1.set_context({\"result\": \"sum\"})\n",
" add1.set_context({\"sum\": \"result\"})\n",
" ```\n",
"\n",
"- Use the `to_context` task to save the result to context.\n",
"- Use the `set_context` task to set either the task result or a constant value to the context.\n",
"\n",
" ```python\n",
" wg.add_task(\"workgraph.to_context\", name=\"to_ctx1\", key=\"sum\", value=add1.outputs[\"result\"])\n",
" wg.add_task(\"workgraph.set_context\", name=\"set_ctx1\", key=\"sum\", value=add1.outputs[\"result\"])\n",
" ```\n",
"\n",
"\n",
"### Nested context keys\n",
"To organize the context data (e.g. group data), The keys may contain dots `.`, which will creating dictionary in the context. Here is an example, to group the results of all add tasks to `context.sum`:\n",
"\n",
"```python\n",
"add1.set_context({\"result\": \"sum.add1\"})\n",
"add2.set_context({\"result\": \"sum.add2\"})\n",
"add1.set_context({\"sum.add1\": \"result\"})\n",
"add2.set_context({\"sum.add2\": \"result\"})\n",
"```\n",
"here, `context.sum` will be:\n",
"```python\n",
Expand All @@ -75,14 +75,14 @@
" nt = WorkGraph(\"while_workgraph\")\n",
" add1 = wg.add_task(add, x=2, y=3)\n",
" add2 = wg.add_task(add, x=2, y=3)\n",
" add1.set_context({\"result\": \"sum.add1\"})\n",
" add2.set_context({\"result\": \"sum.add2\"})\n",
" add1.set_context({\"sum.add1\": \"result\"})\n",
" add2.set_context({\"sum.add2\": \"result\"})\n",
" ```\n",
"\n",
"- One can use the `from_context` task to get the data from context. **This task will be shown in the GUI**\n",
"- One can use the `get_context` task to get the data from context. **This task will be shown in the GUI**\n",
"\n",
" ```python\n",
" wg.add_task(\"workgraph.from_context\", name=\"from_ctx1\", key=\"sum.add1\")\n",
" wg.add_task(\"workgraph.get_context\", name=\"get_ctx1\", key=\"sum.add1\")\n",
" ```"
]
},
Expand Down Expand Up @@ -136,10 +136,10 @@
"wg = WorkGraph(name=\"test_workgraph_ctx\")\n",
"# Set the context of the workgraph\n",
"wg.context = {\"x\": 2, \"data.y\": 3}\n",
"from_ctx1 = wg.add_task(\"workgraph.from_context\", name=\"from_ctx1\", key=\"x\")\n",
"add1 = wg.add_task(add, \"add1\", x=from_ctx1.outputs[\"result\"],\n",
"get_ctx1 = wg.add_task(\"workgraph.get_context\", name=\"get_ctx1\", key=\"x\")\n",
"add1 = wg.add_task(add, \"add1\", x=get_ctx1.outputs[\"result\"],\n",
" y=\"{{data.y}}\")\n",
"to_ctx1 = wg.add_task(\"workgraph.to_context\", name=\"to_ctx1\", key=\"x\",\n",
"set_ctx1 = wg.add_task(\"workgraph.set_context\", name=\"set_ctx1\", key=\"x\",\n",
" value=add1.outputs[\"result\"])\n",
"wg.to_html()\n",
"# wg"
Expand All @@ -150,7 +150,7 @@
"id": "f6969061",
"metadata": {},
"source": [
"As shown in the GUI, the `from_context` task and `to_context` tasks are shown in the GUI. However, the context variable using the `set_context` method or `{{}}` is not shown in the GUI."
"As shown in the GUI, the `get_context` task and `to_context` tasks are shown in the GUI. However, the context variable using the `set_context` method or `{{}}` is not shown in the GUI."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/source/howto/for.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
" multiply1 = wg.add_task(multiply, name=\"multiply1\", x=\"{{ i }}\", y=2)\n",
" add1 = wg.add_task(add, name=\"add1\", x=\"{{ total }}\")\n",
" # update the context variable\n",
" add1.set_context({\"result\": \"total\"})\n",
" add1.set_context({\"total\": \"result\"})\n",
" wg.add_link(multiply1.outputs[\"result\"], add1.inputs[\"y\"])\n",
" # don't forget to return the workgraph\n",
" return wg"
Expand Down
2 changes: 1 addition & 1 deletion docs/source/howto/html/test_workgraph_ctx.html
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
const { RenderUtils } = ReteRenderUtils;
const styled = window.styled;

const workgraphData = {"name": "test_workgraph_ctx", "uuid": "50bec156-5a60-11ef-888c-906584de3e5b", "state": "CREATED", "nodes": {"from_ctx1": {"label": "from_ctx1", "node_type": "FROM_CONTEXT", "inputs": [{"name": "key", "identifier": "workgraph.any", "uuid": "50c8610c-5a60-11ef-888c-906584de3e5b", "node_uuid": "50c85f5e-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "add1": {"label": "add1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "50d07e78-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d0789c-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "from_ctx1", "from_socket": "result", "from_socket_uuid": "50c861fc-5a60-11ef-888c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "50d07edc-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d0789c-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "x"}], "outputs": [{"name": "result"}], "position": [60, 60], "children": []}, "to_ctx1": {"label": "to_ctx1", "node_type": "TO_CONTEXT", "inputs": [{"name": "key", "identifier": "workgraph.any", "uuid": "50d858be-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d856f2-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "value", "identifier": "workgraph.any", "uuid": "50d8599a-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d856f2-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add1", "from_socket": "result", "from_socket_uuid": "50d07fa4-5a60-11ef-888c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "value"}], "outputs": [], "position": [90, 90], "children": []}}, "links": [{"from_socket": "result", "from_node": "from_ctx1", "from_socket_uuid": "50c861fc-5a60-11ef-888c-906584de3e5b", "to_socket": "x", "to_node": "add1", "state": false}, {"from_socket": "result", "from_node": "add1", "from_socket_uuid": "50d07fa4-5a60-11ef-888c-906584de3e5b", "to_socket": "value", "to_node": "to_ctx1", "state": false}]}
const workgraphData = {"name": "test_workgraph_ctx", "uuid": "50bec156-5a60-11ef-888c-906584de3e5b", "state": "CREATED", "nodes": {"get_ctx1": {"label": "get_ctx1", "node_type": "GET_CONTEXT", "inputs": [{"name": "key", "identifier": "workgraph.any", "uuid": "50c8610c-5a60-11ef-888c-906584de3e5b", "node_uuid": "50c85f5e-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "add1": {"label": "add1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "50d07e78-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d0789c-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "get_ctx1", "from_socket": "result", "from_socket_uuid": "50c861fc-5a60-11ef-888c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "50d07edc-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d0789c-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "x"}], "outputs": [{"name": "result"}], "position": [60, 60], "children": []}, "set_ctx1": {"label": "set_ctx1", "node_type": "SET_CONTEXT", "inputs": [{"name": "key", "identifier": "workgraph.any", "uuid": "50d858be-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d856f2-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "value", "identifier": "workgraph.any", "uuid": "50d8599a-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d856f2-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add1", "from_socket": "result", "from_socket_uuid": "50d07fa4-5a60-11ef-888c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "value"}], "outputs": [], "position": [90, 90], "children": []}}, "links": [{"from_socket": "result", "from_node": "get_ctx1", "from_socket_uuid": "50c861fc-5a60-11ef-888c-906584de3e5b", "to_socket": "x", "to_node": "add1", "state": false}, {"from_socket": "result", "from_node": "add1", "from_socket_uuid": "50d07fa4-5a60-11ef-888c-906584de3e5b", "to_socket": "value", "to_node": "set_ctx1", "state": false}]}

// Define Schemes to use in vanilla JS
const Schemes = {
Expand Down
8 changes: 3 additions & 5 deletions docs/source/howto/monitor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,10 @@
"monitor2 = wg.add_task(\"workgraph.file_monitor\", filepath=\"/tmp/test.txt\")\n",
"```\n",
"\n",
"## Awaitable Task Decorator\n",
"### Awaitable Task Decorator\n",
"\n",
"The `awaitable` decorator allows for the integration of `asyncio` within tasks, letting users control asynchronous functions.\n",
"\n",
"### General Awaitable Task\n",
"\n",
"Define and use an awaitable task within the WorkGraph.\n",
"\n"
]
Expand Down Expand Up @@ -204,7 +202,7 @@
"id": "1ae83d3f",
"metadata": {},
"source": [
"## Kill the monitor task\n",
"### Kill the monitor task\n",
"\n",
"One can kill a running monitor task by using the following command:\n",
"\n",
Expand All @@ -220,7 +218,7 @@
"\n",
"The awaitable task lets the WorkGraph enter a `Waiting` state, yielding control to the asyncio event loop. This enables other tasks to run concurrently, although long-running calculations may delay the execution of awaitable tasks.\n",
"\n",
"## Conclusion\n",
"### Conclusion\n",
"\n",
"These enhancements provide powerful tools for managing dependencies and asynchronous operations within WorkGraph, offering greater flexibility and efficiency in task execution."
]
Expand Down
2 changes: 1 addition & 1 deletion docs/source/howto/parallel.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@
" multiply1 = wg.add_task(multiply, x=value, y=y)\n",
" # add result of multiply1 to `self.context.mul`\n",
" # self.context.mul is a dict {\"a\": value1, \"b\": value2, \"c\": value3}\n",
" multiply1.set_context({\"result\": f\"mul.{key}\"})\n",
" multiply1.set_context({f\"mul.{key}\": \"result\"})\n",
" return wg\n",
"\n",
"@task.calcfunction()\n",
Expand Down
Loading

0 comments on commit 59a8d77

Please sign in to comment.