Skip to content

Commit

Permalink
Create input sockets and links for items inside a dynamic socket (#381)
Browse files Browse the repository at this point in the history
In AiiDA, one can define a dynamic namespace for the process, which allows the user to pass any nested dictionary with AiiDA data nodes as values. However, in the `WorkGraph`, we need to define the input and output sockets explicitly, so that one can make a link between tasks. To address this discrepancy, and still allow user to pass any nested dictionary with AiiDA data nodes, as well as the output sockets of other tasks, we automatically create the input for each item in the dictionary if the input is not defined. Besides, if the value of the item is a socket, we will link the socket to the task, and remove the item from the dictionary.
  • Loading branch information
superstar54 authored Dec 5, 2024
1 parent 02e1c48 commit c08af7a
Show file tree
Hide file tree
Showing 13 changed files with 873 additions and 41 deletions.
290 changes: 290 additions & 0 deletions docs/source/howto/html/test_dynamic_namespace.html

Large diffs are not rendered by default.

290 changes: 290 additions & 0 deletions docs/source/howto/html/test_use_calcjob.html

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/source/howto/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This section contains a collection of HowTos for various topics.
:maxdepth: 1
:caption: Contents:

use_calcjob_workchain
autogen/graph_builder
autogen/parallel
autogen/aggregate
Expand Down
216 changes: 216 additions & 0 deletions docs/source/howto/use_calcjob_workchain.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "22d177dc-6cfb-4de2-9509-f1eb45e10cf2",
"metadata": {},
"source": [
"# Use `CalcJob` and `WorkChain` insdie WorkGraph\n",
"One can use `CalcJob`, `WorkChain` and other AiiDA components direclty in the WorkGraph. The inputs and outputs of the task is automatically generated based on the input and output port of the AiiDA component."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a6e0038f",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"100%\"\n",
" height=\"600px\"\n",
" src=\"html/test_use_calcjob.html\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" \n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x7329bd9dca90>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"from aiida_workgraph import WorkGraph\n",
"from aiida.calculations.arithmetic.add import ArithmeticAddCalculation\n",
"\n",
"wg = WorkGraph(\"test_use_calcjob\")\n",
"task1 = wg.add_task(ArithmeticAddCalculation, name=\"add1\")\n",
"task2 = wg.add_task(ArithmeticAddCalculation, name=\"add2\", x=wg.tasks[\"add1\"].outputs[\"sum\"])\n",
"wg.to_html()"
]
},
{
"cell_type": "markdown",
"id": "1781a459",
"metadata": {},
"source": [
"## Set inputs\n",
"One can set the inputs when adding the task, or using the `set` method of the `Task` object."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "288327e4",
"metadata": {},
"outputs": [],
"source": [
"from aiida import load_profile\n",
"from aiida.orm import Int\n",
"\n",
"load_profile()\n",
"\n",
"# use set method\n",
"task1.set({\"x\": Int(1), \"y\": Int(2)})\n",
"# set the inputs when adding the task\n",
"task3 = wg.add_task(ArithmeticAddCalculation, name=\"add3\", x=Int(3), y=Int(4))\n"
]
},
{
"cell_type": "markdown",
"id": "ef4ba444",
"metadata": {},
"source": [
"### Use process builder\n",
"One can also set the inputs of the task using the process builder."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "53e31346",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from aiida.calculations.arithmetic.add import ArithmeticAddCalculation\n",
"from aiida.orm import Int, load_code\n",
"\n",
"\n",
"code = load_code(\"add@localhost\")\n",
"builder = ArithmeticAddCalculation.get_builder()\n",
"builder.code = code\n",
"builder.x = Int(2)\n",
"builder.y = Int(3)\n",
"\n",
"wg = WorkGraph(\"test_set_inputs_from_builder\")\n",
"add1 = wg.add_task(ArithmeticAddCalculation, name=\"add1\")\n",
"add1.set_from_builder(builder)"
]
},
{
"cell_type": "markdown",
"id": "a0a1dbf0",
"metadata": {},
"source": [
"\n",
"## Dynamic namespace\n",
"In AiiDA, one can define a dynamic namespace for the process, which allows the user to pass any nested dictionary with AiiDA data nodes as values. However, in the `WorkGraph`, we need to define the input and output sockets explicitly, so that one can make a link between tasks. To address this discrepancy, and still allow user to pass any nested dictionary with AiiDA data nodes, as well as the output sockets of other tasks, we automatically create the input for each item in the dictionary if the input is not defined. Besides, if the value of the item is a socket, we will link the socket to the task, and remove the item from the dictionary.\n",
"\n",
"For example, the `WorkChainWithDynamicNamespace` has a dynamic namespace `dynamic_port`, and the user can pass any nested dictionary as the input.\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4a81efa9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Failed to inspect function WorkChainWithDynamicNamespace: source code not available\n"
]
},
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"100%\"\n",
" height=\"600px\"\n",
" src=\"html/test_dynamic_namespace.html\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" \n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x77b0e0b96f90>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from aiida.engine import WorkChain\n",
"\n",
"class WorkChainWithDynamicNamespace(WorkChain):\n",
" \"\"\"WorkChain with dynamic namespace.\"\"\"\n",
"\n",
" @classmethod\n",
" def define(cls, spec):\n",
" \"\"\"Specify inputs and outputs.\"\"\"\n",
" super().define(spec)\n",
" spec.input_namespace(\"dynamic_port\", dynamic=True)\n",
"\n",
"wg = WorkGraph(\"test_dynamic_namespace\")\n",
"task1 = wg.add_task(ArithmeticAddCalculation, name=\"add1\")\n",
"task2 = wg.add_task(\n",
" WorkChainWithDynamicNamespace,\n",
" dynamic_port={\n",
" \"input1\": None,\n",
" \"input2\": Int(2),\n",
" \"input3\": task1.outputs[\"sum\"],\n",
" },\n",
" )\n",
"wg.to_html()"
]
},
{
"cell_type": "markdown",
"id": "015f91d7",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"In this example, we will show how to use `CalcJob` and `WorkChain` inside the WorkGraph. One can also use `WorkGraph` inside a `WorkChain`, please refer to the [Calling WorkGraph within a WorkChain](workchain_call_workgraph.ipynb) for more details."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "aiida",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"numpy~=1.21",
"scipy",
"ase",
"node-graph==0.1.5",
"node-graph==0.1.6",
"node-graph-widget",
"aiida-core>=2.3",
"cloudpickle",
Expand Down
5 changes: 1 addition & 4 deletions src/aiida_workgraph/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,11 @@ def new(
if isinstance(identifier, str) and identifier.upper() == "PYTHONJOB":
identifier, _ = build_pythonjob_task(kwargs.pop("function"))
if isinstance(identifier, str) and identifier.upper() == "SHELLJOB":
identifier, _, links = build_shelljob_task(
nodes=kwargs.get("nodes", {}),
identifier, _ = build_shelljob_task(
outputs=kwargs.get("outputs", None),
parser_outputs=kwargs.pop("parser_outputs", None),
)
task = super().new(identifier, name, uuid, **kwargs)
# make links between the tasks
task.set(links)
return task
if isinstance(identifier, str) and identifier.upper() == "WHILE":
task = super().new("workgraph.while", name, uuid, **kwargs)
Expand Down
29 changes: 2 additions & 27 deletions src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,41 +337,16 @@ def build_pythonjob_task(func: Callable) -> Task:
return task, tdata


def build_shelljob_task(
nodes: dict = None, outputs: list = None, parser_outputs: list = None
) -> Task:
def build_shelljob_task(outputs: list = None, parser_outputs: list = None) -> Task:
"""Build ShellJob with custom inputs and outputs."""
from aiida_shell.calculations.shell import ShellJob
from aiida_shell.parsers.shell import ShellParser
from node_graph.socket import NodeSocket

tdata = {
"metadata": {"task_type": "SHELLJOB"},
"executor": ShellJob,
}
_, tdata = build_task_from_AiiDA(tdata)
# create input sockets for the nodes, if it is linked other sockets
links = {}
inputs = []
nodes = {} if nodes is None else nodes
keys = list(nodes.keys())
for key in keys:
inputs.append(
{
"identifier": "workgraph.any",
"name": f"nodes.{key}",
"metadata": {"required": True},
}
)
# input is a output of another task, we make a link
if isinstance(nodes[key], NodeSocket):
links[f"nodes.{key}"] = nodes[key]
# Output socket itself is not a value, so we remove the key from the nodes
nodes.pop(key)
for input in inputs:
if input["name"] not in tdata["inputs"]:
input["list_index"] = len(tdata["inputs"]) + 1
tdata["inputs"][input["name"]] = input
# Extend the outputs
for output in [
{"identifier": "workgraph.any", "name": "stdout"},
Expand Down Expand Up @@ -404,7 +379,7 @@ def build_shelljob_task(
tdata["metadata"]["task_type"] = "SHELLJOB"
task = create_task(tdata)
task.is_aiida_component = True
return task, tdata, links
return task, tdata


def build_task_from_workgraph(wg: any) -> Task:
Expand Down
24 changes: 24 additions & 0 deletions src/aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,30 @@ def set_context(self, context: Dict[str, Any]) -> None:
raise ValueError(msg)
self.context_mapping.update(context)

def set(self, data: Dict[str, Any]) -> None:
from node_graph.socket import NodeSocket

super().set(data)
# create input sockets and links for items inside a dynamic socket
# TODO the input value could be nested, but we only support one level for now
for key, value in data.items():
if self.inputs[key].metadata.get("dynamic", False):
if isinstance(value, dict):
keys = list(value.keys())
for sub_key in keys:
# create a new input socket if it does not exist
if f"{key}.{sub_key}" not in self.inputs.keys():
self.inputs.new(
"workgraph.any",
name=f"{key}.{sub_key}",
metadata={"required": True},
)
if isinstance(value[sub_key], NodeSocket):
self.parent.links.new(
value[sub_key], self.inputs[f"{key}.{sub_key}"]
)
self.inputs[key].value.pop(sub_key)

def set_from_builder(self, builder: Any) -> None:
"""Set the task inputs from a AiiDA ProcessBuilder."""
from aiida_workgraph.utils import get_dict_from_builder
Expand Down
4 changes: 2 additions & 2 deletions src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def submit(
inputs: Optional[Dict[str, Any]] = None,
wait: bool = False,
timeout: int = 60,
interval: int = 1,
interval: int = 5,
metadata: Optional[Dict[str, Any]] = None,
) -> aiida.orm.ProcessNode:
"""Submit the AiiDA workgraph process and optionally wait for it to finish.
Expand Down Expand Up @@ -230,7 +230,7 @@ def get_error_handlers(self) -> Dict[str, Any]:
task["exit_codes"] = exit_codes
return error_handlers

def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 1) -> None:
def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None:
"""
Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout.
Args:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable
from aiida.orm import Float, ArrayData
import numpy as np
import pytest


def test_workgraph_ctx(decorated_add: Callable) -> None:
Expand All @@ -25,6 +26,7 @@ def test_workgraph_ctx(decorated_add: Callable) -> None:
assert add2.outputs["result"].value == 6


@pytest.mark.usefixtures("started_daemon_client")
def test_task_set_ctx(decorated_add: Callable) -> None:
"""Set/get data to/from context."""

Expand Down
Loading

0 comments on commit c08af7a

Please sign in to comment.