Skip to content

Commit

Permalink
create input sockets and links for items inside a dynamic socket
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 5, 2024
1 parent 02e1c48 commit 49a3c26
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 38 deletions.
5 changes: 1 addition & 4 deletions src/aiida_workgraph/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,11 @@ def new(
if isinstance(identifier, str) and identifier.upper() == "PYTHONJOB":
identifier, _ = build_pythonjob_task(kwargs.pop("function"))
if isinstance(identifier, str) and identifier.upper() == "SHELLJOB":
identifier, _, links = build_shelljob_task(
nodes=kwargs.get("nodes", {}),
identifier, _ = build_shelljob_task(
outputs=kwargs.get("outputs", None),
parser_outputs=kwargs.pop("parser_outputs", None),
)
task = super().new(identifier, name, uuid, **kwargs)
# make links between the tasks
task.set(links)
return task
if isinstance(identifier, str) and identifier.upper() == "WHILE":
task = super().new("workgraph.while", name, uuid, **kwargs)
Expand Down
29 changes: 2 additions & 27 deletions src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,41 +337,16 @@ def build_pythonjob_task(func: Callable) -> Task:
return task, tdata


def build_shelljob_task(
nodes: dict = None, outputs: list = None, parser_outputs: list = None
) -> Task:
def build_shelljob_task(outputs: list = None, parser_outputs: list = None) -> Task:
"""Build ShellJob with custom inputs and outputs."""
from aiida_shell.calculations.shell import ShellJob
from aiida_shell.parsers.shell import ShellParser
from node_graph.socket import NodeSocket

tdata = {
"metadata": {"task_type": "SHELLJOB"},
"executor": ShellJob,
}
_, tdata = build_task_from_AiiDA(tdata)
# create input sockets for the nodes, if it is linked other sockets
links = {}
inputs = []
nodes = {} if nodes is None else nodes
keys = list(nodes.keys())
for key in keys:
inputs.append(
{
"identifier": "workgraph.any",
"name": f"nodes.{key}",
"metadata": {"required": True},
}
)
# input is a output of another task, we make a link
if isinstance(nodes[key], NodeSocket):
links[f"nodes.{key}"] = nodes[key]
# Output socket itself is not a value, so we remove the key from the nodes
nodes.pop(key)
for input in inputs:
if input["name"] not in tdata["inputs"]:
input["list_index"] = len(tdata["inputs"]) + 1
tdata["inputs"][input["name"]] = input
# Extend the outputs
for output in [
{"identifier": "workgraph.any", "name": "stdout"},
Expand Down Expand Up @@ -404,7 +379,7 @@ def build_shelljob_task(
tdata["metadata"]["task_type"] = "SHELLJOB"
task = create_task(tdata)
task.is_aiida_component = True
return task, tdata, links
return task, tdata


def build_task_from_workgraph(wg: any) -> Task:
Expand Down
22 changes: 22 additions & 0 deletions src/aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,28 @@ def set_context(self, context: Dict[str, Any]) -> None:
raise ValueError(msg)
self.context_mapping.update(context)

def set(self, data: Dict[str, Any]) -> None:
from node_graph.socket import NodeSocket

super().set(data)
# create input sockets and links for items inside a dynamic socket
# TODO the input value could be nested, but we only support one level for now
for key, value in data.items():
if self.inputs[key].metadata.get("dynamic", False):
if isinstance(value, dict):
keys = list(value.keys())
for sub_key in keys:
self.inputs.new(
"workgraph.any",
name=f"{key}.{sub_key}",
metadata={"required": True},
)
if isinstance(value[sub_key], NodeSocket):
self.parent.links.new(
value[sub_key], self.inputs[f"{key}.{sub_key}"]
)
self.inputs[key].value.pop(sub_key)

def set_from_builder(self, builder: Any) -> None:
"""Set the task inputs from a AiiDA ProcessBuilder."""
from aiida_workgraph.utils import get_dict_from_builder
Expand Down
18 changes: 11 additions & 7 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from aiida_workgraph import WorkGraph, task
from aiida_shell.launch import prepare_code
from aiida.orm import SinglefileData, load_computer
from aiida.orm import SinglefileData, load_computer, Int


def test_prepare_for_shell_task_nonexistent():
Expand Down Expand Up @@ -51,9 +51,9 @@ def test_shell_code():
assert job1.node.outputs.stdout.get_content() == "string astring b"


def test_shell_set():
def test_dynamic_port():
"""Set the nodes during/after the creation of the task."""
wg = WorkGraph(name="test_shell_set")
wg = WorkGraph(name="test_dynamic_port")
echo_task = wg.add_task(
"ShellJob",
name="echo",
Expand All @@ -68,11 +68,15 @@ def test_shell_set():
name="cat",
command="cat",
arguments=["{input}"],
nodes={"input": None},
nodes={"input1": None, "input2": Int(2), "input3": echo_task.outputs["_wait"]},
)
wg.add_link(echo_task.outputs["copied_file"], cat_task.inputs["nodes.input"])
wg.submit(wait=True)
assert cat_task.outputs["stdout"].value.get_content() == "1 5 1"
wg.add_link(echo_task.outputs["copied_file"], cat_task.inputs["nodes.input1"])
# task will create input for each item in the dynamic port (nodes)
assert "nodes.input1" in cat_task.inputs.keys()
assert "nodes.input2" in cat_task.inputs.keys()
# if the value of the item is a Socket, then it will create a link, and pop the item
assert "nodes.input3" in cat_task.inputs.keys()
assert cat_task.inputs["nodes"].value == {"input1": None, "input2": Int(2)}


@pytest.mark.usefixtures("started_daemon_client")
Expand Down
23 changes: 23 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from aiida_workgraph import WorkGraph, task
from typing import Callable
from aiida.cmdline.utils.common import get_workchain_report
from aiida import orm


def test_normal_task(decorated_add) -> None:
Expand Down Expand Up @@ -73,6 +74,28 @@ def test_task_wait(decorated_add: Callable) -> None:
assert "tasks ready to run: add1" in report


def test_set_dynamic_port_input(decorated_add) -> None:
from .utils.test_workchain import WorkChainWithDynamicNamespace

wg = WorkGraph(name="test_set_dynamic_port_input")
task1 = wg.add_task(decorated_add)
task2 = wg.add_task(
WorkChainWithDynamicNamespace,
dynamic_port={
"input1": None,
"input2": orm.Int(2),
"input3": task1.outputs["result"],
},
)
wg.add_link(task1.outputs["_wait"], task2.inputs["dynamic_port.input1"])
# task will create input for each item in the dynamic port (nodes)
assert "dynamic_port.input1" in task2.inputs.keys()
assert "dynamic_port.input2" in task2.inputs.keys()
# if the value of the item is a Socket, then it will create a link, and pop the item
assert "dynamic_port.input3" in task2.inputs.keys()
assert task2.inputs["dynamic_port"].value == {"input1": None, "input2": orm.Int(2)}


def test_set_inputs(decorated_add: Callable) -> None:
"""Test setting inputs of a task."""

Expand Down
10 changes: 10 additions & 0 deletions tests/utils/test_workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,13 @@ def validate_result(self):
def result(self):
"""Add the result to the outputs."""
self.out("result", self.ctx.addition.outputs.sum)


class WorkChainWithDynamicNamespace(WorkChain):
"""WorkChain with dynamic namespace."""

@classmethod
def define(cls, spec):
"""Specify inputs and outputs."""
super().define(spec)
spec.input_namespace("dynamic_port", dynamic=True)

0 comments on commit 49a3c26

Please sign in to comment.