Skip to content

Commit

Permalink
WorkGraph waits for partical states of a task.
Browse files Browse the repository at this point in the history
This can be used in the unit test. Intead of sleep for a time manually.
  • Loading branch information
superstar54 committed Aug 8, 2024
1 parent 7c2337b commit 37dfef6
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 124 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ jobs:
env:
AIIDA_WARN_v3: 1
run: |
pytest -v tests --cov
pytest -v tests/test_workgraph.py --cov --durations=0
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4.0.1
Expand Down
26 changes: 17 additions & 9 deletions aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,28 +210,36 @@ def to_dict(self, store_nodes=False) -> Dict[str, Any]:

return wgdata

def wait(self, timeout: int = 50) -> None:
def wait(self, timeout: int = 50, tasks: dict = None) -> None:
"""
Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout.
Args:
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 50.
"""

start = time.time()
self.update()
while self.state not in (
terminating_states = (
"KILLED",
"PAUSED",
"FINISHED",
"FAILED",
"CANCELLED",
"EXCEPTED",
):
time.sleep(0.5)
)
start = time.time()
self.update()
finished = False
while not finished:
self.update()
if tasks is not None:
states = []
for name, value in tasks.items():
flag = self.tasks[name].state in value
states.append(flag)
finished = all(states)
else:
finished = self.state in terminating_states
time.sleep(0.5)
if time.time() - start > timeout:
return
break

