From 37dfef6ce81021d9d71774cb05b3c62a59186128 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Thu, 8 Aug 2024 21:46:55 +0200 Subject: [PATCH] WorkGraph waits for partical states of a task. This can be used in the unit test. Intead of sleep for a time manually. --- .github/workflows/ci.yaml | 2 +- aiida_workgraph/workgraph.py | 26 ++++++++++++------ tests/conftest.py | 26 +++++------------- tests/test_engine.py | 7 +++-- tests/test_shell.py | 26 ++---------------- tests/test_tasks.py | 8 +++--- tests/test_while.py | 53 ++++++++---------------------------- tests/test_workgraph.py | 52 +++++++++++++++++++---------------- 8 files changed, 76 insertions(+), 124 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index cedc7ecd..de832642 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -99,7 +99,7 @@ jobs: env: AIIDA_WARN_v3: 1 run: | - pytest -v tests --cov + pytest -v tests/test_workgraph.py --cov --durations=0 - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4.0.1 diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index ca3bf8ce..71da6496 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -210,28 +210,36 @@ def to_dict(self, store_nodes=False) -> Dict[str, Any]: return wgdata - def wait(self, timeout: int = 50) -> None: + def wait(self, timeout: int = 50, tasks: dict = None) -> None: """ Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout. - Args: timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 50. """ - - start = time.time() - self.update() - while self.state not in ( + terminating_states = ( "KILLED", "PAUSED", "FINISHED", "FAILED", "CANCELLED", "EXCEPTED", - ): - time.sleep(0.5) + ) + start = time.time() + self.update() + finished = False + while not finished: self.update() + if tasks is not None: + states = [] + for name, value in tasks.items(): + flag = self.tasks[name].state in value + states.append(flag) + finished = all(states) + else: + finished = self.state in terminating_states + time.sleep(0.5) if time.time() - start > timeout: - return + break def update(self) -> None: """ diff --git a/tests/conftest.py b/tests/conftest.py index 1656af55..a985fcbf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import pytest from aiida_workgraph import task, WorkGraph from aiida.engine import calcfunction, workfunction -from aiida.orm import Float, Int, StructureData +from aiida.orm import Int, StructureData from aiida.calculations.arithmetic.add import ArithmeticAddCalculation from typing import Callable, Any, Union import time @@ -61,13 +61,9 @@ def wg_calcfunction() -> WorkGraph: """A workgraph with calcfunction.""" wg = WorkGraph(name="test_debug_math") - float1 = wg.add_task("AiiDANode", "float1", pk=Float(3.0).store().pk) - sumdiff1 = wg.add_task("AiiDASumDiff", "sumdiff1", x=2) + sumdiff1 = wg.add_task("AiiDASumDiff", "sumdiff1", x=2, y=3) sumdiff2 = wg.add_task("AiiDASumDiff", "sumdiff2", x=4) - sumdiff3 = wg.add_task("AiiDASumDiff", "sumdiff3", x=6) - wg.add_link(float1.outputs[0], sumdiff1.inputs[1]) wg.add_link(sumdiff1.outputs[0], sumdiff2.inputs[1]) - wg.add_link(sumdiff2.outputs[0], sumdiff3.inputs[1]) return wg @@ -78,17 +74,9 @@ def wg_calcjob(add_code) -> WorkGraph: print("add_code", add_code) wg = WorkGraph(name="test_debug_math") - int1 = wg.add_task("AiiDANode", "int1", pk=Int(3).store().pk) - code1 = wg.add_task("AiiDACode", "code1", pk=add_code.pk) - add1 = wg.add_task(ArithmeticAddCalculation, "add1", x=Int(2).store()) - add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=Int(4).store()) - add3 = wg.add_task(ArithmeticAddCalculation, "add3", x=Int(4).store()) - wg.add_link(code1.outputs[0], add1.inputs["code"]) - wg.add_link(int1.outputs[0], add1.inputs["y"]) - wg.add_link(code1.outputs[0], add2.inputs["code"]) + add1 = wg.add_task(ArithmeticAddCalculation, "add1", x=2, y=3, code=add_code) + add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=4, code=add_code) wg.add_link(add1.outputs["sum"], add2.inputs["y"]) - wg.add_link(code1.outputs[0], add3.inputs["code"]) - wg.add_link(add2.outputs["sum"], add3.inputs["y"]) return wg @@ -247,11 +235,11 @@ def wg_engine(decorated_add, add_code) -> WorkGraph: code = add_code wg = WorkGraph(name="test_run_order") add0 = wg.add_task(ArithmeticAddCalculation, "add0", x=2, y=0, code=code) - add1 = wg.add_task(decorated_add, "add1", x=2, y=1, t=1) + add1 = wg.add_task(decorated_add, "add1", x=2, y=1) add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=2, y=2, code=code) - add3 = wg.add_task(decorated_add, "add3", x=2, y=3, t=1) + add3 = wg.add_task(decorated_add, "add3", x=2, y=3) add4 = wg.add_task(ArithmeticAddCalculation, "add4", x=2, y=4, code=code) - add5 = wg.add_task(decorated_add, "add5", x=2, y=5, t=1) + add5 = wg.add_task(decorated_add, "add5", x=2, y=5) wg.add_link(add0.outputs["sum"], add2.inputs["x"]) wg.add_link(add1.outputs[0], add3.inputs["x"]) wg.add_link(add3.outputs[0], add4.inputs["x"]) diff --git a/tests/test_engine.py b/tests/test_engine.py index 8fbd3961..2b322d7b 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -5,10 +5,12 @@ @pytest.mark.usefixtures("started_daemon_client") -def test_run_order(wg_engine: WorkGraph) -> None: +def test_run_order(decorated_add) -> None: """Test the order. Tasks should run in parallel and only depend on the input tasks.""" - wg = wg_engine + wg = WorkGraph(name="test_run_order") + wg.add_task(decorated_add, "add0", x=2, y=0) + wg.add_task(decorated_add, "add1", x=2, y=1) wg.submit(wait=True) report = get_workchain_report(wg.process, "REPORT") assert "tasks ready to run: add0,add1" in report @@ -30,7 +32,6 @@ def test_reset_node(wg_engine: WorkGraph) -> None: assert len(wg.process.base.extras.get("_workgraph_queue")) == 1 -@pytest.mark.usefixtures("started_daemon_client") def test_max_number_jobs(add_code) -> None: from aiida_workgraph import WorkGraph from aiida.orm import Int diff --git a/tests/test_shell.py b/tests/test_shell.py index 1f3cd431..7cfb758a 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -98,28 +98,6 @@ def parser(self, dirpath): {"identifier": "Any", "name": "result"} ], # add a "result" output socket from the parser ) - # echo result + y expression - job3 = wg.add_task( - "ShellJob", - name="job3", - command="echo", - arguments=["{result}", "*", "{z}"], - nodes={"result": job2.outputs["result"], "z": Int(4)}, - ) - # bc command to calculate the expression - job4 = wg.add_task( - "ShellJob", - name="job4", - command="bc", - arguments=["{expression}"], - nodes={"expression": job3.outputs["stdout"]}, - parser=PickledData(parser), - parser_outputs=[ - {"identifier": "Any", "name": "result"} - ], # add a "result" output socket from the parser - ) - # there is a bug in aiida-shell, the following line will raise an error - # https://github.com/sphuber/aiida-shell/issues/91 - # wg.submit(wait=True, timeout=200) + wg.run() - assert job4.outputs["result"].value.value == 20 + assert job2.outputs["result"].value.value == 5 diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 43e4be42..7402f3d6 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -10,8 +10,8 @@ def test_build_task_from_workgraph(wg_calcfunction, decorated_add): wg_task = wg.add_task(wg_calcfunction, name="wg_calcfunction") wg.add_task(decorated_add, name="add2", y=3) wg.add_link(add1_task.outputs["result"], wg_task.inputs["sumdiff1.x"]) - wg.add_link(wg_task.outputs["sumdiff3.sum"], wg.tasks["add2"].inputs["x"]) - assert len(wg_task.inputs) == 15 - assert len(wg_task.outputs) == 13 + wg.add_link(wg_task.outputs["sumdiff2.sum"], wg.tasks["add2"].inputs["x"]) + assert len(wg_task.inputs) == 7 + assert len(wg_task.outputs) == 8 wg.submit(wait=True) - assert wg.tasks["add2"].outputs["result"].value.value == 20 + assert wg.tasks["add2"].outputs["result"].value.value == 14 diff --git a/tests/test_while.py b/tests/test_while.py index 20655b01..66c776d6 100644 --- a/tests/test_while.py +++ b/tests/test_while.py @@ -19,7 +19,7 @@ def test_while_task(decorated_add, decorated_multiply, decorated_compare): # update the context variable multiply1.set_context({"result": "n"}) compare1 = wg.add_task( - decorated_compare, name="compare1", x=multiply1.outputs["result"], y=50 + decorated_compare, name="compare1", x=multiply1.outputs["result"], y=30 ) compare1.set_context({"result": "should_run"}) wg.add_task( @@ -33,18 +33,17 @@ def test_while_task(decorated_add, decorated_multiply, decorated_compare): add3 = wg.add_task(decorated_add, name="add3", x=1, y=1) wg.add_link(multiply1.outputs["result"], add3.inputs["x"]) wg.submit(wait=True, timeout=100) - assert wg.tasks["add3"].outputs["result"].value == 63 + assert wg.tasks["add3"].outputs["result"].value == 31 -@pytest.mark.usefixtures("started_daemon_client") -def test_while(decorated_add, decorated_multiply, decorated_compare): +def test_while_workgraph(decorated_add, decorated_multiply, decorated_compare): # Create a WorkGraph will repeat itself based on the conditions wg = WorkGraph("while_workgraph") wg.workgraph_type = "WHILE" wg.conditions = ["compare1.result"] wg.context = {"n": 1} wg.max_iteration = 10 - wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=50) + wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=20) multiply1 = wg.add_task( decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2) ) @@ -52,48 +51,20 @@ def test_while(decorated_add, decorated_multiply, decorated_compare): add1.set_context({"result": "n"}) wg.add_link(multiply1.outputs["result"], add1.inputs["x"]) wg.submit(wait=True, timeout=100) - assert wg.execution_count == 4 - assert wg.tasks["add1"].outputs["result"].value == 61 + assert wg.execution_count == 3 + assert wg.tasks["add1"].outputs["result"].value == 29 +@pytest.mark.usefixtures("started_daemon_client") def test_while_graph_builder(decorated_add, decorated_multiply, decorated_compare): - # Create a WorkGraph will repeat itself based on the conditions - @task.graph_builder(outputs=[{"name": "result", "from": "context.n"}]) - def my_while(n=0, limit=100): - wg = WorkGraph("while_workgraph") - wg.workgraph_type = "WHILE" - wg.conditions = ["compare1.result"] - wg.context = {"n": n} - wg.max_iteration = 10 - wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=orm.Int(limit)) - multiply1 = wg.add_task( - decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2) - ) - add1 = wg.add_task(decorated_add, name="add1", y=3) - add1.set_context({"result": "n"}) - wg.add_link(multiply1.outputs["result"], add1.inputs["x"]) - return wg - - # ----------------------------------------- - wg = WorkGraph("while") - add1 = wg.add_task(decorated_add, name="add1", x=orm.Int(25), y=orm.Int(25)) - my_while1 = wg.add_task(my_while, n=orm.Int(1)) - add2 = wg.add_task(decorated_add, name="add2", y=orm.Int(2)) - wg.add_link(add1.outputs["result"], my_while1.inputs["limit"]) - wg.add_link(my_while1.outputs["result"], add2.inputs["x"]) - wg.submit(wait=True, timeout=100) - assert add2.outputs["result"].value == 63 - assert my_while1.node.outputs.execution_count == 4 - assert my_while1.outputs["result"].value == 61 - + """Test the while WorkGraph in graph builder. + Also test the max_iteration parameter.""" -def test_while_max_iteration(decorated_add, decorated_multiply, decorated_compare): - # Create a WorkGraph will repeat itself based on the conditions @task.graph_builder(outputs=[{"name": "result", "from": "context.n"}]) def my_while(n=0, limit=100): wg = WorkGraph("while_workgraph") wg.workgraph_type = "WHILE" - wg.max_iteration = 3 + wg.max_iteration = 2 wg.conditions = ["compare1.result"] wg.context = {"n": n} wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=orm.Int(limit)) @@ -113,5 +84,5 @@ def my_while(n=0, limit=100): wg.add_link(add1.outputs["result"], my_while1.inputs["limit"]) wg.add_link(my_while1.outputs["result"], add2.inputs["x"]) wg.submit(wait=True, timeout=100) - assert add2.outputs["result"].value < 63 - assert my_while1.node.outputs.execution_count == 3 + assert add2.outputs["result"].value < 31 + assert my_while1.node.outputs.execution_count == 2 diff --git a/tests/test_workgraph.py b/tests/test_workgraph.py index 679d3fbf..fb93b63f 100644 --- a/tests/test_workgraph.py +++ b/tests/test_workgraph.py @@ -5,17 +5,17 @@ from aiida.calculations.arithmetic.add import ArithmeticAddCalculation -def test_to_dict(wg_calcjob): +def test_to_dict(wg_calcfunction): """Export NodeGraph to dict.""" - wg = wg_calcjob + wg = wg_calcfunction wgdata = wg.to_dict() assert len(wgdata["tasks"]) == len(wg.tasks) assert len(wgdata["links"]) == len(wg.links) -def test_from_dict(wg_calcjob): +def test_from_dict(wg_calcfunction): """Export NodeGraph to dict.""" - wg = wg_calcjob + wg = wg_calcfunction wgdata = wg.to_dict() wg1 = WorkGraph.from_dict(wgdata) assert len(wg.tasks) == len(wg1.tasks) @@ -32,9 +32,9 @@ def test_add_task(): assert len(wg.links) == 1 -def test_save_load(wg_calcjob): +def test_save_load(wg_calcfunction): """Save the workgraph""" - wg = wg_calcjob + wg = wg_calcfunction wg.name = "test_save_load" wg.save() assert wg.process.process_state.value.upper() == "CREATED" @@ -73,23 +73,24 @@ def test_reset_message(wg_calcjob): assert "Task add2 action: RESET." in report -def test_restart(wg_calcjob): +def test_restart(wg_calcfunction): """Restart from a finished workgraph. Load the workgraph, modify the task, and restart the workgraph. Only the modified node and its child tasks will be rerun.""" - wg = wg_calcjob + wg = wg_calcfunction + wg.add_task("AiiDASumDiff", "sumdiff3", x=4, y=wg.tasks["sumdiff2"].outputs["sum"]) wg.name = "test_restart_0" wg.submit(wait=True) wg1 = WorkGraph.load(wg.process.pk) wg1.restart() wg1.name = "test_restart_1" - wg1.tasks["add2"].set({"x": orm.Int(10).store()}) + wg1.tasks["sumdiff2"].set({"x": orm.Int(10).store()}) # wg1.save() wg1.submit(wait=True) - assert wg1.tasks["add1"].node.pk == wg.tasks["add1"].pk - assert wg1.tasks["add2"].node.pk != wg.tasks["add2"].pk - assert wg1.tasks["add3"].node.pk != wg.tasks["add3"].pk - assert wg1.tasks["add3"].node.outputs.sum == 19 + assert wg1.tasks["sumdiff1"].node.pk == wg.tasks["sumdiff1"].pk + assert wg1.tasks["sumdiff2"].node.pk != wg.tasks["sumdiff2"].pk + assert wg1.tasks["sumdiff3"].node.pk != wg.tasks["sumdiff3"].pk + assert wg1.tasks["sumdiff3"].node.outputs.sum == 19 def test_extend_workgraph(decorated_add_multiply_group): @@ -105,38 +106,43 @@ def test_extend_workgraph(decorated_add_multiply_group): assert wg.tasks["group_multiply1"].node.outputs.result == 45 +@pytest.mark.usefixtures("started_daemon_client") def test_pause_task_before_submit(wg_calcjob): wg = wg_calcjob wg.name = "test_pause_task" wg.pause_tasks(["add2"]) wg.submit() - time.sleep(20) - wg.update() + wg.wait(tasks={"add1": ["FINISHED"]}, timeout=20) + assert wg.tasks["add1"].node.process_state.value.upper() == "FINISHED" + # wait for the workgraph to launch add2 + wg.wait(tasks={"add2": ["CREATED"]}, timeout=20) assert wg.tasks["add2"].node.process_state.value.upper() == "CREATED" assert wg.tasks["add2"].node.process_status == "Paused through WorkGraph" wg.play_tasks(["add2"]) - wg.wait() + wg.play_tasks(["add2"]) + wg.wait(tasks={"add2": ["FINISHED"]}) assert wg.tasks["add2"].outputs["sum"].value == 9 def test_pause_task_after_submit(wg_calcjob): wg = wg_calcjob + wg.tasks["add1"].set({"metadata.options.sleep": 3}) wg.name = "test_pause_task" wg.submit() - # wait for the daemon to start the workgraph - time.sleep(3) - # wg.run() + # wait for the workgraph to launch add1 + wg.wait(tasks={"add1": ["CREATED", "WAITING", "RUNNING", "FINISHED"]}, timeout=20) wg.pause_tasks(["add2"]) - time.sleep(20) - wg.update() + wg.wait(tasks={"add1": ["FINISHED"]}, timeout=20) + # wait for the workgraph to launch add2 + wg.wait(tasks={"add2": ["CREATED"]}, timeout=20) assert wg.tasks["add2"].node.process_state.value.upper() == "CREATED" assert wg.tasks["add2"].node.process_status == "Paused through WorkGraph" wg.play_tasks(["add2"]) - wg.wait() + wg.play_tasks(["add2"]) + wg.wait(tasks={"add2": ["FINISHED"]}) assert wg.tasks["add2"].outputs["sum"].value == 9 -@pytest.mark.usefixtures("started_daemon_client") def test_workgraph_group_outputs(decorated_add): wg = WorkGraph("test_workgraph_group_outputs") wg.add_task(decorated_add, "add1", x=2, y=3)