diff --git a/aiida_workgraph/engine/task_manager.py b/aiida_workgraph/engine/task_manager.py index 7997dbcc..c8f5153a 100644 --- a/aiida_workgraph/engine/task_manager.py +++ b/aiida_workgraph/engine/task_manager.py @@ -122,9 +122,9 @@ def task_set_context(self, name: str) -> None: from aiida_workgraph.utils import update_nested_dict items = self.ctx._tasks[name]["context_mapping"] - for key, value in items.items(): - result = self.ctx._tasks[name]["results"][key] - update_nested_dict(self.ctx, value, result) + for key, result_name in items.items(): + result = self.ctx._tasks[name]["results"][result_name] + update_nested_dict(self.ctx, key, result) def get_task_state_info(self, name: str, key: str) -> str: """Get task state info from ctx.""" @@ -213,11 +213,8 @@ def continue_workgraph(self) -> None: def run_tasks(self, names: List[str], continue_workgraph: bool = True) -> None: """Run tasks. Task type includes: Node, Data, CalcFunction, WorkFunction, CalcJob, WorkChain, GraphBuilder, - WorkGraph, PythonJob, ShellJob, While, If, Zone, FromContext, ToContext, Normal. + WorkGraph, PythonJob, ShellJob, While, If, Zone, GetContext, SetContext, Normal. - Here we use ToContext to pass the results of the run to the next step. - This will force the engine to wait for all the submitted processes to - finish before continuing to the next step. """ from aiida_workgraph.utils import ( get_executor, @@ -285,10 +282,10 @@ def run_tasks(self, names: List[str], continue_workgraph: bool = True) -> None: self.execute_if_task(task) elif task_type == "ZONE": self.execute_zone_task(task) - elif task_type == "FROM_CONTEXT": - self.execute_from_context_task(task, kwargs) - elif task_type == "TO_CONTEXT": - self.execute_to_context_task(task, kwargs) + elif task_type == "GET_CONTEXT": + self.execute_get_context_task(task, kwargs) + elif task_type == "SET_CONTEXT": + self.execute_set_context_task(task, kwargs) elif task_type == "AWAITABLE": self.execute_awaitable_task( task, executor, args, kwargs, var_args, var_kwargs @@ -499,7 +496,7 @@ def execute_zone_task(self, task): self.set_task_state_info(name, "state", "RUNNING") self.continue_workgraph() - def execute_from_context_task(self, task, kwargs): + def execute_get_context_task(self, task, kwargs): # get the results from the context name = task["name"] results = {"result": getattr(self.ctx, kwargs["key"])} @@ -508,7 +505,7 @@ def execute_from_context_task(self, task, kwargs): self.update_parent_task_state(name) self.continue_workgraph() - def execute_to_context_task(self, task, kwargs): + def execute_set_context_task(self, task, kwargs): name = task["name"] # get the results from the context setattr(self.ctx, kwargs["key"], kwargs["value"]) diff --git a/aiida_workgraph/task.py b/aiida_workgraph/task.py index d8af66eb..cc4e421b 100644 --- a/aiida_workgraph/task.py +++ b/aiida_workgraph/task.py @@ -74,9 +74,11 @@ def to_dict(self, short: bool = False) -> Dict[str, Any]: return tdata def set_context(self, context: Dict[str, Any]) -> None: - """Update the context mappings for this task.""" - # all keys should belong to the outputs.keys() - remain_keys = set(context.keys()).difference(self.outputs.keys()) + """Set the output of the task as a value in the context. + key is the context key, value is the output key. + """ + # all values should belong to the outputs.keys() + remain_keys = set(context.values()).difference(self.outputs.keys()) if remain_keys: msg = f"Keys {remain_keys} are not in the outputs of this task." raise ValueError(msg) diff --git a/aiida_workgraph/tasks/builtins.py b/aiida_workgraph/tasks/builtins.py index 1fe5cee8..4c563a82 100644 --- a/aiida_workgraph/tasks/builtins.py +++ b/aiida_workgraph/tasks/builtins.py @@ -95,12 +95,12 @@ def create_sockets(self) -> None: self.outputs.new("workgraph.any", "result") -class ToContext(Task): - """ToContext""" +class SetContext(Task): + """SetContext""" - identifier = "workgraph.to_context" - name = "ToContext" - node_type = "TO_CONTEXT" + identifier = "workgraph.set_context" + name = "SetContext" + node_type = "SET_CONTEXT" catalog = "Control" args = ["key", "value"] @@ -114,12 +114,12 @@ def create_sockets(self) -> None: self.outputs.new("workgraph.any", "_wait") -class FromContext(Task): - """FromContext""" +class GetContext(Task): + """GetContext""" - identifier = "workgraph.from_context" - name = "FromContext" - node_type = "FROM_CONTEXT" + identifier = "workgraph.get_context" + name = "GetContext" + node_type = "GET_CONTEXT" catalog = "Control" args = ["key"] diff --git a/docs/gallery/howto/autogen/aggregate.py b/docs/gallery/howto/autogen/aggregate.py index 96ef206f..fc3138fc 100644 --- a/docs/gallery/howto/autogen/aggregate.py +++ b/docs/gallery/howto/autogen/aggregate.py @@ -255,7 +255,7 @@ def generator_loop(nb_iterations: Int): wg = WorkGraph() for i in range(nb_iterations.value): # this can be chosen as wanted generator_task = wg.add_task(generator, name=f"generator{i}", seed=Int(i)) - generator_task.set_context({"result": f"generated.seed{i}"}) + generator_task.set_context({f"generated.seed{i}": "result"}) return wg diff --git a/docs/gallery/howto/autogen/graph_builder.py b/docs/gallery/howto/autogen/graph_builder.py index 89ec4cb3..f67d1a04 100644 --- a/docs/gallery/howto/autogen/graph_builder.py +++ b/docs/gallery/howto/autogen/graph_builder.py @@ -197,7 +197,7 @@ def for_loop(nb_iterations: Int): # of the graph builder decorator. # Put result of the task to the context under the name task_out - task.set_context({"result": "task_out"}) + task.set_context({"task_out": "result"}) # If want to know more about the usage of the context please refer to the # context howto in the documentation return wg @@ -244,7 +244,7 @@ def if_then_else(i: Int): task = wg.add_task(modulo_two, x=i) # same concept as before, please read the for loop example for explanation - task.set_context({"result": "task_out"}) + task.set_context({"task_out": "result"}) return wg diff --git a/docs/gallery/howto/autogen/parallel.py b/docs/gallery/howto/autogen/parallel.py index 97ce88a3..240c22b4 100644 --- a/docs/gallery/howto/autogen/parallel.py +++ b/docs/gallery/howto/autogen/parallel.py @@ -94,7 +94,7 @@ def multiply_parallel_gather(X, y): multiply1 = wg.add_task(multiply, x=value, y=y) # add result of multiply1 to `self.context.mul` # self.context.mul is a dict {"a": value1, "b": value2, "c": value3} - multiply1.set_context({"result": f"mul.{key}"}) + multiply1.set_context({f"mul.{key}": "result"}) return wg diff --git a/docs/gallery/tutorial/autogen/eos.py b/docs/gallery/tutorial/autogen/eos.py index 8d88a42a..ae457d6b 100644 --- a/docs/gallery/tutorial/autogen/eos.py +++ b/docs/gallery/tutorial/autogen/eos.py @@ -54,7 +54,7 @@ def all_scf(structures, scf_inputs): pw1 = wg.add_task(PwCalculation, name=f"pw1_{key}", structure=structure) pw1.set(scf_inputs) # save the output parameters to the context - pw1.set_context({"output_parameters": f"result.{key}"}) + pw1.set_context({f"result.{key}": "output_parameters"}) return wg diff --git a/docs/gallery/tutorial/autogen/zero_to_hero.py b/docs/gallery/tutorial/autogen/zero_to_hero.py index 664ca9a2..151e1981 100644 --- a/docs/gallery/tutorial/autogen/zero_to_hero.py +++ b/docs/gallery/tutorial/autogen/zero_to_hero.py @@ -426,7 +426,7 @@ def all_scf(structures, scf_inputs): pw1 = wg.add_task(PwCalculation, name=f"pw1_{key}", structure=structure) pw1.set(scf_inputs) # save the output parameters to the context - pw1.set_context({"output_parameters": f"result.{key}"}) + pw1.set_context({f"result.{key}": "output_parameters"}) return wg diff --git a/docs/source/built-in/pythonjob.ipynb b/docs/source/built-in/pythonjob.ipynb index 88ed8c4d..a7e286ad 100644 --- a/docs/source/built-in/pythonjob.ipynb +++ b/docs/source/built-in/pythonjob.ipynb @@ -1511,7 +1511,7 @@ " emt1 = wg.add_task(\"PythonJob\", function=emt, name=f\"emt1_{key}\", atoms=atoms)\n", " emt1.set({\"computer\": \"localhost\"})\n", " # save the output parameters to the context\n", - " emt1.set_context({\"result\": f\"results.{key}\"})\n", + " emt1.set_context({f\"results.{key}\": \"result\"})\n", " return wg\n", "\n", "\n", diff --git a/docs/source/howto/context.ipynb b/docs/source/howto/context.ipynb index 016e36b0..9ccbca59 100644 --- a/docs/source/howto/context.ipynb +++ b/docs/source/howto/context.ipynb @@ -14,11 +14,11 @@ "metadata": {}, "source": [ "## Introduction\n", - "In AiiDA workflow, the context is a internal data container that can hold and pass information between steps. It is very usefull for complex workflows.\n", + "In AiiDA workflow, the context is a internal container that can hold data that shared between different tasks. It is very usefull for complex workflows.\n", "\n", "## Pass data to context\n", "\n", - "There are three ways to pass data to context.\n", + "There are three ways to set data to context.\n", "\n", "- Initialize the context data when creating the WorkGraph.\n", " ```python\n", @@ -27,7 +27,7 @@ " wg.context = {\"x\": Int(2), \"data.y\": Int(3)}\n", " ```\n", "\n", - "- Export the task result to context.\n", + "- Set the task result to context when the task is done.\n", " ```python\n", " # define add task\n", " @task.calcfunction()\n", @@ -35,13 +35,13 @@ " return x + y\n", " add1 = wg.add_task(add, \"add1\", x=2, y=3)\n", " # set result of add1 to context.sum\n", - " add1.set_context({\"result\": \"sum\"})\n", + " add1.set_context({\"sum\": \"result\"})\n", " ```\n", "\n", - "- Use the `to_context` task to save the result to context.\n", + "- Use the `set_context` task to set either the task result or a constant value to the context.\n", "\n", " ```python\n", - " wg.add_task(\"workgraph.to_context\", name=\"to_ctx1\", key=\"sum\", value=add1.outputs[\"result\"])\n", + " wg.add_task(\"workgraph.set_context\", name=\"set_ctx1\", key=\"sum\", value=add1.outputs[\"result\"])\n", " ```\n", "\n", "\n", @@ -49,8 +49,8 @@ "To organize the context data (e.g. group data), The keys may contain dots `.`, which will creating dictionary in the context. Here is an example, to group the results of all add tasks to `context.sum`:\n", "\n", "```python\n", - "add1.set_context({\"result\": \"sum.add1\"})\n", - "add2.set_context({\"result\": \"sum.add2\"})\n", + "add1.set_context({\"sum.add1\": \"result\"})\n", + "add2.set_context({\"sum.add2\": \"result\"})\n", "```\n", "here, `context.sum` will be:\n", "```python\n", @@ -75,14 +75,14 @@ " nt = WorkGraph(\"while_workgraph\")\n", " add1 = wg.add_task(add, x=2, y=3)\n", " add2 = wg.add_task(add, x=2, y=3)\n", - " add1.set_context({\"result\": \"sum.add1\"})\n", - " add2.set_context({\"result\": \"sum.add2\"})\n", + " add1.set_context({\"sum.add1\": \"result\"})\n", + " add2.set_context({\"sum.add2\": \"result\"})\n", " ```\n", "\n", - "- One can use the `from_context` task to get the data from context. **This task will be shown in the GUI**\n", + "- One can use the `get_context` task to get the data from context. **This task will be shown in the GUI**\n", "\n", " ```python\n", - " wg.add_task(\"workgraph.from_context\", name=\"from_ctx1\", key=\"sum.add1\")\n", + " wg.add_task(\"workgraph.get_context\", name=\"get_ctx1\", key=\"sum.add1\")\n", " ```" ] }, @@ -136,10 +136,10 @@ "wg = WorkGraph(name=\"test_workgraph_ctx\")\n", "# Set the context of the workgraph\n", "wg.context = {\"x\": 2, \"data.y\": 3}\n", - "from_ctx1 = wg.add_task(\"workgraph.from_context\", name=\"from_ctx1\", key=\"x\")\n", - "add1 = wg.add_task(add, \"add1\", x=from_ctx1.outputs[\"result\"],\n", + "get_ctx1 = wg.add_task(\"workgraph.get_context\", name=\"get_ctx1\", key=\"x\")\n", + "add1 = wg.add_task(add, \"add1\", x=get_ctx1.outputs[\"result\"],\n", " y=\"{{data.y}}\")\n", - "to_ctx1 = wg.add_task(\"workgraph.to_context\", name=\"to_ctx1\", key=\"x\",\n", + "set_ctx1 = wg.add_task(\"workgraph.set_context\", name=\"set_ctx1\", key=\"x\",\n", " value=add1.outputs[\"result\"])\n", "wg.to_html()\n", "# wg" @@ -150,7 +150,7 @@ "id": "f6969061", "metadata": {}, "source": [ - "As shown in the GUI, the `from_context` task and `to_context` tasks are shown in the GUI. However, the context variable using the `set_context` method or `{{}}` is not shown in the GUI." + "As shown in the GUI, the `get_context` task and `to_context` tasks are shown in the GUI. However, the context variable using the `set_context` method or `{{}}` is not shown in the GUI." ] }, { diff --git a/docs/source/howto/for.ipynb b/docs/source/howto/for.ipynb index e3d4df4d..6fda2663 100644 --- a/docs/source/howto/for.ipynb +++ b/docs/source/howto/for.ipynb @@ -103,7 +103,7 @@ " multiply1 = wg.add_task(multiply, name=\"multiply1\", x=\"{{ i }}\", y=2)\n", " add1 = wg.add_task(add, name=\"add1\", x=\"{{ total }}\")\n", " # update the context variable\n", - " add1.set_context({\"result\": \"total\"})\n", + " add1.set_context({\"total\": \"result\"})\n", " wg.add_link(multiply1.outputs[\"result\"], add1.inputs[\"y\"])\n", " # don't forget to return the workgraph\n", " return wg" diff --git a/docs/source/howto/html/test_workgraph_ctx.html b/docs/source/howto/html/test_workgraph_ctx.html index 0c3ffd7f..08f8b000 100644 --- a/docs/source/howto/html/test_workgraph_ctx.html +++ b/docs/source/howto/html/test_workgraph_ctx.html @@ -61,7 +61,7 @@ const { RenderUtils } = ReteRenderUtils; const styled = window.styled; - const workgraphData = {"name": "test_workgraph_ctx", "uuid": "50bec156-5a60-11ef-888c-906584de3e5b", "state": "CREATED", "nodes": {"from_ctx1": {"label": "from_ctx1", "node_type": "FROM_CONTEXT", "inputs": [{"name": "key", "identifier": "workgraph.any", "uuid": "50c8610c-5a60-11ef-888c-906584de3e5b", "node_uuid": "50c85f5e-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "add1": {"label": "add1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "50d07e78-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d0789c-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "from_ctx1", "from_socket": "result", "from_socket_uuid": "50c861fc-5a60-11ef-888c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "50d07edc-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d0789c-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "x"}], "outputs": [{"name": "result"}], "position": [60, 60], "children": []}, "to_ctx1": {"label": "to_ctx1", "node_type": "TO_CONTEXT", "inputs": [{"name": "key", "identifier": "workgraph.any", "uuid": "50d858be-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d856f2-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "value", "identifier": "workgraph.any", "uuid": "50d8599a-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d856f2-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add1", "from_socket": "result", "from_socket_uuid": "50d07fa4-5a60-11ef-888c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "value"}], "outputs": [], "position": [90, 90], "children": []}}, "links": [{"from_socket": "result", "from_node": "from_ctx1", "from_socket_uuid": "50c861fc-5a60-11ef-888c-906584de3e5b", "to_socket": "x", "to_node": "add1", "state": false}, {"from_socket": "result", "from_node": "add1", "from_socket_uuid": "50d07fa4-5a60-11ef-888c-906584de3e5b", "to_socket": "value", "to_node": "to_ctx1", "state": false}]} + const workgraphData = {"name": "test_workgraph_ctx", "uuid": "50bec156-5a60-11ef-888c-906584de3e5b", "state": "CREATED", "nodes": {"get_ctx1": {"label": "get_ctx1", "node_type": "GET_CONTEXT", "inputs": [{"name": "key", "identifier": "workgraph.any", "uuid": "50c8610c-5a60-11ef-888c-906584de3e5b", "node_uuid": "50c85f5e-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "add1": {"label": "add1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "50d07e78-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d0789c-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "get_ctx1", "from_socket": "result", "from_socket_uuid": "50c861fc-5a60-11ef-888c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "50d07edc-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d0789c-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "x"}], "outputs": [{"name": "result"}], "position": [60, 60], "children": []}, "set_ctx1": {"label": "set_ctx1", "node_type": "SET_CONTEXT", "inputs": [{"name": "key", "identifier": "workgraph.any", "uuid": "50d858be-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d856f2-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "value", "identifier": "workgraph.any", "uuid": "50d8599a-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d856f2-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add1", "from_socket": "result", "from_socket_uuid": "50d07fa4-5a60-11ef-888c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "value"}], "outputs": [], "position": [90, 90], "children": []}}, "links": [{"from_socket": "result", "from_node": "get_ctx1", "from_socket_uuid": "50c861fc-5a60-11ef-888c-906584de3e5b", "to_socket": "x", "to_node": "add1", "state": false}, {"from_socket": "result", "from_node": "add1", "from_socket_uuid": "50d07fa4-5a60-11ef-888c-906584de3e5b", "to_socket": "value", "to_node": "set_ctx1", "state": false}]} // Define Schemes to use in vanilla JS const Schemes = { diff --git a/docs/source/howto/monitor.ipynb b/docs/source/howto/monitor.ipynb index 3104ea86..0a4d1361 100644 --- a/docs/source/howto/monitor.ipynb +++ b/docs/source/howto/monitor.ipynb @@ -111,12 +111,10 @@ "monitor2 = wg.add_task(\"workgraph.file_monitor\", filepath=\"/tmp/test.txt\")\n", "```\n", "\n", - "## Awaitable Task Decorator\n", + "### Awaitable Task Decorator\n", "\n", "The `awaitable` decorator allows for the integration of `asyncio` within tasks, letting users control asynchronous functions.\n", "\n", - "### General Awaitable Task\n", - "\n", "Define and use an awaitable task within the WorkGraph.\n", "\n" ] @@ -204,7 +202,7 @@ "id": "1ae83d3f", "metadata": {}, "source": [ - "## Kill the monitor task\n", + "### Kill the monitor task\n", "\n", "One can kill a running monitor task by using the following command:\n", "\n", @@ -220,7 +218,7 @@ "\n", "The awaitable task lets the WorkGraph enter a `Waiting` state, yielding control to the asyncio event loop. This enables other tasks to run concurrently, although long-running calculations may delay the execution of awaitable tasks.\n", "\n", - "## Conclusion\n", + "### Conclusion\n", "\n", "These enhancements provide powerful tools for managing dependencies and asynchronous operations within WorkGraph, offering greater flexibility and efficiency in task execution." ] diff --git a/docs/source/howto/parallel.ipynb b/docs/source/howto/parallel.ipynb index d99041ae..5cdad860 100644 --- a/docs/source/howto/parallel.ipynb +++ b/docs/source/howto/parallel.ipynb @@ -401,7 +401,7 @@ " multiply1 = wg.add_task(multiply, x=value, y=y)\n", " # add result of multiply1 to `self.context.mul`\n", " # self.context.mul is a dict {\"a\": value1, \"b\": value2, \"c\": value3}\n", - " multiply1.set_context({\"result\": f\"mul.{key}\"})\n", + " multiply1.set_context({f\"mul.{key}\": \"result\"})\n", " return wg\n", "\n", "@task.calcfunction()\n", diff --git a/docs/source/howto/waiting_on.ipynb b/docs/source/howto/waiting_on.ipynb index bd41045b..f9c2e2c7 100644 --- a/docs/source/howto/waiting_on.ipynb +++ b/docs/source/howto/waiting_on.ipynb @@ -105,9 +105,9 @@ "\n", "wg = WorkGraph(\"test_wait\")\n", "add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n", - "add1.set_context({\"result\": \"data.add1\"})\n", + "add1.set_context({\"data.add1\": \"result\"})\n", "add2 = wg.add_task(add, name=\"add2\", x=2, y=2)\n", - "add2.set_context({\"result\": \"data.add2\"})\n", + "add2.set_context({\"data.add2\": \"result\"})\n", "# let sum task wait for add1 and add2, and the `data` in the context is ready\n", "sum3 = wg.add_task(sum, name=\"sum1\", datas=\"{{data}}\")\n", "sum3.waiting_on.add([\"add1\", \"add2\"])\n", diff --git a/docs/source/howto/while.ipynb b/docs/source/howto/while.ipynb index 1330679d..af3eff2e 100644 --- a/docs/source/howto/while.ipynb +++ b/docs/source/howto/while.ipynb @@ -126,7 +126,7 @@ "# set a context variable before running.\n", "wg.context = {\"should_run\": True}\n", "add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n", - "add1.set_context({\"result\": \"n\"})\n", + "add1.set_context({\"n\": \"result\"})\n", "#---------------------------------------------------------------------\n", "# Create the while tasks\n", "compare1 = wg.add_task(compare, name=\"compare1\", x=\"{{n}}\", y=50)\n", @@ -138,7 +138,7 @@ " x=add2.outputs[\"result\"],\n", " y=2)\n", "# update the context variable\n", - "multiply1.set_context({\"result\": \"n\"})\n", + "multiply1.set_context({\"n\": \"result\"})\n", "while1.children.add([\"add2\", \"multiply1\"])\n", "#---------------------------------------------------------------------\n", "add3 = wg.add_task(add, name=\"add3\", x=1, y=1)\n", @@ -845,7 +845,7 @@ " multiply1 = wg.add_task(multiply, name=\"multiply1\", x=add1.outputs[\"result\"],\n", " y=2)\n", " # update the context variable\n", - " multiply1.set_context({\"result\": \"n\"})\n", + " multiply1.set_context({\"n\": \"result\"})\n", " return wg" ] }, diff --git a/docs/source/tutorial/eos.ipynb b/docs/source/tutorial/eos.ipynb index 7e71d11c..c4fe5a26 100644 --- a/docs/source/tutorial/eos.ipynb +++ b/docs/source/tutorial/eos.ipynb @@ -48,7 +48,7 @@ " pw1 = wg.add_task(PwCalculation, name=f\"pw1_{key}\", structure=structure)\n", " pw1.set(scf_inputs)\n", " # save the output parameters to the context\n", - " pw1.set_context({\"output_parameters\": f\"result.{key}\"})\n", + " pw1.set_context({f\"result.{key}\": \"output_parameters\"})\n", " return wg\n", "\n", "\n", diff --git a/docs/source/tutorial/zero_to_hero.ipynb b/docs/source/tutorial/zero_to_hero.ipynb index 0bebdf09..a643a822 100644 --- a/docs/source/tutorial/zero_to_hero.ipynb +++ b/docs/source/tutorial/zero_to_hero.ipynb @@ -1224,7 +1224,7 @@ " pw1 = wg.add_task(PwCalculation, name=f\"pw1_{key}\", structure=structure)\n", " pw1.set(scf_inputs)\n", " # save the output parameters to the context\n", - " pw1.set_context({\"output_parameters\": f\"result.{key}\"})\n", + " pw1.set_context({f\"result.{key}\": \"output_parameters\"})\n", " return wg\n", "\n", "\n", diff --git a/pyproject.toml b/pyproject.toml index 7930b305..18e33905 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,8 +94,8 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" "workgraph.if" = "aiida_workgraph.tasks.builtins:If" "workgraph.select" = "aiida_workgraph.tasks.builtins:Select" "workgraph.gather" = "aiida_workgraph.tasks.builtins:Gather" -"workgraph.to_context" = "aiida_workgraph.tasks.builtins:ToContext" -"workgraph.from_context" = "aiida_workgraph.tasks.builtins:FromContext" +"workgraph.set_context" = "aiida_workgraph.tasks.builtins:SetContext" +"workgraph.get_context" = "aiida_workgraph.tasks.builtins:GetContext" "workgraph.time_monitor" = "aiida_workgraph.tasks.monitors:TimeMonitor" "workgraph.file_monitor" = "aiida_workgraph.tasks.monitors:FileMonitor" "workgraph.task_monitor" = "aiida_workgraph.tasks.monitors:TaskMonitor" diff --git a/tests/test_ctx.py b/tests/test_ctx.py index 9ec1bee5..5f8b66bf 100644 --- a/tests/test_ctx.py +++ b/tests/test_ctx.py @@ -15,26 +15,26 @@ def test_workgraph_ctx(decorated_add: Callable) -> None: wg.context = {"x": Float(2), "data.y": Float(3), "array": array} add1 = wg.add_task(decorated_add, "add1", x="{{ x }}", y="{{ data.y }}") wg.add_task( - "workgraph.to_context", name="to_ctx1", key="x", value=add1.outputs["result"] + "workgraph.set_context", name="set_ctx1", key="x", value=add1.outputs["result"] ) - from_ctx1 = wg.add_task("workgraph.from_context", name="from_ctx1", key="x") + get_ctx1 = wg.add_task("workgraph.get_context", name="get_ctx1", key="x") # test the task can wait for another task - from_ctx1.waiting_on.add(add1) - add2 = wg.add_task(decorated_add, "add2", x=from_ctx1.outputs["result"], y=1) + get_ctx1.waiting_on.add(add1) + add2 = wg.add_task(decorated_add, "add2", x=get_ctx1.outputs["result"], y=1) wg.run() assert add2.outputs["result"].value == 6 -def test_node_to_ctx(decorated_add: Callable) -> None: +def test_task_set_ctx(decorated_add: Callable) -> None: """Set/get data to/from context.""" - wg = WorkGraph(name="test_node_to_ctx") + wg = WorkGraph(name="test_node_set_ctx") add1 = wg.add_task(decorated_add, "add1", x=Float(2).store(), y=Float(3).store()) try: - add1.set_context({"resul": "sum"}) + add1.set_context({"sum": "resul"}) except ValueError as e: assert str(e) == "Keys {'resul'} are not in the outputs of this task." - add1.set_context({"result": "sum"}) + add1.set_context({"sum": "result"}) add2 = wg.add_task(decorated_add, "add2", y="{{ sum }}") wg.add_link(add1.outputs[0], add2.inputs["x"]) wg.submit(wait=True) diff --git a/tests/test_for.py b/tests/test_for.py index 8f347ebc..fdd6fa47 100644 --- a/tests/test_for.py +++ b/tests/test_for.py @@ -21,7 +21,7 @@ def add_multiply_for(sequence): ) add1 = wg.add_task(decorated_add, name="add1", x="{{ total }}") # update the context variable - add1.set_context({"result": "total"}) + add1.set_context({"total": "result"}) wg.add_link(multiply1.outputs["result"], add1.inputs["y"]) # don't forget to return the workgraph return wg diff --git a/tests/test_while.py b/tests/test_while.py index 90f5280d..6c677e85 100644 --- a/tests/test_while.py +++ b/tests/test_while.py @@ -43,7 +43,7 @@ def raw_python_code(): "l": 1, } add1 = wg.add_task(decorated_add, name="add1", x=1, y=1) - add1.set_context({"result": "n"}) + add1.set_context({"n": "result"}) # --------------------------------------------------------------------- # the `result` of compare1 taskis used as condition compare1 = wg.add_task(decorated_compare, name="compare1", x="{{m}}", y=10) @@ -57,7 +57,7 @@ def raw_python_code(): ) add21.waiting_on.add("add1") add22 = wg.add_task(decorated_add, name="add22", x=add21.outputs["result"], y=1) - add22.set_context({"result": "n"}) + add22.set_context({"n": "result"}) while2.children.add(["add21", "add22"]) # --------------------------------------------------------------------- compare3 = wg.add_task(decorated_compare, name="compare3", x="{{l}}", y=5) @@ -67,13 +67,13 @@ def raw_python_code(): add31 = wg.add_task(decorated_add, name="add31", x="{{l}}", y=1) add31.waiting_on.add("add22") add32 = wg.add_task(decorated_add, name="add32", x=add31.outputs["result"], y=1) - add32.set_context({"result": "l"}) + add32.set_context({"l": "result"}) while3.children.add(["add31", "add32"]) # --------------------------------------------------------------------- add12 = wg.add_task( decorated_add, name="add12", x="{{m}}", y=add32.outputs["result"] ) - add12.set_context({"result": "m"}) + add12.set_context({"m": "result"}) while1.children.add(["add11", "while2", "while3", "add12", "compare2", "compare3"]) # --------------------------------------------------------------------- add2 = wg.add_task( @@ -101,7 +101,7 @@ def test_while_workgraph(decorated_add, decorated_multiply, decorated_compare): decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2) ) add1 = wg.add_task(decorated_add, name="add1", y=3) - add1.set_context({"result": "n"}) + add1.set_context({"n": "result"}) wg.add_link(multiply1.outputs["result"], add1.inputs["x"]) wg.submit(wait=True, timeout=100) assert wg.execution_count == 4 @@ -125,7 +125,7 @@ def my_while(n=0, limit=100): decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2) ) add1 = wg.add_task(decorated_add, name="add1", y=3) - add1.set_context({"result": "n"}) + add1.set_context({"n": "result"}) wg.add_link(multiply1.outputs["result"], add1.inputs["x"]) return wg