From 7f87a26a2828d8e12ec78f725f46f10b94c2dd94 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Mon, 2 Dec 2024 15:00:40 +0100 Subject: [PATCH] Improve test coverage (#372) * Add test for the normal task, vector socket, get_or_create_code, test_get_parent_workgraphs, test_generate_node_graph, widget * Remove unused code * fix max number of running processes * fix organize_nested_inputs * remove append in the awaitable manager * increase check interval to fix unstable `play` and `pause` test --- .github/workflows/ci.yaml | 3 + .gitignore | 3 +- aiida_workgraph/engine/awaitable_manager.py | 24 +--- aiida_workgraph/engine/task_manager.py | 9 +- aiida_workgraph/engine/utils.py | 4 +- aiida_workgraph/engine/workgraph.py | 7 - aiida_workgraph/executors/builtins.py | 33 ----- aiida_workgraph/executors/test.py | 9 -- aiida_workgraph/property.py | 11 -- aiida_workgraph/socket.py | 11 -- aiida_workgraph/tasks/builtins.py | 22 --- aiida_workgraph/tasks/test.py | 24 ---- aiida_workgraph/utils/__init__.py | 147 ++++++++------------ aiida_workgraph/utils/control.py | 18 +-- aiida_workgraph/web/__init__.py | 0 aiida_workgraph/web/backend/__init__.py | 0 aiida_workgraph/workgraph.py | 32 ++--- pyproject.toml | 1 - tests/__init__.py | 0 tests/conftest.py | 17 ++- tests/test_action.py | 59 ++++++++ tests/test_engine.py | 4 +- tests/test_socket.py | 16 +++ tests/test_tasks.py | 18 ++- tests/test_utils.py | 55 +++++++- tests/test_workgraph.py | 73 ++++------ tests/utils/test_workchain.py | 55 ++++++++ tests/widget/test_widget.py | 32 +++++ 28 files changed, 367 insertions(+), 320 deletions(-) create mode 100644 aiida_workgraph/web/__init__.py create mode 100644 aiida_workgraph/web/backend/__init__.py create mode 100644 tests/__init__.py create mode 100644 tests/test_action.py create mode 100644 tests/utils/test_workchain.py create mode 100644 tests/widget/test_widget.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a2ca7596..440f0579 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -79,6 +79,9 @@ jobs: playwright install pip list + - name: Install system dependencies + run: sudo apt update && sudo apt install --no-install-recommends graphviz + - name: Create AiiDA profile run: verdi setup -n --config .github/config/profile.yaml diff --git a/.gitignore b/.gitignore index 3aa9dd8a..c45e6652 100644 --- a/.gitignore +++ b/.gitignore @@ -135,4 +135,5 @@ dmypy.json tests/work /tests/**/*.png /tests/**/*txt -.vscode/ +/tests/**/*html +.vscode diff --git a/aiida_workgraph/engine/awaitable_manager.py b/aiida_workgraph/engine/awaitable_manager.py index 414933a5..25049e7f 100644 --- a/aiida_workgraph/engine/awaitable_manager.py +++ b/aiida_workgraph/engine/awaitable_manager.py @@ -40,13 +40,9 @@ def insert_awaitable(self, awaitable: Awaitable) -> None: ctx, key = self.ctx_manager.resolve_nested_context(awaitable.key) # Already assign the awaitable itself to the location in the context container where it is supposed to end up - # once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the - # order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the - # awaitable as a placeholder, in the `_resolve_awaitable`, it can be found and replaced by the resolved value. + # once it is resolved. if awaitable.action == AwaitableAction.ASSIGN: ctx[key] = awaitable - elif awaitable.action == AwaitableAction.APPEND: - ctx.setdefault(key, []).append(awaitable) else: raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") @@ -67,26 +63,12 @@ def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None: if awaitable.action == AwaitableAction.ASSIGN: ctx[key] = value - elif awaitable.action == AwaitableAction.APPEND: - # Find the same awaitable inserted in the context - container = ctx[key] - for index, placeholder in enumerate(container): - if ( - isinstance(placeholder, Awaitable) - and placeholder.pk == awaitable.pk - ): - container[index] = value - break - else: - raise AssertionError( - f"Awaitable `{awaitable.pk} was not in `ctx.{awaitable.key}`" - ) else: raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") awaitable.resolved = True - # remove awaitabble from the list - self._awaitables = [a for a in self._awaitables if a.pk != awaitable.pk] + # remove awaitabble from the list, and use the same list reference + self._awaitables[:] = [a for a in self._awaitables if a.pk != awaitable.pk] if not self.process.has_terminated(): # the process may be terminated, for example, if the process was killed or excepted diff --git a/aiida_workgraph/engine/task_manager.py b/aiida_workgraph/engine/task_manager.py index c8f5153a..d0175dfe 100644 --- a/aiida_workgraph/engine/task_manager.py +++ b/aiida_workgraph/engine/task_manager.py @@ -187,7 +187,6 @@ def is_workgraph_finished(self) -> bool: def continue_workgraph(self) -> None: self.process.report("Continue workgraph.") - # self.update_workgraph_from_base() task_to_run = [] for name, task in self.ctx._tasks.items(): # update task state @@ -734,9 +733,11 @@ def update_normal_task_state(self, name, results, success=True): if success: task = self.ctx._tasks[name] if isinstance(results, tuple): - if len(task["outputs"]) != len(results): - return self.exit_codes.OUTPUS_NOT_MATCH_RESULTS - output_names = get_sorted_names(task["outputs"]) + # there are two built-in outputs: _wait and _outputs + if len(task["outputs"]) - 2 != len(results): + self.on_task_failed(name) + return self.process.exit_codes.OUTPUS_NOT_MATCH_RESULTS + output_names = get_sorted_names(task["outputs"])[0:-2] for i, output_name in enumerate(output_names): task["results"][output_name] = results[i] elif isinstance(results, dict): diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py index 9d86efbd..519e6bba 100644 --- a/aiida_workgraph/engine/utils.py +++ b/aiida_workgraph/engine/utils.py @@ -4,7 +4,7 @@ def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple: """Prepare the inputs for WorkGraph task""" - from aiida_workgraph.utils import merge_properties, serialize_properties + from aiida_workgraph.utils import organize_nested_inputs, serialize_properties from aiida.orm.utils.serialize import deserialize_unsafe wgdata = deserialize_unsafe(task["executor"]["wgdata"]) @@ -19,7 +19,7 @@ def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple: "value" ] = value # merge the properties - merge_properties(wgdata) + organize_nested_inputs(wgdata) serialize_properties(wgdata) metadata = {"call_link_label": task["name"]} inputs = {"wg": wgdata, "metadata": metadata} diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index f4145a9d..5b79c27a 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -322,13 +322,6 @@ def read_wgdata_from_base(self) -> t.Dict[str, t.Any]: wgdata["context"] = deserialize_unsafe(wgdata["context"]) return wgdata - def update_workgraph_from_base(self) -> None: - """Update the ctx from base.extras.""" - wgdata = self.read_wgdata_from_base() - for name, task in wgdata["tasks"].items(): - task["results"] = self.ctx._tasks[name].get("results") - self.setup_ctx_workgraph(wgdata) - def init_ctx(self, wgdata: t.Dict[str, t.Any]) -> None: """Init the context from the workgraph data.""" from aiida_workgraph.utils import update_nested_dict diff --git a/aiida_workgraph/executors/builtins.py b/aiida_workgraph/executors/builtins.py index 97055bcf..d28e735d 100644 --- a/aiida_workgraph/executors/builtins.py +++ b/aiida_workgraph/executors/builtins.py @@ -1,38 +1,5 @@ -from aiida.engine import WorkChain -from aiida import orm -from aiida.engine.processes.workchains.workchain import WorkChainSpec - - def select(condition, true, false): """Select the data based on the condition.""" if condition: return true return false - - -class GatherWorkChain(WorkChain): - @classmethod - def define(cls, spec: WorkChainSpec) -> None: - """Define the process specification.""" - - super().define(spec) - spec.input_namespace( - "datas", - dynamic=True, - help=('Dynamic namespace for the datas, "{key}" : {Data}".'), - ) - spec.outline( - cls.gather, - ) - spec.output( - "result", - valid_type=orm.List, - required=True, - help="A list of the uuid of the outputs.", - ) - - def gather(self) -> None: - datas = self.inputs.datas.values() - uuids = [data.uuid for data in datas] - # uuids = gather(uuids) - self.out("result", orm.List(uuids).store()) diff --git a/aiida_workgraph/executors/test.py b/aiida_workgraph/executors/test.py index 9bdb8b7f..0c5692da 100644 --- a/aiida_workgraph/executors/test.py +++ b/aiida_workgraph/executors/test.py @@ -14,15 +14,6 @@ def add( return {"sum": x + y} -@calcfunction -def greater( - x: Union[Int, Float], y: Union[Int, Float], t: Union[Int, Float] = 1.0 -) -> Dict[str, bool]: - """Compare node.""" - time.sleep(t.value) - return {"result": x > y} - - @calcfunction def sum_diff( x: Union[Int, Float], y: Union[Int, Float], t: Union[Int, Float] = 1.0 diff --git a/aiida_workgraph/property.py b/aiida_workgraph/property.py index 4098e729..7c41419d 100644 --- a/aiida_workgraph/property.py +++ b/aiida_workgraph/property.py @@ -55,15 +55,4 @@ def set_value(self, value: Any) -> None: else: raise Exception("{} is not an {}.".format(value, DataClass.__name__)) - def get_serialize(self) -> Dict[str, str]: - serialize = {"module": "aiida.orm.utils.serialize", "name": "serialize"} - return serialize - - def get_deserialize(self) -> Dict[str, str]: - deserialize = { - "module": "aiida.orm.utils.serialize", - "name": "deserialize_unsafe", - } - return deserialize - return AiiDATaskProperty diff --git a/aiida_workgraph/socket.py b/aiida_workgraph/socket.py index f1b201a5..ff9f677a 100644 --- a/aiida_workgraph/socket.py +++ b/aiida_workgraph/socket.py @@ -50,15 +50,4 @@ def __init__( super().__init__(name, parent, type, index, uuid=uuid) self.add_property(DataClass, name, **kwargs) - def get_serialize(self) -> dict: - serialize = {"module": "aiida.orm.utils.serialize", "name": "serialize"} - return serialize - - def get_deserialize(self) -> dict: - deserialize = { - "module": "aiida.orm.utils.serialize", - "name": "deserialize_unsafe", - } - return deserialize - return AiiDATaskSocket diff --git a/aiida_workgraph/tasks/builtins.py b/aiida_workgraph/tasks/builtins.py index 4c563a82..9e4c1e5a 100644 --- a/aiida_workgraph/tasks/builtins.py +++ b/aiida_workgraph/tasks/builtins.py @@ -73,28 +73,6 @@ def create_sockets(self) -> None: self.outputs.new("workgraph.any", "_wait") -class Gather(Task): - """Gather""" - - identifier = "workgraph.aiida_gather" - name = "Gather" - node_type = "WORKCHAIN" - catalog = "Control" - - _executor = { - "module": "aiida_workgraph.executors.builtins", - "name": "GatherWorkChain", - } - kwargs = ["datas"] - - def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - inp = self.inputs.new("workgraph.any", "datas") - inp.link_limit = 100000 - self.outputs.new("workgraph.any", "result") - - class SetContext(Task): """SetContext""" diff --git a/aiida_workgraph/tasks/test.py b/aiida_workgraph/tasks/test.py index 5c645a43..991cd690 100644 --- a/aiida_workgraph/tasks/test.py +++ b/aiida_workgraph/tasks/test.py @@ -28,30 +28,6 @@ def create_sockets(self) -> None: self.outputs.new("workgraph.aiida_float", "sum") -class TestGreater(Task): - - identifier: str = "workgraph.test_greater" - name = "TestGreater" - node_type = "CALCFUNCTION" - catalog = "Test" - - _executor = { - "module": "aiida_workgraph.executors.test", - "name": "greater", - } - kwargs = ["x", "y"] - - def create_properties(self) -> None: - pass - - def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("workgraph.aiida_float", "x") - self.inputs.new("workgraph.aiida_float", "y") - self.outputs.new("workgraph.aiida_bool", "result") - - class TestSumDiff(Task): identifier: str = "workgraph.test_sum_diff" diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index bee8cf6a..aacb784f 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -43,9 +43,8 @@ def build_callable(obj: Callable) -> Dict[str, Any]: return executor -def get_sorted_names(data: dict) -> list: +def get_sorted_names(data: dict) -> list[str]: """Get the sorted names from a dictionary.""" - print("data: ", data) sorted_names = [ name for name, _ in sorted( @@ -151,19 +150,21 @@ def get_nested_dict(d: Dict, name: str, **kwargs) -> Any: return current -def merge_dicts(existing: Any, new: Any) -> Any: +def merge_dicts(dict1: Any, dict2: Any) -> Any: """Recursively merges two dictionaries.""" - if isinstance(existing, dict) and isinstance(new, dict): - for k, v in new.items(): - if k in existing and isinstance(existing[k], dict) and isinstance(v, dict): - merge_dicts(existing[k], v) - else: - existing[k] = v - else: - return new + for key, value in dict2.items(): + if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict): + # Recursively merge dictionaries + dict1[key] = merge_dicts(dict1[key], value) + else: + # Overwrite or add the key + dict1[key] = value + return dict1 -def update_nested_dict(d: Optional[Dict[str, Any]], key: str, value: Any) -> None: +def update_nested_dict( + base: Optional[Dict[str, Any]], key_path: str, value: Any +) -> None: """ Update or create a nested dictionary structure based on a dotted key path. @@ -174,59 +175,49 @@ def update_nested_dict(d: Optional[Dict[str, Any]], key: str, value: Any) -> Non If the resulting dictionary is empty, it is set to `None`. Args: - d (Dict[str, Any] | None): The dictionary to update, which can be `None`. + base (Dict[str, Any] | None): The dictionary to update, which can be `None`. If `None`, an empty dictionary will be created. key (str): A dotted key path string representing the nested structure. value (Any): The value to set at the specified key. Example: - d = None - key = "base.pw.parameters" + base = None + key = "scf.pw.parameters" value = 2 After running: update_nested_dict(d, key, value) The result will be: - d = {"base": {"pw": {"parameters": 2}}} + base = {"scf": {"pw": {"parameters": 2}}} Edge Case: If the resulting dictionary is empty after the update, it will be set to `None`. """ - keys = key.split(".") - current = d if d is not None else {} - for k in keys[:-1]: - current = current.setdefault(k, {}) - # Handle merging instead of overwriting - last_key = keys[-1] - if ( - last_key in current - and isinstance(current[last_key], dict) - and isinstance(value, dict) - ): - merge_dicts(current[last_key], value) + if base is None: + base = {} + keys = key_path.split(".") + current_key = keys[0] + if len(keys) == 1: + # Base case: Merge dictionaries or set the value directly. + if isinstance(base.get(current_key), dict) and isinstance(value, dict): + base[current_key] = merge_dicts(base[current_key], value) + else: + base[current_key] = value else: - current[last_key] = value - # if current is empty, set it to None - if not current: - current = None - return current - - -def is_empty(value: Any) -> bool: - """Check if the provided value is an empty collection.""" - import numpy as np + # Recursive case: Ensure the key exists and is a dictionary, then recurse. + if current_key not in base or not isinstance(base[current_key], dict): + base[current_key] = {} + base[current_key] = update_nested_dict( + base[current_key], ".".join(keys[1:]), value + ) - if isinstance(value, np.ndarray): - return value.size == 0 - elif isinstance(value, (dict, list, set, tuple)): - return not value - return False + return base def update_nested_dict_with_special_keys(data: Dict[str, Any]) -> Dict[str, Any]: - """Remove None and empty value""" - # data = {k: v for k, v in data.items() if v is not None and not is_empty(v)} + """Update the nested dictionary with special keys like "base.pw.parameters".""" + # Remove None data = {k: v for k, v in data.items() if v is not None} # special_keys = [k for k in data.keys() if "." in k] @@ -236,16 +227,26 @@ def update_nested_dict_with_special_keys(data: Dict[str, Any]) -> Dict[str, Any] return data -def merge_properties(wgdata: Dict[str, Any]) -> None: +def organize_nested_inputs(wgdata: Dict[str, Any]) -> None: """Merge sub properties to the root properties. - { - "base.pw.parameters": 2, - "base.pw.code": 1, - } - after merge: - {"base": {"pw": {"parameters": 2, - "code": 1}} - So that no "." in the key name. + The sub properties will be se + For example: + task["inputs"]["base"]["property"]["value"] = None + task["inputs"]["base.pw.parameters"]["property"]["value"] = 2 + task["inputs"]["base.pw.code"]["property"]["value"] = 1 + task["inputs"]["metadata"]["property"]["value"] = {"options": {"resources": {"num_cpus": 1}} + task["inputs"]["metadata.options"]["property"]["value"] = {"resources": {"num_machine": 1}} + After organizing: + task["inputs"]["base"]["property"]["value"] = {"base": {"pw": {"parameters": 2, + "code": 1}, + "metadata": {"options": + {"resources": {"num_cpus": 1, + "num_machine": 1}}}}, + } + task["inputs"]["base.pw.parameters"]["property"]["value"] = None + task["inputs"]["base.pw.code"]["property"]["value"] = None + task["inputs"]["metadata"]["property"]["value"] = None + task["inputs"]["metadata.options"]["property"]["value"] = None """ for _, task in wgdata["tasks"].items(): for key, prop in task["properties"].items(): @@ -307,34 +308,6 @@ def generate_node_graph( return g -def build_task_link(wgdata: Dict[str, Any]) -> None: - """Create links for tasks. - Create the links for task inputs using: - 1) workgraph links - 2) if it is a graph builder graph, expose the group inputs and outputs - sockets. - """ - # reset task input links - for name, task in wgdata["tasks"].items(): - for input in task["inputs"]: - input["links"] = [] - for output in task["outputs"]: - output["links"] = [] - for link in wgdata["links"]: - to_socket = [ - socket - for socket in wgdata["tasks"][link["to_node"]]["inputs"] - if socket["name"] == link["to_socket"] - ][0] - from_socket = [ - socket - for socket in wgdata["tasks"][link["from_node"]]["outputs"] - if socket["name"] == link["from_socket"] - ][0] - to_socket["links"].append(link) - from_socket["links"].append(link) - - def get_dict_from_builder(builder: Any) -> Dict: """Transform builder to pure dict.""" from aiida.engine.processes.builder import ProcessBuilderNamespace @@ -345,14 +318,6 @@ def get_dict_from_builder(builder: Any) -> Dict: return builder -def serialize_workgraph_data(wgdata: Dict[str, Any]) -> Dict[str, Any]: - from aiida.orm.utils.serialize import serialize - - for name, task in wgdata["tasks"].items(): - wgdata["tasks"][name] = serialize(task) - wgdata["error_handlers"] = serialize(wgdata["error_handlers"]) - - def get_workgraph_data(process: Union[int, orm.Node]) -> Optional[Dict[str, Any]]: """Get the workgraph data from the process node.""" from aiida.orm.utils.serialize import deserialize_unsafe @@ -380,10 +345,8 @@ def get_parent_workgraphs(pk: int) -> list: node = orm.load_node(pk) parent_workgraphs = [[node.process_label, node.pk]] links = node.base.links.get_incoming(link_type=LinkType.CALL_WORK).all() - print(links) if len(links) > 0: parent_workgraphs.extend(get_parent_workgraphs(links[0].node.pk)) - print(parent_workgraphs) return parent_workgraphs diff --git a/aiida_workgraph/utils/control.py b/aiida_workgraph/utils/control.py index 376f8fc3..72041224 100644 --- a/aiida_workgraph/utils/control.py +++ b/aiida_workgraph/utils/control.py @@ -28,21 +28,11 @@ def get_task_state_info(node, name: str, key: str) -> str: return value -def set_task_state_info(node, name: str, key: str, value: any) -> None: - """Set task state info to base.extras.""" - from aiida.orm.utils.serialize import serialize - - if key == "process": - node.base.extras.set(f"_task_{key}_{name}", serialize(value)) - else: - node.base.extras.set(f"_task_{key}_{name}", value) - - -def pause_tasks(pk: int, tasks: list, timeout: int = 5, wait: bool = False): +def pause_tasks(pk: int, tasks: list[str], timeout: int = 5, wait: bool = False): """Pause task.""" node = orm.load_node(pk) if node.is_finished: - message = "Process is finished. Cannot pause tasks." + message = "WorkGraph is finished. Cannot pause tasks." print(message) return False, message elif node.process_state.value.upper() in [ @@ -70,7 +60,7 @@ def pause_tasks(pk: int, tasks: list, timeout: int = 5, wait: bool = False): def play_tasks(pk: int, tasks: list, timeout: int = 5, wait: bool = False): node = orm.load_node(pk) if node.is_finished: - message = "Process is finished. Cannot pause tasks." + message = "WorkGraph is finished. Cannot kill tasks." print(message) return False, message elif node.process_state.value.upper() in [ @@ -100,7 +90,7 @@ def play_tasks(pk: int, tasks: list, timeout: int = 5, wait: bool = False): def kill_tasks(pk: int, tasks: list, timeout: int = 5, wait: bool = False): node = orm.load_node(pk) if node.is_finished: - message = "Process is finished. Cannot pause tasks." + message = "WorkGraph is finished. Cannot kill tasks." print(message) return False, message elif node.process_state.value.upper() in [ diff --git a/aiida_workgraph/web/__init__.py b/aiida_workgraph/web/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiida_workgraph/web/backend/__init__.py b/aiida_workgraph/web/backend/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index d043f321..194fbf12 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -68,14 +68,16 @@ def tasks(self) -> TaskCollection: """Add alias to `nodes` for WorkGraph""" return self.nodes - def prepare_inputs(self, metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]: + def prepare_inputs( + self, metadata: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: from aiida_workgraph.utils import ( - merge_properties, + organize_nested_inputs, serialize_properties, ) wgdata = self.to_dict() - merge_properties(wgdata) + organize_nested_inputs(wgdata) serialize_properties(wgdata) metadata = metadata or {} inputs = {"wg": wgdata, "metadata": metadata} @@ -114,6 +116,7 @@ def submit( inputs: Optional[Dict[str, Any]] = None, wait: bool = False, timeout: int = 60, + interval: int = 1, metadata: Optional[Dict[str, Any]] = None, ) -> aiida.orm.ProcessNode: """Submit the AiiDA workgraph process and optionally wait for it to finish. @@ -138,7 +141,7 @@ def submit( # as long as we submit the process, it is a new submission, we should set restart_process to None self.restart_process = None if wait: - self.wait(timeout=timeout) + self.wait(timeout=timeout, interval=interval) return self.process def save(self, metadata: Optional[Dict[str, Any]] = None) -> None: @@ -230,7 +233,7 @@ def get_error_handlers(self) -> Dict[str, Any]: task["exit_codes"] = exit_codes return error_handlers - def wait(self, timeout: int = 50, tasks: dict = None) -> None: + def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 1) -> None: """ Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout. Args: @@ -257,7 +260,7 @@ def wait(self, timeout: int = 50, tasks: dict = None) -> None: finished = all(states) else: finished = self.state in terminating_states - time.sleep(0.5) + time.sleep(interval) if time.time() - start > timeout: break @@ -386,21 +389,18 @@ def show(self) -> None: print(tabulate(table, headers=["Name", "PK", "State"])) print("-" * 80) - def pause(self) -> None: - """Pause the workgraph.""" - # from aiida.engine.processes import control - # try: - # control.pause_processes([self.process]) - import os - - os.system("verdi process pause {}".format(self.process.pk)) + # def pause(self) -> None: + # """Pause the workgraph.""" + # from aiida.engine.processes import control + # try: + # control.pause_processes([self.process]) + # except Exception as e: + # print(f"Pause process failed: {e}") def pause_tasks(self, tasks: List[str]) -> None: """Pause the given tasks.""" from aiida_workgraph.utils.control import pause_tasks - self.update() - if self.process is None: for name in tasks: self.tasks[name].action = "PAUSE" diff --git a/pyproject.toml b/pyproject.toml index 18e33905..b586e06d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,7 +107,6 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" "workgraph.aiida_node" = "aiida_workgraph.tasks.builtins:AiiDANode" "workgraph.aiida_code" = "aiida_workgraph.tasks.builtins:AiiDACode" "workgraph.test_add" = "aiida_workgraph.tasks.test:TestAdd" -"workgraph.test_greater" = "aiida_workgraph.tasks.test:TestGreater" "workgraph.test_sum_diff" = "aiida_workgraph.tasks.test:TestSumDiff" "workgraph.test_arithmetic_multiply_add" = "aiida_workgraph.tasks.test:TestArithmeticMultiplyAdd" "workgraph.pythonjob" = "aiida_workgraph.tasks.pythonjob:PythonJob" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index 36b1fe61..31b2b776 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ from aiida.orm import Int, StructureData from aiida.calculations.arithmetic.add import ArithmeticAddCalculation from typing import Callable, Any, Union +from aiida.orm import WorkflowNode import time import os @@ -29,7 +30,9 @@ def fixture_localhost(aiida_localhost): def add_code(fixture_localhost): from aiida.orm import InstalledCode - code = InstalledCode(computer=fixture_localhost, filepath_executable="/bin/bash") + code = InstalledCode( + label="add", computer=fixture_localhost, filepath_executable="/bin/bash" + ) code.store() return code @@ -233,3 +236,15 @@ def wg_engine(decorated_add, add_code) -> WorkGraph: wg.add_link(add2.outputs["sum"], add5.inputs["x"]) wg.add_link(add4.outputs["sum"], add5.inputs["y"]) return wg + + +@pytest.fixture +def finished_process_node(): + """Return a finished process node.""" + + node = WorkflowNode() + node.set_process_state("finished") + node.set_exit_status(0) + node.seal() + node.store() + return node diff --git a/tests/test_action.py b/tests/test_action.py new file mode 100644 index 00000000..82ad8954 --- /dev/null +++ b/tests/test_action.py @@ -0,0 +1,59 @@ +import pytest +import time + + +@pytest.mark.skip(reason="PAUSED state is wrong for the moment.") +def test_pause_play_workgraph(wg_engine): + wg = wg_engine + wg.name = "test_pause_play_workgraph" + wg.submit() + time.sleep(5) + wg.pause() + wg.update() + assert wg.process.process_state.value.upper() == "PAUSED" + + +# @pytest.mark.skip(reason="pause task is not stable for the moment.") +@pytest.mark.usefixtures("started_daemon_client") +def test_pause_play_task(wg_calcjob): + wg = wg_calcjob + wg.name = "test_pause_play_task" + # pause add1 before submit + wg.pause_tasks(["add1"]) + wg.submit() + # wait for the workgraph to launch add1 + wg.wait(tasks={"add1": ["CREATED"]}, timeout=40, interval=5) + assert wg.tasks["add1"].node.process_state.value.upper() == "CREATED" + assert wg.tasks["add1"].node.process_status == "Paused through WorkGraph" + # pause add2 after submit + wg.pause_tasks(["add2"]) + wg.play_tasks(["add1"]) + # wait for the workgraph to launch add2 + wg.wait(tasks={"add2": ["CREATED"]}, timeout=40, interval=5) + assert wg.tasks["add2"].node.process_state.value.upper() == "CREATED" + assert wg.tasks["add2"].node.process_status == "Paused through WorkGraph" + # I disabled the following lines because the test is not stable + # Seems the daemon is not responding to the play signal + wg.play_tasks(["add2"]) + wg.wait(interval=5) + assert wg.tasks["add2"].outputs["sum"].value == 9 + + +def test_pause_play_error_handler(wg_calcjob, finished_process_node): + wg = wg_calcjob + wg.name = "test_pause_play_error_handler" + wg.process = finished_process_node + try: + wg.pause_tasks(["add1"]) + except Exception as e: + assert "WorkGraph is finished. Cannot pause tasks." in str(e) + + try: + wg.play_tasks(["add1"]) + except Exception as e: + assert "WorkGraph is finished. Cannot play tasks." in str(e) + + try: + wg.kill_tasks(["add2"]) + except Exception as e: + assert "WorkGraph is finished. Cannot kill tasks." in str(e) diff --git a/tests/test_engine.py b/tests/test_engine.py index 2b322d7b..04c4fea5 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -32,6 +32,7 @@ def test_reset_node(wg_engine: WorkGraph) -> None: assert len(wg.process.base.extras.get("_workgraph_queue")) == 1 +@pytest.mark.usefixtures("started_daemon_client") def test_max_number_jobs(add_code) -> None: from aiida_workgraph import WorkGraph from aiida.orm import Int @@ -46,6 +47,7 @@ def test_max_number_jobs(add_code) -> None: ) # Set the maximum number of running jobs inside the WorkGraph wg.max_number_jobs = 2 - wg.submit(wait=True, timeout=100) + wg.submit(wait=True, timeout=40) report = get_workchain_report(wg.process, "REPORT") assert "tasks ready to run: add2" in report + wg.tasks["add2"].outputs["sum"].value == 2 diff --git a/tests/test_socket.py b/tests/test_socket.py index fea72d4d..264ecc67 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -36,6 +36,22 @@ def add(x: data_type): add_task.set({"x": "{{variable}}"}) +def test_vector_socket() -> None: + """Test the vector data type.""" + from aiida_workgraph import Task + + t = Task() + t.inputs.new( + "workgraph.aiida_int_vector", + "vector2d", + property_data={"size": 2, "default": [1, 2]}, + ) + try: + t.inputs["vector2d"].value = [1, 2, 3] + except Exception as e: + assert "Invalid size: Expected 2, got 3 instead." in str(e) + + def test_aiida_data_socket() -> None: """Test the mapping of data types to socket types.""" from aiida import orm, load_profile diff --git a/tests/test_tasks.py b/tests/test_tasks.py index d2c5b916..eafc7145 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,9 +1,25 @@ import pytest -from aiida_workgraph import WorkGraph +from aiida_workgraph import WorkGraph, task from typing import Callable from aiida.cmdline.utils.common import get_workchain_report +def test_normal_task(decorated_add) -> None: + """Test a normal task.""" + + @task(outputs=[{"name": "sum"}, {"name": "diff"}]) + def sum_diff(x, y): + return x + y, x - y + + wg = WorkGraph("test_normal_task") + task1 = wg.add_task(sum_diff, name="sum_diff", x=2, y=3) + task2 = wg.add_task( + decorated_add, name="add", x=task1.outputs["sum"], y=task1.outputs["diff"] + ) + wg.run() + assert task2.outputs["result"].value == 4 + + def test_task_collection(decorated_add: Callable) -> None: """Test the TaskCollection class. Since waiting_on and children are TaskCollection instances, we test the waiting_on.""" diff --git a/tests/test_utils.py b/tests/test_utils.py index 8ecaaab6..20eee143 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,5 @@ import pytest - +from aiida import orm from aiida_workgraph.utils import validate_task_inout @@ -56,3 +56,56 @@ def test_validate_task_inout_dict_with_extra_keys(): ] result = validate_task_inout(input_list, "inputs") assert result == input_list + + +def test_get_or_create_code(fixture_localhost): + from aiida_workgraph.utils import get_or_create_code + from aiida.orm import Code + + # create a new code + code1 = get_or_create_code( + computer="localhost", + code_label="test_code", + code_path="/bin/bash", + prepend_text='echo "Hello, World!"', + ) + assert isinstance(code1, Code) + # use already created code + code2 = get_or_create_code( + computer="localhost", + code_label="test_code", + code_path="/bin/bash", + prepend_text='echo "Hello, World!"', + ) + assert code1.uuid == code2.uuid + + +def test_get_parent_workgraphs(): + from aiida.common.links import LinkType + from aiida_workgraph.utils import get_parent_workgraphs + + wn1 = orm.WorkflowNode() + wn2 = orm.WorkflowNode() + wn3 = orm.WorkflowNode() + wn3.base.links.add_incoming(wn2, link_type=LinkType.CALL_WORK, link_label="link") + wn2.base.links.add_incoming(wn1, link_type=LinkType.CALL_WORK, link_label="link") + wn1.store() + wn2.store() + wn3.store() + + parent_workgraphs = get_parent_workgraphs(wn3.pk) + assert len(parent_workgraphs) == 3 + + +def test_generate_node_graph(): + from aiida_workgraph.utils import generate_node_graph + from IPython.display import IFrame + import os + + wn1 = orm.WorkflowNode() + wn1.store() + + graph = generate_node_graph(wn1.pk) + assert isinstance(graph, IFrame) + # check file html/node_graph_{pk}.html is created + assert os.path.isfile(f"html/node_graph_{wn1.pk}.html") diff --git a/tests/test_workgraph.py b/tests/test_workgraph.py index 3a1de9b3..1fa1366b 100644 --- a/tests/test_workgraph.py +++ b/tests/test_workgraph.py @@ -40,16 +40,31 @@ def test_save_load(wg_calcfunction): assert len(wg.tasks) == len(wg2.tasks) -# skip this test -@pytest.mark.skip(reason="PAUSED state is wrong for the moment.") -def test_pause(wg_engine): - wg = wg_engine - wg.name = "test_pause" - wg.submit() - time.sleep(5) - wg.pause() - wg.update() - assert wg.process.process_state.value.upper() == "PAUSED" +def test_organize_nested_inputs(): + """Merge sub properties to the root properties.""" + from .utils.test_workchain import WorkChainWithNestNamespace + + wg = WorkGraph("test_organize_nested_inputs") + task1 = wg.add_task(WorkChainWithNestNamespace, name="task1") + task1.set( + { + "add": {"x": "1"}, + "add.metadata": { + "call_link_label": "nest", + "options": {"resources": {"num_cpus": 1}}, + }, + "add.metadata.options": {"resources": {"num_machines": 1}}, + } + ) + inputs = wg.prepare_inputs() + data = { + "metadata": { + "call_link_label": "nest", + "options": {"resources": {"num_cpus": 1, "num_machines": 1}}, + }, + "x": "1", + } + assert inputs["wg"]["tasks"]["task1"]["inputs"]["add"]["property"]["value"] == data @pytest.mark.usefixtures("started_daemon_client") @@ -114,44 +129,6 @@ def test_extend_workgraph(decorated_add_multiply_group): assert wg.tasks["group_multiply1"].node.outputs.result == 45 -@pytest.mark.usefixtures("started_daemon_client") -def test_pause_task_before_submit(wg_calcjob): - wg = wg_calcjob - wg.name = "test_pause_task" - wg.pause_tasks(["add2"]) - wg.submit() - # wait for the workgraph to launch add2 - wg.wait(tasks={"add2": ["CREATED"]}, timeout=20) - assert wg.tasks["add2"].node.process_state.value.upper() == "CREATED" - assert wg.tasks["add2"].node.process_status == "Paused through WorkGraph" - # I disabled the following lines because the test is not stable - # Seems the daemon is not responding to the play signal - # This should be a problem of AiiDA test fixtures - # wg.play_tasks(["add2"]) - # wg.wait(tasks={"add2": ["FINISHED"]}) - # assert wg.tasks["add2"].outputs["sum"].value == 9 - - -@pytest.mark.skip(reason="pause task is not stable for the moment.") -def test_pause_task_after_submit(wg_calcjob): - wg = wg_calcjob - wg.tasks["add1"].set({"metadata.options.sleep": 5}) - wg.name = "test_pause_task" - wg.submit() - # wait for the workgraph to launch add1 - wg.wait(tasks={"add1": ["CREATED", "WAITING", "RUNNING", "FINISHED"]}, timeout=20) - wg.pause_tasks(["add2"]) - # wait for the workgraph to launch add2 - wg.wait(tasks={"add2": ["CREATED"]}, timeout=20) - assert wg.tasks["add2"].node.process_state.value.upper() == "CREATED" - assert wg.tasks["add2"].node.process_status == "Paused through WorkGraph" - # I disabled the following lines because the test is not stable - # Seems the daemon is not responding to the play signal - # wg.play_tasks(["add2"]) - # wg.wait(tasks={"add2": ["FINISHED"]}) - # assert wg.tasks["add2"].outputs["sum"].value == 9 - - def test_workgraph_group_outputs(decorated_add): wg = WorkGraph("test_workgraph_group_outputs") wg.add_task(decorated_add, "add1", x=2, y=3) diff --git a/tests/utils/test_workchain.py b/tests/utils/test_workchain.py new file mode 100644 index 00000000..80341bb2 --- /dev/null +++ b/tests/utils/test_workchain.py @@ -0,0 +1,55 @@ +from aiida.engine import ToContext, WorkChain +from aiida.workflows.arithmetic.multiply_add import MultiplyAddWorkChain +from aiida.calculations.arithmetic.add import ArithmeticAddCalculation +from aiida.common import AttributeDict +from aiida.orm import Int + + +class WorkChainWithNestNamespace(WorkChain): + """WorkChain to add two numbers.""" + + @classmethod + def define(cls, spec): + """Specify inputs and outputs.""" + super().define(spec) + spec.expose_inputs( + ArithmeticAddCalculation, + namespace="add", + ) + spec.expose_inputs(MultiplyAddWorkChain, namespace="multiply_add") + spec.outline( + cls.add, + cls.multiply_add, + cls.validate_result, + cls.result, + ) + spec.output("result", valid_type=Int) + spec.expose_outputs(MultiplyAddWorkChain, namespace="multiply_add") + spec.exit_code( + 400, "ERROR_NEGATIVE_NUMBER", message="The result is a negative number." + ) + + def add(self): + """Add two numbers using the `ArithmeticAddCalculation` calculation job plugin.""" + inputs = AttributeDict(self.exposed_inputs(ArithmeticAddCalculation, "add")) + future = self.submit(ArithmeticAddCalculation, **inputs) + self.report(f"Submitted the `ArithmeticAddCalculation`: {future}") + return ToContext(addition=future) + + def multiply_add(self): + """Multiply and add two numbers using the `MultiplyAddWorkChain` workchain.""" + inputs = self.exposed_inputs(MultiplyAddWorkChain, "multiply_add") + inputs["z"] = self.ctx.addition.outputs.sum + future = self.submit(MultiplyAddWorkChain, **inputs) + self.report(f"Submitted the `MultiplyAddWorkChain`: {future}") + return ToContext(multiply_add=future) + + def validate_result(self): + """Make sure the result is not negative.""" + result = self.ctx.addition.outputs.sum + if result.value < 0: + return self.exit_codes.ERROR_NEGATIVE_NUMBER + + def result(self): + """Add the result to the outputs.""" + self.out("result", self.ctx.addition.outputs.sum) diff --git a/tests/widget/test_widget.py b/tests/widget/test_widget.py new file mode 100644 index 00000000..c5cc3028 --- /dev/null +++ b/tests/widget/test_widget.py @@ -0,0 +1,32 @@ +from IPython.display import IFrame + + +def test_workgraph_widget(wg_calcfunction): + """Save the workgraph""" + + wg = wg_calcfunction + wg.name = "test_workgraph_widget" + wg.tasks["sumdiff2"].waiting_on.add(wg.tasks["sumdiff2"]) + wg._widget.from_workgraph(wg) + assert len(wg._widget.value["nodes"]) == 2 + # the waiting_on is also transformed to links + assert len(wg._widget.value["links"]) == 2 + # to_html + data = wg.to_html() + assert isinstance(data, IFrame) + + +def test_workgraph_task(wg_calcfunction): + """Save the workgraph""" + wg = wg_calcfunction + wg.name = "test_workgraph_task" + wg.tasks["sumdiff2"]._widget.from_node(wg.tasks["sumdiff2"]) + print(wg.tasks["sumdiff2"]._widget.value) + assert len(wg.tasks["sumdiff2"]._widget.value["nodes"]) == 1 + assert len( + wg.tasks["sumdiff2"]._widget.value["nodes"]["sumdiff2"]["inputs"] + ) == len(wg.tasks["sumdiff2"].inputs) + assert len(wg.tasks["sumdiff2"]._widget.value["links"]) == 0 + # to html + data = wg.tasks["sumdiff2"].to_html() + assert isinstance(data, IFrame)