From c08af7ac6902b634e638cd1c5e5d9392df4f7976 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Thu, 5 Dec 2024 13:44:58 +0100 Subject: [PATCH] Create input sockets and links for items inside a dynamic socket (#381) In AiiDA, one can define a dynamic namespace for the process, which allows the user to pass any nested dictionary with AiiDA data nodes as values. However, in the `WorkGraph`, we need to define the input and output sockets explicitly, so that one can make a link between tasks. To address this discrepancy, and still allow user to pass any nested dictionary with AiiDA data nodes, as well as the output sockets of other tasks, we automatically create the input for each item in the dictionary if the input is not defined. Besides, if the value of the item is a socket, we will link the socket to the task, and remove the item from the dictionary. --- .../howto/html/test_dynamic_namespace.html | 290 ++++++++++++++++++ docs/source/howto/html/test_use_calcjob.html | 290 ++++++++++++++++++ docs/source/howto/index.rst | 1 + docs/source/howto/use_calcjob_workchain.ipynb | 216 +++++++++++++ pyproject.toml | 2 +- src/aiida_workgraph/collection.py | 5 +- src/aiida_workgraph/decorator.py | 29 +- src/aiida_workgraph/task.py | 24 ++ src/aiida_workgraph/workgraph.py | 4 +- tests/test_ctx.py | 2 + tests/test_shell.py | 18 +- tests/test_tasks.py | 23 ++ tests/utils/test_workchain.py | 10 + 13 files changed, 873 insertions(+), 41 deletions(-) create mode 100644 docs/source/howto/html/test_dynamic_namespace.html create mode 100644 docs/source/howto/html/test_use_calcjob.html create mode 100644 docs/source/howto/use_calcjob_workchain.ipynb diff --git a/docs/source/howto/html/test_dynamic_namespace.html b/docs/source/howto/html/test_dynamic_namespace.html new file mode 100644 index 00000000..9d91f8c4 --- /dev/null +++ b/docs/source/howto/html/test_dynamic_namespace.html @@ -0,0 +1,290 @@ + + + + + + + Rete.js with React in Vanilla JS + + + + + + + + + + + + + + + + + + + + + +
+ + + diff --git a/docs/source/howto/html/test_use_calcjob.html b/docs/source/howto/html/test_use_calcjob.html new file mode 100644 index 00000000..7f8fea45 --- /dev/null +++ b/docs/source/howto/html/test_use_calcjob.html @@ -0,0 +1,290 @@ + + + + + + + Rete.js with React in Vanilla JS + + + + + + + + + + + + + + + + + + + + + +
+ + + diff --git a/docs/source/howto/index.rst b/docs/source/howto/index.rst index fd9ef106..a8dd8aab 100644 --- a/docs/source/howto/index.rst +++ b/docs/source/howto/index.rst @@ -8,6 +8,7 @@ This section contains a collection of HowTos for various topics. :maxdepth: 1 :caption: Contents: + use_calcjob_workchain autogen/graph_builder autogen/parallel autogen/aggregate diff --git a/docs/source/howto/use_calcjob_workchain.ipynb b/docs/source/howto/use_calcjob_workchain.ipynb new file mode 100644 index 00000000..f4c8fe57 --- /dev/null +++ b/docs/source/howto/use_calcjob_workchain.ipynb @@ -0,0 +1,216 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "22d177dc-6cfb-4de2-9509-f1eb45e10cf2", + "metadata": {}, + "source": [ + "# Use `CalcJob` and `WorkChain` insdie WorkGraph\n", + "One can use `CalcJob`, `WorkChain` and other AiiDA components direclty in the WorkGraph. The inputs and outputs of the task is automatically generated based on the input and output port of the AiiDA component." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a6e0038f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "from aiida_workgraph import WorkGraph\n", + "from aiida.calculations.arithmetic.add import ArithmeticAddCalculation\n", + "\n", + "wg = WorkGraph(\"test_use_calcjob\")\n", + "task1 = wg.add_task(ArithmeticAddCalculation, name=\"add1\")\n", + "task2 = wg.add_task(ArithmeticAddCalculation, name=\"add2\", x=wg.tasks[\"add1\"].outputs[\"sum\"])\n", + "wg.to_html()" + ] + }, + { + "cell_type": "markdown", + "id": "1781a459", + "metadata": {}, + "source": [ + "## Set inputs\n", + "One can set the inputs when adding the task, or using the `set` method of the `Task` object." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "288327e4", + "metadata": {}, + "outputs": [], + "source": [ + "from aiida import load_profile\n", + "from aiida.orm import Int\n", + "\n", + "load_profile()\n", + "\n", + "# use set method\n", + "task1.set({\"x\": Int(1), \"y\": Int(2)})\n", + "# set the inputs when adding the task\n", + "task3 = wg.add_task(ArithmeticAddCalculation, name=\"add3\", x=Int(3), y=Int(4))\n" + ] + }, + { + "cell_type": "markdown", + "id": "ef4ba444", + "metadata": {}, + "source": [ + "### Use process builder\n", + "One can also set the inputs of the task using the process builder." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "53e31346", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from aiida.calculations.arithmetic.add import ArithmeticAddCalculation\n", + "from aiida.orm import Int, load_code\n", + "\n", + "\n", + "code = load_code(\"add@localhost\")\n", + "builder = ArithmeticAddCalculation.get_builder()\n", + "builder.code = code\n", + "builder.x = Int(2)\n", + "builder.y = Int(3)\n", + "\n", + "wg = WorkGraph(\"test_set_inputs_from_builder\")\n", + "add1 = wg.add_task(ArithmeticAddCalculation, name=\"add1\")\n", + "add1.set_from_builder(builder)" + ] + }, + { + "cell_type": "markdown", + "id": "a0a1dbf0", + "metadata": {}, + "source": [ + "\n", + "## Dynamic namespace\n", + "In AiiDA, one can define a dynamic namespace for the process, which allows the user to pass any nested dictionary with AiiDA data nodes as values. However, in the `WorkGraph`, we need to define the input and output sockets explicitly, so that one can make a link between tasks. To address this discrepancy, and still allow user to pass any nested dictionary with AiiDA data nodes, as well as the output sockets of other tasks, we automatically create the input for each item in the dictionary if the input is not defined. Besides, if the value of the item is a socket, we will link the socket to the task, and remove the item from the dictionary.\n", + "\n", + "For example, the `WorkChainWithDynamicNamespace` has a dynamic namespace `dynamic_port`, and the user can pass any nested dictionary as the input.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4a81efa9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Failed to inspect function WorkChainWithDynamicNamespace: source code not available\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from aiida.engine import WorkChain\n", + "\n", + "class WorkChainWithDynamicNamespace(WorkChain):\n", + " \"\"\"WorkChain with dynamic namespace.\"\"\"\n", + "\n", + " @classmethod\n", + " def define(cls, spec):\n", + " \"\"\"Specify inputs and outputs.\"\"\"\n", + " super().define(spec)\n", + " spec.input_namespace(\"dynamic_port\", dynamic=True)\n", + "\n", + "wg = WorkGraph(\"test_dynamic_namespace\")\n", + "task1 = wg.add_task(ArithmeticAddCalculation, name=\"add1\")\n", + "task2 = wg.add_task(\n", + " WorkChainWithDynamicNamespace,\n", + " dynamic_port={\n", + " \"input1\": None,\n", + " \"input2\": Int(2),\n", + " \"input3\": task1.outputs[\"sum\"],\n", + " },\n", + " )\n", + "wg.to_html()" + ] + }, + { + "cell_type": "markdown", + "id": "015f91d7", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "In this example, we will show how to use `CalcJob` and `WorkChain` inside the WorkGraph. One can also use `WorkGraph` inside a `WorkChain`, please refer to the [Calling WorkGraph within a WorkChain](workchain_call_workgraph.ipynb) for more details." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "aiida", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index cd11f0e6..0e60a37f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "numpy~=1.21", "scipy", "ase", - "node-graph==0.1.5", + "node-graph==0.1.6", "node-graph-widget", "aiida-core>=2.3", "cloudpickle", diff --git a/src/aiida_workgraph/collection.py b/src/aiida_workgraph/collection.py index 7af02e10..ac8f3337 100644 --- a/src/aiida_workgraph/collection.py +++ b/src/aiida_workgraph/collection.py @@ -29,14 +29,11 @@ def new( if isinstance(identifier, str) and identifier.upper() == "PYTHONJOB": identifier, _ = build_pythonjob_task(kwargs.pop("function")) if isinstance(identifier, str) and identifier.upper() == "SHELLJOB": - identifier, _, links = build_shelljob_task( - nodes=kwargs.get("nodes", {}), + identifier, _ = build_shelljob_task( outputs=kwargs.get("outputs", None), parser_outputs=kwargs.pop("parser_outputs", None), ) task = super().new(identifier, name, uuid, **kwargs) - # make links between the tasks - task.set(links) return task if isinstance(identifier, str) and identifier.upper() == "WHILE": task = super().new("workgraph.while", name, uuid, **kwargs) diff --git a/src/aiida_workgraph/decorator.py b/src/aiida_workgraph/decorator.py index ea7dda73..3f091088 100644 --- a/src/aiida_workgraph/decorator.py +++ b/src/aiida_workgraph/decorator.py @@ -337,41 +337,16 @@ def build_pythonjob_task(func: Callable) -> Task: return task, tdata -def build_shelljob_task( - nodes: dict = None, outputs: list = None, parser_outputs: list = None -) -> Task: +def build_shelljob_task(outputs: list = None, parser_outputs: list = None) -> Task: """Build ShellJob with custom inputs and outputs.""" from aiida_shell.calculations.shell import ShellJob from aiida_shell.parsers.shell import ShellParser - from node_graph.socket import NodeSocket tdata = { "metadata": {"task_type": "SHELLJOB"}, "executor": ShellJob, } _, tdata = build_task_from_AiiDA(tdata) - # create input sockets for the nodes, if it is linked other sockets - links = {} - inputs = [] - nodes = {} if nodes is None else nodes - keys = list(nodes.keys()) - for key in keys: - inputs.append( - { - "identifier": "workgraph.any", - "name": f"nodes.{key}", - "metadata": {"required": True}, - } - ) - # input is a output of another task, we make a link - if isinstance(nodes[key], NodeSocket): - links[f"nodes.{key}"] = nodes[key] - # Output socket itself is not a value, so we remove the key from the nodes - nodes.pop(key) - for input in inputs: - if input["name"] not in tdata["inputs"]: - input["list_index"] = len(tdata["inputs"]) + 1 - tdata["inputs"][input["name"]] = input # Extend the outputs for output in [ {"identifier": "workgraph.any", "name": "stdout"}, @@ -404,7 +379,7 @@ def build_shelljob_task( tdata["metadata"]["task_type"] = "SHELLJOB" task = create_task(tdata) task.is_aiida_component = True - return task, tdata, links + return task, tdata def build_task_from_workgraph(wg: any) -> Task: diff --git a/src/aiida_workgraph/task.py b/src/aiida_workgraph/task.py index 085c778e..f43a68be 100644 --- a/src/aiida_workgraph/task.py +++ b/src/aiida_workgraph/task.py @@ -81,6 +81,30 @@ def set_context(self, context: Dict[str, Any]) -> None: raise ValueError(msg) self.context_mapping.update(context) + def set(self, data: Dict[str, Any]) -> None: + from node_graph.socket import NodeSocket + + super().set(data) + # create input sockets and links for items inside a dynamic socket + # TODO the input value could be nested, but we only support one level for now + for key, value in data.items(): + if self.inputs[key].metadata.get("dynamic", False): + if isinstance(value, dict): + keys = list(value.keys()) + for sub_key in keys: + # create a new input socket if it does not exist + if f"{key}.{sub_key}" not in self.inputs.keys(): + self.inputs.new( + "workgraph.any", + name=f"{key}.{sub_key}", + metadata={"required": True}, + ) + if isinstance(value[sub_key], NodeSocket): + self.parent.links.new( + value[sub_key], self.inputs[f"{key}.{sub_key}"] + ) + self.inputs[key].value.pop(sub_key) + def set_from_builder(self, builder: Any) -> None: """Set the task inputs from a AiiDA ProcessBuilder.""" from aiida_workgraph.utils import get_dict_from_builder diff --git a/src/aiida_workgraph/workgraph.py b/src/aiida_workgraph/workgraph.py index dfd371bf..4a3b8fde 100644 --- a/src/aiida_workgraph/workgraph.py +++ b/src/aiida_workgraph/workgraph.py @@ -113,7 +113,7 @@ def submit( inputs: Optional[Dict[str, Any]] = None, wait: bool = False, timeout: int = 60, - interval: int = 1, + interval: int = 5, metadata: Optional[Dict[str, Any]] = None, ) -> aiida.orm.ProcessNode: """Submit the AiiDA workgraph process and optionally wait for it to finish. @@ -230,7 +230,7 @@ def get_error_handlers(self) -> Dict[str, Any]: task["exit_codes"] = exit_codes return error_handlers - def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 1) -> None: + def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None: """ Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout. Args: diff --git a/tests/test_ctx.py b/tests/test_ctx.py index 5f8b66bf..bffbc872 100644 --- a/tests/test_ctx.py +++ b/tests/test_ctx.py @@ -2,6 +2,7 @@ from typing import Callable from aiida.orm import Float, ArrayData import numpy as np +import pytest def test_workgraph_ctx(decorated_add: Callable) -> None: @@ -25,6 +26,7 @@ def test_workgraph_ctx(decorated_add: Callable) -> None: assert add2.outputs["result"].value == 6 +@pytest.mark.usefixtures("started_daemon_client") def test_task_set_ctx(decorated_add: Callable) -> None: """Set/get data to/from context.""" diff --git a/tests/test_shell.py b/tests/test_shell.py index 3160e71f..88045b42 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -1,7 +1,7 @@ import pytest from aiida_workgraph import WorkGraph, task from aiida_shell.launch import prepare_code -from aiida.orm import SinglefileData, load_computer +from aiida.orm import SinglefileData, load_computer, Int def test_prepare_for_shell_task_nonexistent(): @@ -51,9 +51,9 @@ def test_shell_code(): assert job1.node.outputs.stdout.get_content() == "string astring b" -def test_shell_set(): +def test_dynamic_port(): """Set the nodes during/after the creation of the task.""" - wg = WorkGraph(name="test_shell_set") + wg = WorkGraph(name="test_dynamic_port") echo_task = wg.add_task( "ShellJob", name="echo", @@ -68,11 +68,15 @@ def test_shell_set(): name="cat", command="cat", arguments=["{input}"], - nodes={"input": None}, + nodes={"input1": None, "input2": Int(2), "input3": echo_task.outputs["_wait"]}, ) - wg.add_link(echo_task.outputs["copied_file"], cat_task.inputs["nodes.input"]) - wg.submit(wait=True) - assert cat_task.outputs["stdout"].value.get_content() == "1 5 1" + wg.add_link(echo_task.outputs["copied_file"], cat_task.inputs["nodes.input1"]) + # task will create input for each item in the dynamic port (nodes) + assert "nodes.input1" in cat_task.inputs.keys() + assert "nodes.input2" in cat_task.inputs.keys() + # if the value of the item is a Socket, then it will create a link, and pop the item + assert "nodes.input3" in cat_task.inputs.keys() + assert cat_task.inputs["nodes"].value == {"input1": None, "input2": Int(2)} @pytest.mark.usefixtures("started_daemon_client") diff --git a/tests/test_tasks.py b/tests/test_tasks.py index e8bd8051..bfdffdfe 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -2,6 +2,7 @@ from aiida_workgraph import WorkGraph, task from typing import Callable from aiida.cmdline.utils.common import get_workchain_report +from aiida import orm def test_normal_task(decorated_add) -> None: @@ -73,6 +74,28 @@ def test_task_wait(decorated_add: Callable) -> None: assert "tasks ready to run: add1" in report +def test_set_dynamic_port_input(decorated_add) -> None: + from .utils.test_workchain import WorkChainWithDynamicNamespace + + wg = WorkGraph(name="test_set_dynamic_port_input") + task1 = wg.add_task(decorated_add) + task2 = wg.add_task( + WorkChainWithDynamicNamespace, + dynamic_port={ + "input1": None, + "input2": orm.Int(2), + "input3": task1.outputs["result"], + }, + ) + wg.add_link(task1.outputs["_wait"], task2.inputs["dynamic_port.input1"]) + # task will create input for each item in the dynamic port (nodes) + assert "dynamic_port.input1" in task2.inputs.keys() + assert "dynamic_port.input2" in task2.inputs.keys() + # if the value of the item is a Socket, then it will create a link, and pop the item + assert "dynamic_port.input3" in task2.inputs.keys() + assert task2.inputs["dynamic_port"].value == {"input1": None, "input2": orm.Int(2)} + + def test_set_inputs(decorated_add: Callable) -> None: """Test setting inputs of a task.""" diff --git a/tests/utils/test_workchain.py b/tests/utils/test_workchain.py index 80341bb2..af702642 100644 --- a/tests/utils/test_workchain.py +++ b/tests/utils/test_workchain.py @@ -53,3 +53,13 @@ def validate_result(self): def result(self): """Add the result to the outputs.""" self.out("result", self.ctx.addition.outputs.sum) + + +class WorkChainWithDynamicNamespace(WorkChain): + """WorkChain with dynamic namespace.""" + + @classmethod + def define(cls, spec): + """Specify inputs and outputs.""" + super().define(spec) + spec.input_namespace("dynamic_port", dynamic=True)