def update(self) -> None:
"""
Expand Down
26 changes: 7 additions & 19 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from aiida_workgraph import task, WorkGraph
from aiida.engine import calcfunction, workfunction
from aiida.orm import Float, Int, StructureData
from aiida.orm import Int, StructureData
from aiida.calculations.arithmetic.add import ArithmeticAddCalculation
from typing import Callable, Any, Union
import time
Expand Down Expand Up @@ -61,13 +61,9 @@ def wg_calcfunction() -> WorkGraph:
"""A workgraph with calcfunction."""

wg = WorkGraph(name="test_debug_math")
float1 = wg.add_task("AiiDANode", "float1", pk=Float(3.0).store().pk)
sumdiff1 = wg.add_task("AiiDASumDiff", "sumdiff1", x=2)
sumdiff1 = wg.add_task("AiiDASumDiff", "sumdiff1", x=2, y=3)
sumdiff2 = wg.add_task("AiiDASumDiff", "sumdiff2", x=4)
sumdiff3 = wg.add_task("AiiDASumDiff", "sumdiff3", x=6)
wg.add_link(float1.outputs[0], sumdiff1.inputs[1])
wg.add_link(sumdiff1.outputs[0], sumdiff2.inputs[1])
wg.add_link(sumdiff2.outputs[0], sumdiff3.inputs[1])
return wg


Expand All @@ -78,17 +74,9 @@ def wg_calcjob(add_code) -> WorkGraph:
print("add_code", add_code)

wg = WorkGraph(name="test_debug_math")
int1 = wg.add_task("AiiDANode", "int1", pk=Int(3).store().pk)
code1 = wg.add_task("AiiDACode", "code1", pk=add_code.pk)
add1 = wg.add_task(ArithmeticAddCalculation, "add1", x=Int(2).store())
add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=Int(4).store())
add3 = wg.add_task(ArithmeticAddCalculation, "add3", x=Int(4).store())
wg.add_link(code1.outputs[0], add1.inputs["code"])
wg.add_link(int1.outputs[0], add1.inputs["y"])
wg.add_link(code1.outputs[0], add2.inputs["code"])
add1 = wg.add_task(ArithmeticAddCalculation, "add1", x=2, y=3, code=add_code)
add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=4, code=add_code)
wg.add_link(add1.outputs["sum"], add2.inputs["y"])
wg.add_link(code1.outputs[0], add3.inputs["code"])
wg.add_link(add2.outputs["sum"], add3.inputs["y"])
return wg


Expand Down Expand Up @@ -247,11 +235,11 @@ def wg_engine(decorated_add, add_code) -> WorkGraph:
code = add_code
wg = WorkGraph(name="test_run_order")
add0 = wg.add_task(ArithmeticAddCalculation, "add0", x=2, y=0, code=code)
add1 = wg.add_task(decorated_add, "add1", x=2, y=1, t=1)
add1 = wg.add_task(decorated_add, "add1", x=2, y=1)
add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=2, y=2, code=code)
add3 = wg.add_task(decorated_add, "add3", x=2, y=3, t=1)
add3 = wg.add_task(decorated_add, "add3", x=2, y=3)
add4 = wg.add_task(ArithmeticAddCalculation, "add4", x=2, y=4, code=code)
add5 = wg.add_task(decorated_add, "add5", x=2, y=5, t=1)
add5 = wg.add_task(decorated_add, "add5", x=2, y=5)
wg.add_link(add0.outputs["sum"], add2.inputs["x"])
wg.add_link(add1.outputs[0], add3.inputs["x"])
wg.add_link(add3.outputs[0], add4.inputs["x"])
Expand Down
7 changes: 4 additions & 3 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@


@pytest.mark.usefixtures("started_daemon_client")
def test_run_order(wg_engine: WorkGraph) -> None:
def test_run_order(decorated_add) -> None:
"""Test the order.
Tasks should run in parallel and only depend on the input tasks."""
wg = wg_engine
wg = WorkGraph(name="test_run_order")
wg.add_task(decorated_add, "add0", x=2, y=0)
wg.add_task(decorated_add, "add1", x=2, y=1)
wg.submit(wait=True)
report = get_workchain_report(wg.process, "REPORT")
assert "tasks ready to run: add0,add1" in report
Expand All @@ -30,7 +32,6 @@ def test_reset_node(wg_engine: WorkGraph) -> None:
assert len(wg.process.base.extras.get("_workgraph_queue")) == 1


@pytest.mark.usefixtures("started_daemon_client")
def test_max_number_jobs(add_code) -> None:
from aiida_workgraph import WorkGraph
from aiida.orm import Int
Expand Down
26 changes: 2 additions & 24 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,28 +98,6 @@ def parser(self, dirpath):
{"identifier": "Any", "name": "result"}
], # add a "result" output socket from the parser
)
# echo result + y expression
job3 = wg.add_task(
"ShellJob",
name="job3",
command="echo",
arguments=["{result}", "*", "{z}"],
nodes={"result": job2.outputs["result"], "z": Int(4)},
)
# bc command to calculate the expression
job4 = wg.add_task(
"ShellJob",
name="job4",
command="bc",
arguments=["{expression}"],
nodes={"expression": job3.outputs["stdout"]},
parser=PickledData(parser),
parser_outputs=[
{"identifier": "Any", "name": "result"}
], # add a "result" output socket from the parser
)
# there is a bug in aiida-shell, the following line will raise an error
# https://github.com/sphuber/aiida-shell/issues/91
# wg.submit(wait=True, timeout=200)

wg.run()
assert job4.outputs["result"].value.value == 20
assert job2.outputs["result"].value.value == 5
8 changes: 4 additions & 4 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def test_build_task_from_workgraph(wg_calcfunction, decorated_add):
wg_task = wg.add_task(wg_calcfunction, name="wg_calcfunction")
wg.add_task(decorated_add, name="add2", y=3)
wg.add_link(add1_task.outputs["result"], wg_task.inputs["sumdiff1.x"])
wg.add_link(wg_task.outputs["sumdiff3.sum"], wg.tasks["add2"].inputs["x"])
assert len(wg_task.inputs) == 15
assert len(wg_task.outputs) == 13
wg.add_link(wg_task.outputs["sumdiff2.sum"], wg.tasks["add2"].inputs["x"])
assert len(wg_task.inputs) == 7
assert len(wg_task.outputs) == 8
wg.submit(wait=True)
assert wg.tasks["add2"].outputs["result"].value.value == 20
assert wg.tasks["add2"].outputs["result"].value.value == 14
53 changes: 12 additions & 41 deletions tests/test_while.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_while_task(decorated_add, decorated_multiply, decorated_compare):
# update the context variable
multiply1.set_context({"result": "n"})
compare1 = wg.add_task(
decorated_compare, name="compare1", x=multiply1.outputs["result"], y=50
decorated_compare, name="compare1", x=multiply1.outputs["result"], y=30
)
compare1.set_context({"result": "should_run"})
wg.add_task(
Expand All @@ -33,67 +33,38 @@ def test_while_task(decorated_add, decorated_multiply, decorated_compare):
add3 = wg.add_task(decorated_add, name="add3", x=1, y=1)
wg.add_link(multiply1.outputs["result"], add3.inputs["x"])
wg.submit(wait=True, timeout=100)
assert wg.tasks["add3"].outputs["result"].value == 63
assert wg.tasks["add3"].outputs["result"].value == 31


@pytest.mark.usefixtures("started_daemon_client")
def test_while(decorated_add, decorated_multiply, decorated_compare):
def test_while_workgraph(decorated_add, decorated_multiply, decorated_compare):
# Create a WorkGraph will repeat itself based on the conditions
wg = WorkGraph("while_workgraph")
wg.workgraph_type = "WHILE"
wg.conditions = ["compare1.result"]
wg.context = {"n": 1}
wg.max_iteration = 10
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=50)
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=20)
multiply1 = wg.add_task(
decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2)
)
add1 = wg.add_task(decorated_add, name="add1", y=3)
add1.set_context({"result": "n"})
wg.add_link(multiply1.outputs["result"], add1.inputs["x"])
wg.submit(wait=True, timeout=100)
assert wg.execution_count == 4
assert wg.tasks["add1"].outputs["result"].value == 61
assert wg.execution_count == 3
assert wg.tasks["add1"].outputs["result"].value == 29


@pytest.mark.usefixtures("started_daemon_client")
def test_while_graph_builder(decorated_add, decorated_multiply, decorated_compare):
# Create a WorkGraph will repeat itself based on the conditions
@task.graph_builder(outputs=[{"name": "result", "from": "context.n"}])
def my_while(n=0, limit=100):
wg = WorkGraph("while_workgraph")
wg.workgraph_type = "WHILE"
wg.conditions = ["compare1.result"]
wg.context = {"n": n}
wg.max_iteration = 10
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=orm.Int(limit))
multiply1 = wg.add_task(
decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2)
)
add1 = wg.add_task(decorated_add, name="add1", y=3)
add1.set_context({"result": "n"})
wg.add_link(multiply1.outputs["result"], add1.inputs["x"])
return wg

# -----------------------------------------
wg = WorkGraph("while")
add1 = wg.add_task(decorated_add, name="add1", x=orm.Int(25), y=orm.Int(25))
my_while1 = wg.add_task(my_while, n=orm.Int(1))
add2 = wg.add_task(decorated_add, name="add2", y=orm.Int(2))
wg.add_link(add1.outputs["result"], my_while1.inputs["limit"])
wg.add_link(my_while1.outputs["result"], add2.inputs["x"])
wg.submit(wait=True, timeout=100)
assert add2.outputs["result"].value == 63
assert my_while1.node.outputs.execution_count == 4
assert my_while1.outputs["result"].value == 61

"""Test the while WorkGraph in graph builder.
Also test the max_iteration parameter."""

def test_while_max_iteration(decorated_add, decorated_multiply, decorated_compare):
# Create a WorkGraph will repeat itself based on the conditions
@task.graph_builder(outputs=[{"name": "result", "from": "context.n"}])
def my_while(n=0, limit=100):
wg = WorkGraph("while_workgraph")
wg.workgraph_type = "WHILE"
wg.max_iteration = 3
wg.max_iteration = 2
wg.conditions = ["compare1.result"]
wg.context = {"n": n}
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=orm.Int(limit))
Expand All @@ -113,5 +84,5 @@ def my_while(n=0, limit=100):
wg.add_link(add1.outputs["result"], my_while1.inputs["limit"])
wg.add_link(my_while1.outputs["result"], add2.inputs["x"])
wg.submit(wait=True, timeout=100)
assert add2.outputs["result"].value < 63
assert my_while1.node.outputs.execution_count == 3
assert add2.outputs["result"].value < 31
assert my_while1.node.outputs.execution_count == 2
52 changes: 29 additions & 23 deletions tests/test_workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
from aiida.calculations.arithmetic.add import ArithmeticAddCalculation


def test_to_dict(wg_calcjob):
def test_to_dict(wg_calcfunction):
"""Export NodeGraph to dict."""
wg = wg_calcjob
wg = wg_calcfunction
wgdata = wg.to_dict()
assert len(wgdata["tasks"]) == len(wg.tasks)
assert len(wgdata["links"]) == len(wg.links)


def test_from_dict(wg_calcjob):
def test_from_dict(wg_calcfunction):
"""Export NodeGraph to dict."""
wg = wg_calcjob
wg = wg_calcfunction
wgdata = wg.to_dict()
wg1 = WorkGraph.from_dict(wgdata)
assert len(wg.tasks) == len(wg1.tasks)
Expand All @@ -32,9 +32,9 @@ def test_add_task():
assert len(wg.links) == 1


def test_save_load(wg_calcjob):
def test_save_load(wg_calcfunction):
"""Save the workgraph"""
wg = wg_calcjob
wg = wg_calcfunction
wg.name = "test_save_load"
wg.save()
assert wg.process.process_state.value.upper() == "CREATED"
Expand Down Expand Up @@ -73,23 +73,24 @@ def test_reset_message(wg_calcjob):
assert "Task add2 action: RESET." in report


def test_restart(wg_calcjob):
def test_restart(wg_calcfunction):
"""Restart from a finished workgraph.
Load the workgraph, modify the task, and restart the workgraph.
Only the modified node and its child tasks will be rerun."""
wg = wg_calcjob
wg = wg_calcfunction
wg.add_task("AiiDASumDiff", "sumdiff3", x=4, y=wg.tasks["sumdiff2"].outputs["sum"])
wg.name = "test_restart_0"
wg.submit(wait=True)
wg1 = WorkGraph.load(wg.process.pk)
wg1.restart()
wg1.name = "test_restart_1"
wg1.tasks["add2"].set({"x": orm.Int(10).store()})
wg1.tasks["sumdiff2"].set({"x": orm.Int(10).store()})
# wg1.save()
wg1.submit(wait=True)
assert wg1.tasks["add1"].node.pk == wg.tasks["add1"].pk
assert wg1.tasks["add2"].node.pk != wg.tasks["add2"].pk
assert wg1.tasks["add3"].node.pk != wg.tasks["add3"].pk
assert wg1.tasks["add3"].node.outputs.sum == 19
assert wg1.tasks["sumdiff1"].node.pk == wg.tasks["sumdiff1"].pk
assert wg1.tasks["sumdiff2"].node.pk != wg.tasks["sumdiff2"].pk
assert wg1.tasks["sumdiff3"].node.pk != wg.tasks["sumdiff3"].pk
assert wg1.tasks["sumdiff3"].node.outputs.sum == 19


def test_extend_workgraph(decorated_add_multiply_group):
Expand All @@ -105,38 +106,43 @@ def test_extend_workgraph(decorated_add_multiply_group):
assert wg.tasks["group_multiply1"].node.outputs.result == 45


@pytest.mark.usefixtures("started_daemon_client")
def test_pause_task_before_submit(wg_calcjob):
wg = wg_calcjob
wg.name = "test_pause_task"
wg.pause_tasks(["add2"])
wg.submit()
time.sleep(20)
wg.update()
wg.wait(tasks={"add1": ["FINISHED"]}, timeout=20)
assert wg.tasks["add1"].node.process_state.value.upper() == "FINISHED"
# wait for the workgraph to launch add2
wg.wait(tasks={"add2": ["CREATED"]}, timeout=20)
assert wg.tasks["add2"].node.process_state.value.upper() == "CREATED"
assert wg.tasks["add2"].node.process_status == "Paused through WorkGraph"
wg.play_tasks(["add2"])
wg.wait()
wg.play_tasks(["add2"])
wg.wait(tasks={"add2": ["FINISHED"]})
assert wg.tasks["add2"].outputs["sum"].value == 9


def test_pause_task_after_submit(wg_calcjob):
wg = wg_calcjob
wg.tasks["add1"].set({"metadata.options.sleep": 3})
wg.name = "test_pause_task"
wg.submit()
# wait for the daemon to start the workgraph
time.sleep(3)
# wg.run()
# wait for the workgraph to launch add1
wg.wait(tasks={"add1": ["CREATED", "WAITING", "RUNNING", "FINISHED"]}, timeout=20)
wg.pause_tasks(["add2"])
time.sleep(20)
wg.update()
wg.wait(tasks={"add1": ["FINISHED"]}, timeout=20)
# wait for the workgraph to launch add2
wg.wait(tasks={"add2": ["CREATED"]}, timeout=20)
assert wg.tasks["add2"].node.process_state.value.upper() == "CREATED"
assert wg.tasks["add2"].node.process_status == "Paused through WorkGraph"
wg.play_tasks(["add2"])
wg.wait()
wg.play_tasks(["add2"])
wg.wait(tasks={"add2": ["FINISHED"]})
assert wg.tasks["add2"].outputs["sum"].value == 9


@pytest.mark.usefixtures("started_daemon_client")
def test_workgraph_group_outputs(decorated_add):
wg = WorkGraph("test_workgraph_group_outputs")
wg.add_task(decorated_add, "add1", x=2, y=3)
Expand Down

0 comments on commit 37dfef6

Please sign in to comment.