Skip to content

Commit

Permalink
Speed up unit test and skip unstable test (#205)
Browse files Browse the repository at this point in the history
* Allows the WorkGraph to wait for the particular states of a task.
* Remove unused tasks and duplicated tests.
* Replace the calcjob task with the calcfunction task.
* Skip test for palying paused task, because it is unstable.
  • Loading branch information
superstar54 authored Aug 8, 2024
1 parent 262b744 commit b4e443f
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 147 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ jobs:
env:
AIIDA_WARN_v3: 1
run: |
pytest -v tests --cov
pytest -v tests --cov --durations=0
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4.0.1
Expand Down
2 changes: 1 addition & 1 deletion aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
"PYTHONJOB",
"SHELLJOB",
]:
if len(self._awaitables) > self.ctx.max_number_awaitables:
if len(self._awaitables) >= self.ctx.max_number_awaitables:
print(
MAX_NUMBER_AWAITABLES_MSG.format(
self.ctx.max_number_awaitables, name
Expand Down
26 changes: 17 additions & 9 deletions aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,28 +210,36 @@ def to_dict(self, store_nodes=False) -> Dict[str, Any]:

return wgdata

def wait(self, timeout: int = 50) -> None:
def wait(self, timeout: int = 50, tasks: dict = None) -> None:
"""
Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout.
Args:
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 50.
"""

start = time.time()
self.update()
while self.state not in (
terminating_states = (
"KILLED",
"PAUSED",
"FINISHED",
"FAILED",
"CANCELLED",
"EXCEPTED",
):
time.sleep(0.5)
)
start = time.time()
self.update()
finished = False
while not finished:
self.update()
if tasks is not None:
states = []
for name, value in tasks.items():
flag = self.tasks[name].state in value
states.append(flag)
finished = all(states)
else:
finished = self.state in terminating_states
time.sleep(0.5)
if time.time() - start > timeout:
return
break

def update(self) -> None:
"""
Expand Down
36 changes: 10 additions & 26 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from aiida_workgraph import task, WorkGraph
from aiida.engine import calcfunction, workfunction
from aiida.orm import Float, Int, StructureData
from aiida.orm import Int, StructureData
from aiida.calculations.arithmetic.add import ArithmeticAddCalculation
from typing import Callable, Any, Union
import time
Expand Down Expand Up @@ -61,13 +61,9 @@ def wg_calcfunction() -> WorkGraph:
"""A workgraph with calcfunction."""

wg = WorkGraph(name="test_debug_math")
float1 = wg.add_task("AiiDANode", "float1", pk=Float(3.0).store().pk)
sumdiff1 = wg.add_task("AiiDASumDiff", "sumdiff1", x=2)
sumdiff1 = wg.add_task("AiiDASumDiff", "sumdiff1", x=2, y=3)
sumdiff2 = wg.add_task("AiiDASumDiff", "sumdiff2", x=4)
sumdiff3 = wg.add_task("AiiDASumDiff", "sumdiff3", x=6)
wg.add_link(float1.outputs[0], sumdiff1.inputs[1])
wg.add_link(sumdiff1.outputs[0], sumdiff2.inputs[1])
wg.add_link(sumdiff2.outputs[0], sumdiff3.inputs[1])
return wg


Expand All @@ -78,17 +74,9 @@ def wg_calcjob(add_code) -> WorkGraph:
print("add_code", add_code)

wg = WorkGraph(name="test_debug_math")
int1 = wg.add_task("AiiDANode", "int1", pk=Int(3).store().pk)
code1 = wg.add_task("AiiDACode", "code1", pk=add_code.pk)
add1 = wg.add_task(ArithmeticAddCalculation, "add1", x=Int(2).store())
add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=Int(4).store())
add3 = wg.add_task(ArithmeticAddCalculation, "add3", x=Int(4).store())
wg.add_link(code1.outputs[0], add1.inputs["code"])
wg.add_link(int1.outputs[0], add1.inputs["y"])
wg.add_link(code1.outputs[0], add2.inputs["code"])
add1 = wg.add_task(ArithmeticAddCalculation, "add1", x=2, y=3, code=add_code)
add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=4, code=add_code)
wg.add_link(add1.outputs["sum"], add2.inputs["y"])
wg.add_link(code1.outputs[0], add3.inputs["code"])
wg.add_link(add2.outputs["sum"], add3.inputs["y"])
return wg


Expand Down Expand Up @@ -245,17 +233,13 @@ def wg_structure_si() -> WorkGraph:
def wg_engine(decorated_add, add_code) -> WorkGraph:
"""Use to test the engine."""
code = add_code
x = Int(2)
wg = WorkGraph(name="test_run_order")
add0 = wg.add_task(ArithmeticAddCalculation, "add0", x=x, y=Int(0), code=code)
add0.set({"metadata.options.sleep": 15})
add1 = wg.add_task(decorated_add, "add1", x=x, y=Int(1), t=Int(1))
add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=x, y=Int(2), code=code)
add2.set({"metadata.options.sleep": 1})
add3 = wg.add_task(decorated_add, "add3", x=x, y=Int(3), t=Int(1))
add4 = wg.add_task(ArithmeticAddCalculation, "add4", x=x, y=Int(4), code=code)
add4.set({"metadata.options.sleep": 1})
add5 = wg.add_task(decorated_add, "add5", x=x, y=Int(5), t=Int(1))
add0 = wg.add_task(ArithmeticAddCalculation, "add0", x=2, y=0, code=code)
add1 = wg.add_task(decorated_add, "add1", x=2, y=1)
add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=2, y=2, code=code)
add3 = wg.add_task(decorated_add, "add3", x=2, y=3)
add4 = wg.add_task(ArithmeticAddCalculation, "add4", x=2, y=4, code=code)
add5 = wg.add_task(decorated_add, "add5", x=2, y=5)
wg.add_link(add0.outputs["sum"], add2.inputs["x"])
wg.add_link(add1.outputs[0], add3.inputs["x"])
wg.add_link(add3.outputs[0], add4.inputs["x"])
Expand Down
3 changes: 0 additions & 3 deletions tests/test_calcjob.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
from aiida_workgraph import WorkGraph
import os


@pytest.mark.usefixtures("started_daemon_client")
Expand All @@ -9,6 +8,4 @@ def test_submit(wg_calcjob: WorkGraph) -> None:
wg = wg_calcjob
wg.name = "test_submit_calcjob"
wg.submit(wait=True)
os.system("verdi process list -a")
os.system(f"verdi process report {wg.pk}")
assert wg.tasks["add2"].outputs["sum"].value == 9
23 changes: 12 additions & 11 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import time
import pytest
from aiida_workgraph import WorkGraph
from aiida.cmdline.utils.common import get_workchain_report


@pytest.mark.usefixtures("started_daemon_client")
def test_run_order(wg_engine: WorkGraph) -> None:
def test_run_order(decorated_add) -> None:
"""Test the order.
Tasks should run in parallel and only depend on the input tasks."""
wg = wg_engine
wg = WorkGraph(name="test_run_order")
wg.add_task(decorated_add, "add0", x=2, y=0)
wg.add_task(decorated_add, "add1", x=2, y=1)
wg.submit(wait=True)
wg.tasks["add2"].ctime < wg.tasks["add4"].ctime
report = get_workchain_report(wg.process, "REPORT")
assert "tasks ready to run: add0,add1" in report


@pytest.mark.skip(reason="The test is not stable.")
Expand All @@ -28,23 +32,20 @@ def test_reset_node(wg_engine: WorkGraph) -> None:
assert len(wg.process.base.extras.get("_workgraph_queue")) == 1


@pytest.mark.usefixtures("started_daemon_client")
def test_max_number_jobs(add_code) -> None:
from aiida_workgraph import WorkGraph
from aiida.orm import Int
from aiida.calculations.arithmetic.add import ArithmeticAddCalculation

wg = WorkGraph("test_max_number_jobs")
N = 9
N = 3
# Create N nodes
for i in range(N):
temp = wg.add_task(
wg.add_task(
ArithmeticAddCalculation, name=f"add{i}", x=Int(1), y=Int(1), code=add_code
)
# Set a sleep option for each job (e.g., 2 seconds per job)
temp.set({"metadata.options.sleep": 1})

# Set the maximum number of running jobs inside the WorkGraph
wg.max_number_jobs = 3
wg.max_number_jobs = 2
wg.submit(wait=True, timeout=100)
wg.tasks["add1"].ctime < wg.tasks["add8"].ctime
report = get_workchain_report(wg.process, "REPORT")
assert "tasks ready to run: add2" in report
26 changes: 2 additions & 24 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,28 +98,6 @@ def parser(self, dirpath):
{"identifier": "Any", "name": "result"}
], # add a "result" output socket from the parser
)
# echo result + y expression
job3 = wg.add_task(
"ShellJob",
name="job3",
command="echo",
arguments=["{result}", "*", "{z}"],
nodes={"result": job2.outputs["result"], "z": Int(4)},
)
# bc command to calculate the expression
job4 = wg.add_task(
"ShellJob",
name="job4",
command="bc",
arguments=["{expression}"],
nodes={"expression": job3.outputs["stdout"]},
parser=PickledData(parser),
parser_outputs=[
{"identifier": "Any", "name": "result"}
], # add a "result" output socket from the parser
)
# there is a bug in aiida-shell, the following line will raise an error
# https://github.com/sphuber/aiida-shell/issues/91
# wg.submit(wait=True, timeout=200)

wg.run()
assert job4.outputs["result"].value.value == 20
assert job2.outputs["result"].value.value == 5
8 changes: 4 additions & 4 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def test_build_task_from_workgraph(wg_calcfunction, decorated_add):
wg_task = wg.add_task(wg_calcfunction, name="wg_calcfunction")
wg.add_task(decorated_add, name="add2", y=3)
wg.add_link(add1_task.outputs["result"], wg_task.inputs["sumdiff1.x"])
wg.add_link(wg_task.outputs["sumdiff3.sum"], wg.tasks["add2"].inputs["x"])
assert len(wg_task.inputs) == 15
assert len(wg_task.outputs) == 13
wg.add_link(wg_task.outputs["sumdiff2.sum"], wg.tasks["add2"].inputs["x"])
assert len(wg_task.inputs) == 7
assert len(wg_task.outputs) == 8
wg.submit(wait=True)
assert wg.tasks["add2"].outputs["result"].value.value == 20
assert wg.tasks["add2"].outputs["result"].value.value == 14
53 changes: 12 additions & 41 deletions tests/test_while.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_while_task(decorated_add, decorated_multiply, decorated_compare):
# update the context variable
multiply1.set_context({"result": "n"})
compare1 = wg.add_task(
decorated_compare, name="compare1", x=multiply1.outputs["result"], y=50
decorated_compare, name="compare1", x=multiply1.outputs["result"], y=30
)
compare1.set_context({"result": "should_run"})
wg.add_task(
Expand All @@ -33,67 +33,38 @@ def test_while_task(decorated_add, decorated_multiply, decorated_compare):
add3 = wg.add_task(decorated_add, name="add3", x=1, y=1)
wg.add_link(multiply1.outputs["result"], add3.inputs["x"])
wg.submit(wait=True, timeout=100)
assert wg.tasks["add3"].outputs["result"].value == 63
assert wg.tasks["add3"].outputs["result"].value == 31


@pytest.mark.usefixtures("started_daemon_client")
def test_while(decorated_add, decorated_multiply, decorated_compare):
def test_while_workgraph(decorated_add, decorated_multiply, decorated_compare):
# Create a WorkGraph will repeat itself based on the conditions
wg = WorkGraph("while_workgraph")
wg.workgraph_type = "WHILE"
wg.conditions = ["compare1.result"]
wg.context = {"n": 1}
wg.max_iteration = 10
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=50)
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=20)
multiply1 = wg.add_task(
decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2)
)
add1 = wg.add_task(decorated_add, name="add1", y=3)
add1.set_context({"result": "n"})
wg.add_link(multiply1.outputs["result"], add1.inputs["x"])
wg.submit(wait=True, timeout=100)
assert wg.execution_count == 4
assert wg.tasks["add1"].outputs["result"].value == 61
assert wg.execution_count == 3
assert wg.tasks["add1"].outputs["result"].value == 29


@pytest.mark.usefixtures("started_daemon_client")
def test_while_graph_builder(decorated_add, decorated_multiply, decorated_compare):
# Create a WorkGraph will repeat itself based on the conditions
@task.graph_builder(outputs=[{"name": "result", "from": "context.n"}])
def my_while(n=0, limit=100):
wg = WorkGraph("while_workgraph")
wg.workgraph_type = "WHILE"
wg.conditions = ["compare1.result"]
wg.context = {"n": n}
wg.max_iteration = 10
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=orm.Int(limit))
multiply1 = wg.add_task(
decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2)
)
add1 = wg.add_task(decorated_add, name="add1", y=3)
add1.set_context({"result": "n"})
wg.add_link(multiply1.outputs["result"], add1.inputs["x"])
return wg

# -----------------------------------------
wg = WorkGraph("while")
add1 = wg.add_task(decorated_add, name="add1", x=orm.Int(25), y=orm.Int(25))
my_while1 = wg.add_task(my_while, n=orm.Int(1))
add2 = wg.add_task(decorated_add, name="add2", y=orm.Int(2))
wg.add_link(add1.outputs["result"], my_while1.inputs["limit"])
wg.add_link(my_while1.outputs["result"], add2.inputs["x"])
wg.submit(wait=True, timeout=100)
assert add2.outputs["result"].value == 63
assert my_while1.node.outputs.execution_count == 4
assert my_while1.outputs["result"].value == 61

"""Test the while WorkGraph in graph builder.
Also test the max_iteration parameter."""

def test_while_max_iteration(decorated_add, decorated_multiply, decorated_compare):
# Create a WorkGraph will repeat itself based on the conditions
@task.graph_builder(outputs=[{"name": "result", "from": "context.n"}])
def my_while(n=0, limit=100):
wg = WorkGraph("while_workgraph")
wg.workgraph_type = "WHILE"
wg.max_iteration = 3
wg.max_iteration = 2
wg.conditions = ["compare1.result"]
wg.context = {"n": n}
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=orm.Int(limit))
Expand All @@ -113,5 +84,5 @@ def my_while(n=0, limit=100):
wg.add_link(add1.outputs["result"], my_while1.inputs["limit"])
wg.add_link(my_while1.outputs["result"], add2.inputs["x"])
wg.submit(wait=True, timeout=100)
assert add2.outputs["result"].value < 63
assert my_while1.node.outputs.execution_count == 3
assert add2.outputs["result"].value < 31
assert my_while1.node.outputs.execution_count == 2
Loading

0 comments on commit b4e443f

Please sign in to comment.