Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up unit test and skip unstable test #205

Merged
merged 3 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ jobs:
env:
AIIDA_WARN_v3: 1
run: |
pytest -v tests --cov
pytest -v tests --cov --durations=0

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4.0.1
Expand Down
2 changes: 1 addition & 1 deletion aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
"PYTHONJOB",
"SHELLJOB",
]:
if len(self._awaitables) > self.ctx.max_number_awaitables:
if len(self._awaitables) >= self.ctx.max_number_awaitables:
print(
MAX_NUMBER_AWAITABLES_MSG.format(
self.ctx.max_number_awaitables, name
Expand Down
26 changes: 17 additions & 9 deletions aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,28 +210,36 @@ def to_dict(self, store_nodes=False) -> Dict[str, Any]:

return wgdata

def wait(self, timeout: int = 50) -> None:
def wait(self, timeout: int = 50, tasks: dict = None) -> None:
"""
Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout.

Args:
timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 50.
"""

start = time.time()
self.update()
while self.state not in (
terminating_states = (
"KILLED",
"PAUSED",
"FINISHED",
"FAILED",
"CANCELLED",
"EXCEPTED",
):
time.sleep(0.5)
)
start = time.time()
self.update()
finished = False
while not finished:
self.update()
if tasks is not None:
states = []
for name, value in tasks.items():
flag = self.tasks[name].state in value
states.append(flag)
finished = all(states)
else:
finished = self.state in terminating_states
time.sleep(0.5)
if time.time() - start > timeout:
return
break

def update(self) -> None:
"""
Expand Down
36 changes: 10 additions & 26 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from aiida_workgraph import task, WorkGraph
from aiida.engine import calcfunction, workfunction
from aiida.orm import Float, Int, StructureData
from aiida.orm import Int, StructureData
from aiida.calculations.arithmetic.add import ArithmeticAddCalculation
from typing import Callable, Any, Union
import time
Expand Down Expand Up @@ -61,13 +61,9 @@ def wg_calcfunction() -> WorkGraph:
"""A workgraph with calcfunction."""

wg = WorkGraph(name="test_debug_math")
float1 = wg.add_task("AiiDANode", "float1", pk=Float(3.0).store().pk)
sumdiff1 = wg.add_task("AiiDASumDiff", "sumdiff1", x=2)
sumdiff1 = wg.add_task("AiiDASumDiff", "sumdiff1", x=2, y=3)
sumdiff2 = wg.add_task("AiiDASumDiff", "sumdiff2", x=4)
sumdiff3 = wg.add_task("AiiDASumDiff", "sumdiff3", x=6)
wg.add_link(float1.outputs[0], sumdiff1.inputs[1])
wg.add_link(sumdiff1.outputs[0], sumdiff2.inputs[1])
wg.add_link(sumdiff2.outputs[0], sumdiff3.inputs[1])
return wg


Expand All @@ -78,17 +74,9 @@ def wg_calcjob(add_code) -> WorkGraph:
print("add_code", add_code)

wg = WorkGraph(name="test_debug_math")
int1 = wg.add_task("AiiDANode", "int1", pk=Int(3).store().pk)
code1 = wg.add_task("AiiDACode", "code1", pk=add_code.pk)
add1 = wg.add_task(ArithmeticAddCalculation, "add1", x=Int(2).store())
add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=Int(4).store())
add3 = wg.add_task(ArithmeticAddCalculation, "add3", x=Int(4).store())
wg.add_link(code1.outputs[0], add1.inputs["code"])
wg.add_link(int1.outputs[0], add1.inputs["y"])
wg.add_link(code1.outputs[0], add2.inputs["code"])
add1 = wg.add_task(ArithmeticAddCalculation, "add1", x=2, y=3, code=add_code)
add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=4, code=add_code)
wg.add_link(add1.outputs["sum"], add2.inputs["y"])
wg.add_link(code1.outputs[0], add3.inputs["code"])
wg.add_link(add2.outputs["sum"], add3.inputs["y"])
return wg


Expand Down Expand Up @@ -245,17 +233,13 @@ def wg_structure_si() -> WorkGraph:
def wg_engine(decorated_add, add_code) -> WorkGraph:
"""Use to test the engine."""
code = add_code
x = Int(2)
wg = WorkGraph(name="test_run_order")
add0 = wg.add_task(ArithmeticAddCalculation, "add0", x=x, y=Int(0), code=code)
add0.set({"metadata.options.sleep": 15})
add1 = wg.add_task(decorated_add, "add1", x=x, y=Int(1), t=Int(1))
add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=x, y=Int(2), code=code)
add2.set({"metadata.options.sleep": 1})
add3 = wg.add_task(decorated_add, "add3", x=x, y=Int(3), t=Int(1))
add4 = wg.add_task(ArithmeticAddCalculation, "add4", x=x, y=Int(4), code=code)
add4.set({"metadata.options.sleep": 1})
add5 = wg.add_task(decorated_add, "add5", x=x, y=Int(5), t=Int(1))
add0 = wg.add_task(ArithmeticAddCalculation, "add0", x=2, y=0, code=code)
add1 = wg.add_task(decorated_add, "add1", x=2, y=1)
add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=2, y=2, code=code)
add3 = wg.add_task(decorated_add, "add3", x=2, y=3)
add4 = wg.add_task(ArithmeticAddCalculation, "add4", x=2, y=4, code=code)
add5 = wg.add_task(decorated_add, "add5", x=2, y=5)
wg.add_link(add0.outputs["sum"], add2.inputs["x"])
wg.add_link(add1.outputs[0], add3.inputs["x"])
wg.add_link(add3.outputs[0], add4.inputs["x"])
Expand Down
3 changes: 0 additions & 3 deletions tests/test_calcjob.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
from aiida_workgraph import WorkGraph
import os


@pytest.mark.usefixtures("started_daemon_client")
Expand All @@ -9,6 +8,4 @@ def test_submit(wg_calcjob: WorkGraph) -> None:
wg = wg_calcjob
wg.name = "test_submit_calcjob"
wg.submit(wait=True)
os.system("verdi process list -a")
os.system(f"verdi process report {wg.pk}")
assert wg.tasks["add2"].outputs["sum"].value == 9
23 changes: 12 additions & 11 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import time
import pytest
from aiida_workgraph import WorkGraph
from aiida.cmdline.utils.common import get_workchain_report


@pytest.mark.usefixtures("started_daemon_client")
def test_run_order(wg_engine: WorkGraph) -> None:
def test_run_order(decorated_add) -> None:
"""Test the order.
Tasks should run in parallel and only depend on the input tasks."""
wg = wg_engine
wg = WorkGraph(name="test_run_order")
wg.add_task(decorated_add, "add0", x=2, y=0)
wg.add_task(decorated_add, "add1", x=2, y=1)
wg.submit(wait=True)
wg.tasks["add2"].ctime < wg.tasks["add4"].ctime
report = get_workchain_report(wg.process, "REPORT")
assert "tasks ready to run: add0,add1" in report


@pytest.mark.skip(reason="The test is not stable.")
Expand All @@ -28,23 +32,20 @@ def test_reset_node(wg_engine: WorkGraph) -> None:
assert len(wg.process.base.extras.get("_workgraph_queue")) == 1


@pytest.mark.usefixtures("started_daemon_client")
def test_max_number_jobs(add_code) -> None:
from aiida_workgraph import WorkGraph
from aiida.orm import Int
from aiida.calculations.arithmetic.add import ArithmeticAddCalculation

wg = WorkGraph("test_max_number_jobs")
N = 9
N = 3
# Create N nodes
for i in range(N):
temp = wg.add_task(
wg.add_task(
ArithmeticAddCalculation, name=f"add{i}", x=Int(1), y=Int(1), code=add_code
)
# Set a sleep option for each job (e.g., 2 seconds per job)
temp.set({"metadata.options.sleep": 1})

# Set the maximum number of running jobs inside the WorkGraph
wg.max_number_jobs = 3
wg.max_number_jobs = 2
wg.submit(wait=True, timeout=100)
wg.tasks["add1"].ctime < wg.tasks["add8"].ctime
report = get_workchain_report(wg.process, "REPORT")
assert "tasks ready to run: add2" in report
26 changes: 2 additions & 24 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,28 +98,6 @@ def parser(self, dirpath):
{"identifier": "Any", "name": "result"}
], # add a "result" output socket from the parser
)
# echo result + y expression
job3 = wg.add_task(
"ShellJob",
name="job3",
command="echo",
arguments=["{result}", "*", "{z}"],
nodes={"result": job2.outputs["result"], "z": Int(4)},
)
# bc command to calculate the expression
job4 = wg.add_task(
"ShellJob",
name="job4",
command="bc",
arguments=["{expression}"],
nodes={"expression": job3.outputs["stdout"]},
parser=PickledData(parser),
parser_outputs=[
{"identifier": "Any", "name": "result"}
], # add a "result" output socket from the parser
)
# there is a bug in aiida-shell, the following line will raise an error
# https://github.com/sphuber/aiida-shell/issues/91
# wg.submit(wait=True, timeout=200)

wg.run()
assert job4.outputs["result"].value.value == 20
assert job2.outputs["result"].value.value == 5
8 changes: 4 additions & 4 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def test_build_task_from_workgraph(wg_calcfunction, decorated_add):
wg_task = wg.add_task(wg_calcfunction, name="wg_calcfunction")
wg.add_task(decorated_add, name="add2", y=3)
wg.add_link(add1_task.outputs["result"], wg_task.inputs["sumdiff1.x"])
wg.add_link(wg_task.outputs["sumdiff3.sum"], wg.tasks["add2"].inputs["x"])
assert len(wg_task.inputs) == 15
assert len(wg_task.outputs) == 13
wg.add_link(wg_task.outputs["sumdiff2.sum"], wg.tasks["add2"].inputs["x"])
assert len(wg_task.inputs) == 7
assert len(wg_task.outputs) == 8
wg.submit(wait=True)
assert wg.tasks["add2"].outputs["result"].value.value == 20
assert wg.tasks["add2"].outputs["result"].value.value == 14
53 changes: 12 additions & 41 deletions tests/test_while.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_while_task(decorated_add, decorated_multiply, decorated_compare):
# update the context variable
multiply1.set_context({"result": "n"})
compare1 = wg.add_task(
decorated_compare, name="compare1", x=multiply1.outputs["result"], y=50
decorated_compare, name="compare1", x=multiply1.outputs["result"], y=30
)
compare1.set_context({"result": "should_run"})
wg.add_task(
Expand All @@ -33,67 +33,38 @@ def test_while_task(decorated_add, decorated_multiply, decorated_compare):
add3 = wg.add_task(decorated_add, name="add3", x=1, y=1)
wg.add_link(multiply1.outputs["result"], add3.inputs["x"])
wg.submit(wait=True, timeout=100)
assert wg.tasks["add3"].outputs["result"].value == 63
assert wg.tasks["add3"].outputs["result"].value == 31


@pytest.mark.usefixtures("started_daemon_client")
def test_while(decorated_add, decorated_multiply, decorated_compare):
def test_while_workgraph(decorated_add, decorated_multiply, decorated_compare):
# Create a WorkGraph will repeat itself based on the conditions
wg = WorkGraph("while_workgraph")
wg.workgraph_type = "WHILE"
wg.conditions = ["compare1.result"]
wg.context = {"n": 1}
wg.max_iteration = 10
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=50)
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=20)
multiply1 = wg.add_task(
decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2)
)
add1 = wg.add_task(decorated_add, name="add1", y=3)
add1.set_context({"result": "n"})
wg.add_link(multiply1.outputs["result"], add1.inputs["x"])
wg.submit(wait=True, timeout=100)
assert wg.execution_count == 4
assert wg.tasks["add1"].outputs["result"].value == 61
assert wg.execution_count == 3
assert wg.tasks["add1"].outputs["result"].value == 29


@pytest.mark.usefixtures("started_daemon_client")
def test_while_graph_builder(decorated_add, decorated_multiply, decorated_compare):
# Create a WorkGraph will repeat itself based on the conditions
@task.graph_builder(outputs=[{"name": "result", "from": "context.n"}])
def my_while(n=0, limit=100):
wg = WorkGraph("while_workgraph")
wg.workgraph_type = "WHILE"
wg.conditions = ["compare1.result"]
wg.context = {"n": n}
wg.max_iteration = 10
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=orm.Int(limit))
multiply1 = wg.add_task(
decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2)
)
add1 = wg.add_task(decorated_add, name="add1", y=3)
add1.set_context({"result": "n"})
wg.add_link(multiply1.outputs["result"], add1.inputs["x"])
return wg

# -----------------------------------------
wg = WorkGraph("while")
add1 = wg.add_task(decorated_add, name="add1", x=orm.Int(25), y=orm.Int(25))
my_while1 = wg.add_task(my_while, n=orm.Int(1))
add2 = wg.add_task(decorated_add, name="add2", y=orm.Int(2))
wg.add_link(add1.outputs["result"], my_while1.inputs["limit"])
wg.add_link(my_while1.outputs["result"], add2.inputs["x"])
wg.submit(wait=True, timeout=100)
assert add2.outputs["result"].value == 63
assert my_while1.node.outputs.execution_count == 4
assert my_while1.outputs["result"].value == 61

"""Test the while WorkGraph in graph builder.
Also test the max_iteration parameter."""

def test_while_max_iteration(decorated_add, decorated_multiply, decorated_compare):
# Create a WorkGraph will repeat itself based on the conditions
@task.graph_builder(outputs=[{"name": "result", "from": "context.n"}])
def my_while(n=0, limit=100):
wg = WorkGraph("while_workgraph")
wg.workgraph_type = "WHILE"
wg.max_iteration = 3
wg.max_iteration = 2
wg.conditions = ["compare1.result"]
wg.context = {"n": n}
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=orm.Int(limit))
Expand All @@ -113,5 +84,5 @@ def my_while(n=0, limit=100):
wg.add_link(add1.outputs["result"], my_while1.inputs["limit"])
wg.add_link(my_while1.outputs["result"], add2.inputs["x"])
wg.submit(wait=True, timeout=100)
assert add2.outputs["result"].value < 63
assert my_while1.node.outputs.execution_count == 3
assert add2.outputs["result"].value < 31
assert my_while1.node.outputs.execution_count == 2
Loading
Loading