Skip to content

Commit

Permalink
Allow user to define the inputs manually for dynamic input
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Aug 27, 2024
1 parent 1e0c64f commit 9b23ecc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
9 changes: 9 additions & 0 deletions aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def build_task_from_AiiDA(
spec = executor.spec()
args = []
kwargs = []
user_defined_input_names = [input["name"] for input in inputs]
for _key, port in spec.inputs.ports.items():
add_input_recursive(inputs, port, args, kwargs, required=port.required)
for _key, port in spec.outputs.ports.items():
Expand All @@ -230,6 +231,14 @@ def build_task_from_AiiDA(
"property": {"identifier": "workgraph.any", "default": {}},
}
)
# if user define input names does not included in the args and kwargs,
# this will be the case for dynamic inputs
for key in user_defined_input_names:
if key not in args and key not in kwargs:
if key == name:
continue
kwargs.append(key)

# TODO In order to reload the WorkGraph from process, "is_pickle" should be True
# so I pickled the function here, but this is not necessary
# we need to update the node_graph to support the path and name of the function
Expand Down
23 changes: 15 additions & 8 deletions tests/test_calcfunction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from aiida_workgraph import WorkGraph
from aiida_workgraph import WorkGraph, task
from aiida import orm


def test_run(wg_calcfunction: WorkGraph) -> None:
Expand All @@ -14,10 +15,16 @@ def test_run(wg_calcfunction: WorkGraph) -> None:


@pytest.mark.usefixtures("started_daemon_client")
def test_submit(wg_calcfunction: WorkGraph) -> None:
"""Submit simple calcfunction."""
wg = wg_calcfunction
wg.name = "test_submit_calcfunction"
wg.submit(wait=True)
# print("results: ", results[])
assert wg.tasks["sumdiff2"].outputs["sum"].value == 9
def test_dynamic_inputs() -> None:
"""Test dynamic inputs.
For dynamic inputs, we allow the user to define the inputs manually.
"""

@task.calcfunction(inputs=[{"name": "x"}, {"name": "y"}])
def add(**kwargs):
return kwargs["x"] + kwargs["y"]

wg = WorkGraph("test_dynamic_inputs")
wg.add_task(add, name="add1", x=orm.Int(1), y=orm.Int(2))
wg.run()
assert wg.tasks["add1"].outputs["result"].value == 3

0 comments on commit 9b23ecc

Please sign in to comment.