From 8047817f7cff24e3cd1632482f3c7d16d2e809f2 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Thu, 12 Sep 2024 16:18:56 +0200 Subject: [PATCH] Refactor task's data, bumps `node-graph` to 0.0.19 (#304) This PR bumps `node-graph` to 0.0.19, so that we don't need to pickle the `node-class`: 1) inputs and outputs to dict 2) keep property inside input 3) use the `module` and `name` for the node_class and executor --- aiida_workgraph/decorator.py | 147 ++++++++++-------- aiida_workgraph/engine/utils.py | 33 +++- aiida_workgraph/engine/workgraph.py | 94 ++++++----- aiida_workgraph/executors/monitors.py | 7 +- aiida_workgraph/property.py | 4 +- aiida_workgraph/socket.py | 4 +- aiida_workgraph/task.py | 4 +- aiida_workgraph/tasks/builtins.py | 18 +-- aiida_workgraph/tasks/monitors.py | 9 +- aiida_workgraph/tasks/pythonjob.py | 2 +- aiida_workgraph/tasks/test.py | 6 +- aiida_workgraph/utils/__init__.py | 132 +++++++++++----- aiida_workgraph/utils/analysis.py | 25 +-- aiida_workgraph/web/backend/app/utils.py | 2 +- aiida_workgraph/widget/src/widget/__init__.py | 2 +- aiida_workgraph/workgraph.py | 7 +- docs/gallery/concept/autogen/task.py | 2 +- pyproject.toml | 2 +- tests/datas/test_calcfunction.yaml | 31 ++-- tests/test_awaitable_task.py | 4 +- tests/test_decorator.py | 4 +- tests/test_link.py | 20 +-- tests/test_python.py | 2 + tests/test_workgraph.py | 16 +- 24 files changed, 341 insertions(+), 236 deletions(-) diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index 0af0173b..a11304e2 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -36,7 +36,7 @@ def create_task(tdata): from node_graph.decorator import create_node tdata["type_mapping"] = type_mapping - tdata["node_type"] = tdata.pop("task_type") + tdata["metadata"]["node_type"] = tdata["metadata"].pop("task_type") return create_node(tdata) @@ -132,10 +132,10 @@ def build_task( return build_task_from_workgraph(executor) elif isinstance(executor, str): ( - path, + module, executor_name, ) = executor.rsplit(".", 1) - executor, _ = get_executor({"path": path, "name": executor_name}) + executor, _ = get_executor({"module": module, "name": executor_name}) if callable(executor): return build_task_from_callable(executor, inputs=inputs, outputs=outputs) @@ -162,24 +162,26 @@ def build_task_from_callable( and issubclass(executor, Task) ): return executor - tdata = {} + tdata = {"metadata": {}} if inspect.isfunction(executor): # calcfunction and workfunction if getattr(executor, "node_class", False): - tdata["task_type"] = task_types.get(executor.node_class, "NORMAL") + tdata["metadata"]["task_type"] = task_types.get( + executor.node_class, "NORMAL" + ) tdata["executor"] = executor return build_task_from_AiiDA(tdata, inputs=inputs, outputs=outputs)[0] else: - tdata["task_type"] = "NORMAL" + tdata["metadata"]["task_type"] = "NORMAL" tdata["executor"] = executor return build_task_from_function(executor, inputs=inputs, outputs=outputs) else: if issubclass(executor, CalcJob): - tdata["task_type"] = "CALCJOB" + tdata["metadata"]["task_type"] = "CALCJOB" tdata["executor"] = executor return build_task_from_AiiDA(tdata, inputs=inputs, outputs=outputs)[0] elif issubclass(executor, WorkChain): - tdata["task_type"] = "WORKCHAIN" + tdata["metadata"]["task_type"] = "WORKCHAIN" tdata["executor"] = executor return build_task_from_AiiDA(tdata, inputs=inputs, outputs=outputs)[0] raise ValueError("The executor is not supported.") @@ -203,9 +205,8 @@ def build_task_from_AiiDA( ) -> Task: """Register a task from a AiiDA component. For example: CalcJob, WorkChain, CalcFunction, WorkFunction.""" - from aiida_workgraph.task import Task - # print(executor) + tdata.setdefault("metadata", {}) inputs = [] if inputs is None else inputs outputs = [] if outputs is None else outputs executor = tdata["executor"] @@ -250,10 +251,10 @@ def build_task_from_AiiDA( tdata["identifier"] = tdata.pop("identifier", tdata["executor"].__name__) tdata["executor"] = { "executor": pickle.dumps(executor), - "type": tdata["task_type"], + "type": tdata["metadata"]["task_type"], "is_pickle": True, } - if tdata["task_type"].upper() in ["CALCFUNCTION", "WORKFUNCTION"]: + if tdata["metadata"]["task_type"].upper() in ["CALCFUNCTION", "WORKFUNCTION"]: outputs = ( [{"identifier": "workgraph.any", "name": "result"}] if not outputs @@ -261,13 +262,13 @@ def build_task_from_AiiDA( ) # build executor from the function tdata["executor"] = PickledFunction.build_executor(executor) - # tdata["executor"]["type"] = tdata["task_type"] + # tdata["executor"]["type"] = tdata["metadata"]["task_type"] # print("kwargs: ", kwargs) # add built-in sockets outputs.append({"identifier": "workgraph.any", "name": "_outputs"}) outputs.append({"identifier": "workgraph.any", "name": "_wait"}) inputs.append({"identifier": "workgraph.any", "name": "_wait", "link_limit": 1e6}) - tdata["node_class"] = Task + tdata["metadata"]["node_class"] = {"module": "aiida_workgraph.task", "name": "Task"} tdata["args"] = args tdata["kwargs"] = kwargs tdata["inputs"] = inputs @@ -280,7 +281,6 @@ def build_task_from_AiiDA( def build_pythonjob_task(func: Callable) -> Task: """Build PythonJob task from function.""" from aiida_workgraph.calculations.python import PythonJob - from aiida_workgraph.tasks.pythonjob import PythonJob as PythonJobTask from copy import deepcopy # if the function is not a task, build a task from the function @@ -290,42 +290,46 @@ def build_pythonjob_task(func: Callable) -> Task: raise ValueError( "GraphBuilder task cannot be run remotely. Please remove 'PythonJob'." ) - tdata = {"executor": PythonJob, "task_type": "CALCJOB"} + tdata = { + "metadata": {"task_type": "PYTHONJOB"}, + "executor": PythonJob, + } _, tdata_py = build_task_from_AiiDA(tdata) tdata = deepcopy(func.tdata) # merge the inputs and outputs from the PythonJob task to the function task # skip the already existed inputs and outputs - inputs = tdata["inputs"] - inputs.extend( - [ - {"identifier": "workgraph.string", "name": "computer"}, - {"identifier": "workgraph.string", "name": "code_label"}, - {"identifier": "workgraph.string", "name": "code_path"}, - {"identifier": "workgraph.string", "name": "prepend_text"}, - ] - ) - outputs = tdata["outputs"] - for input in tdata_py["inputs"]: - if input not in inputs: - inputs.append(input) - for output in tdata_py["outputs"]: - if output not in outputs: - outputs.append(output) - outputs.append({"identifier": "workgraph.any", "name": "exit_code"}) + for input in [ + {"identifier": "workgraph.string", "name": "computer"}, + {"identifier": "workgraph.string", "name": "code_label"}, + {"identifier": "workgraph.string", "name": "code_path"}, + {"identifier": "workgraph.string", "name": "prepend_text"}, + ]: + input["list_index"] = len(tdata["inputs"]) + 1 + tdata["inputs"][input["name"]] = input + for name, input in tdata_py["inputs"].items(): + if name not in tdata["inputs"]: + input["list_index"] = len(tdata["inputs"]) + 1 + tdata["inputs"][name] = input + for name, output in tdata_py["outputs"].items(): + if name not in tdata["outputs"]: + output["list_index"] = len(tdata["outputs"]) + 1 + tdata["outputs"][name] = output + for output in [{"identifier": "workgraph.any", "name": "exit_code"}]: + output["list_index"] = len(tdata["outputs"]) + 1 + tdata["outputs"][output["name"]] = output # change "copy_files" link_limit to 1e6 - for input in inputs: - if input["name"] == "copy_files": - input["link_limit"] = 1e6 + tdata["inputs"]["copy_files"]["link_limit"] = 1e6 # append the kwargs of the PythonJob task to the function task kwargs = tdata["kwargs"] kwargs.extend(["computer", "code_label", "code_path", "prepend_text"]) kwargs.extend(tdata_py["kwargs"]) - tdata["inputs"] = inputs - tdata["outputs"] = outputs tdata["kwargs"] = kwargs - tdata["task_type"] = "PYTHONJOB" + tdata["metadata"]["task_type"] = "PYTHONJOB" tdata["identifier"] = "workgraph.pythonjob" - tdata["node_class"] = PythonJobTask + tdata["metadata"]["node_class"] = { + "module": "aiida_workgraph.tasks.pythonjob", + "name": "PythonJob", + } task = create_task(tdata) task.is_aiida_component = True return task, tdata @@ -339,7 +343,10 @@ def build_shelljob_task( from aiida_shell.parsers.shell import ShellParser from node_graph.socket import NodeSocket - tdata = {"executor": ShellJob, "task_type": "SHELLJOB"} + tdata = { + "metadata": {"task_type": "SHELLJOB"}, + "executor": ShellJob, + } _, tdata = build_task_from_AiiDA(tdata) # create input sockets for the nodes, if it is linked other sockets links = {} @@ -354,16 +361,17 @@ def build_shelljob_task( # Output socket itself is not a value, so we remove the key from the nodes nodes.pop(key) for input in inputs: - if input not in tdata["inputs"]: - tdata["inputs"].append(input) + if input["name"] not in tdata["inputs"]: + input["list_index"] = len(tdata["inputs"]) + 1 + tdata["inputs"][input["name"]] = input tdata["kwargs"].append(input["name"]) # Extend the outputs - tdata["outputs"].extend( - [ - {"identifier": "workgraph.any", "name": "stdout"}, - {"identifier": "workgraph.any", "name": "stderr"}, - ] - ) + for output in [ + {"identifier": "workgraph.any", "name": "stdout"}, + {"identifier": "workgraph.any", "name": "stderr"}, + ]: + output["list_index"] = len(tdata["outputs"]) + 1 + tdata["outputs"][output["name"]] = output outputs = [] if outputs is None else outputs parser_outputs = [] if parser_outputs is None else parser_outputs outputs = [ @@ -373,18 +381,19 @@ def build_shelljob_task( outputs.extend(parser_outputs) # add user defined outputs for output in outputs: - if output not in tdata["outputs"]: - tdata["outputs"].append(output) + if output["name"] not in tdata["outputs"]: + output["list_index"] = len(tdata["outputs"]) + 1 + tdata["outputs"][output["name"]] = output # tdata["identifier"] = "ShellJob" - tdata["inputs"].extend( - [ - {"identifier": "workgraph.any", "name": "command"}, - {"identifier": "workgraph.any", "name": "resolve_command"}, - ] - ) + for input in [ + {"identifier": "workgraph.any", "name": "command"}, + {"identifier": "workgraph.any", "name": "resolve_command"}, + ]: + input["list_index"] = len(tdata["inputs"]) + 1 + tdata["inputs"][input["name"]] = input tdata["kwargs"].extend(["command", "resolve_command"]) - tdata["task_type"] = "SHELLJOB" + tdata["metadata"]["task_type"] = "SHELLJOB" task = create_task(tdata) task.is_aiida_component = True return task, tdata, links @@ -392,10 +401,9 @@ def build_shelljob_task( def build_task_from_workgraph(wg: any) -> Task: """Build task from workgraph.""" - from aiida_workgraph.task import Task from aiida.orm.utils.serialize import serialize - tdata = {"task_type": "workgraph"} + tdata = {"metadata": {"task_type": "workgraph"}} inputs = [] outputs = [] group_outputs = [] @@ -407,7 +415,7 @@ def build_task_from_workgraph(wg: any) -> Task: if socket.name == "_wait": continue inputs.append( - {"identifier": "workgraph.any", "name": f"{task.name}.{socket.name}"} + {"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"} ) # outputs outputs.append({"identifier": "workgraph.any", "name": f"{task.name}"}) @@ -415,7 +423,7 @@ def build_task_from_workgraph(wg: any) -> Task: if socket.name in ["_wait", "_outputs"]: continue outputs.append( - {"identifier": "workgraph.any", "name": f"{task.name}.{socket.name}"} + {"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"} ) group_outputs.append( { @@ -428,16 +436,16 @@ def build_task_from_workgraph(wg: any) -> Task: outputs.append({"identifier": "workgraph.any", "name": "_outputs"}) outputs.append({"identifier": "workgraph.any", "name": "_wait"}) inputs.append({"identifier": "workgraph.any", "name": "_wait", "link_limit": 1e6}) - tdata["node_class"] = Task + tdata["metadata"]["node_class"] = {"module": "aiida_workgraph.task", "name": "Task"} tdata["kwargs"] = kwargs tdata["inputs"] = inputs tdata["outputs"] = outputs tdata["identifier"] = wg.name executor = { - "path": "aiida_workgraph.engine.workgraph", + "module": "aiida_workgraph.engine.workgraph", "name": "WorkGraphEngine", "wgdata": serialize(wg.to_dict(store_nodes=True)), - "type": tdata["task_type"], + "type": tdata["metadata"]["task_type"], "is_pickle": False, } tdata["executor"] = executor @@ -490,7 +498,6 @@ def generate_tdata( ) -> Dict[str, Any]: """Generate task data for creating a task.""" from node_graph.decorator import generate_input_sockets - from aiida_workgraph.task import Task args, kwargs, var_args, var_kwargs, _inputs = generate_input_sockets( func, inputs, properties, type_mapping=type_mapping @@ -501,17 +508,19 @@ def generate_tdata( task_outputs.append({"identifier": "workgraph.any", "name": "_wait"}) task_outputs.append({"identifier": "workgraph.any", "name": "_outputs"}) tdata = { - "node_class": Task, "identifier": identifier, "args": args, "kwargs": kwargs, "var_args": var_args, "var_kwargs": var_kwargs, - "task_type": task_type, + "metadata": { + "task_type": task_type, + "catalog": catalog, + "node_class": {"module": "aiida_workgraph.task", "name": "Task"}, + }, "properties": properties, "inputs": _inputs, "outputs": task_outputs, - "catalog": catalog, } tdata["executor"] = PickledFunction.build_executor(func) if additional_data: diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py index 82332a78..d2c60723 100644 --- a/aiida_workgraph/engine/utils.py +++ b/aiida_workgraph/engine/utils.py @@ -5,7 +5,7 @@ def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple: """Prepare the inputs for WorkGraph task""" - from aiida_workgraph.utils import merge_properties + from aiida_workgraph.utils import merge_properties, serialize_properties from aiida.orm.utils.serialize import deserialize_unsafe wgdata = deserialize_unsafe(task["executor"]["wgdata"]) @@ -16,9 +16,12 @@ def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple: # because kwargs is updated using update_nested_dict_with_special_keys # which means the data is grouped by the task name for socket_name, value in data.items(): - wgdata["tasks"][task_name]["properties"][socket_name]["value"] = value + wgdata["tasks"][task_name]["inputs"][socket_name]["property"][ + "value" + ] = value # merge the properties merge_properties(wgdata) + serialize_properties(wgdata) metadata = {"call_link_label": task["name"]} inputs = {"wg": wgdata, "metadata": metadata} return inputs, wgdata @@ -31,14 +34,22 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict: # get the names kwargs for the PythonJob, which are the inputs before _wait function_kwargs = {} - for input in task["inputs"]: - if input["name"] == "_wait": + # TODO better way to find the function_kwargs + input_names = [ + name + for name, _ in sorted( + ((name, input["list_index"]) for name, input in task["inputs"].items()), + key=lambda x: x[1], + ) + ] + for name in input_names: + if name == "_wait": break - function_kwargs[input["name"]] = kwargs.pop(input["name"], None) + function_kwargs[name] = kwargs.pop(name, None) # if the var_kwargs is not None, we need to pop the var_kwargs from the kwargs # then update the function_kwargs if var_kwargs is not None - if task["metadata"]["var_kwargs"] is not None: - function_kwargs.pop(task["metadata"]["var_kwargs"], None) + if task["var_kwargs"] is not None: + function_kwargs.pop(task["var_kwargs"], None) if var_kwargs: # var_kwargs can be AttributeDict if it get data from the previous task output if isinstance(var_kwargs, (dict, AttributeDict)): @@ -88,7 +99,13 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict: + task["executor"]["function_source_code_without_decorator"] ) # outputs - function_outputs = task["outputs"] + function_outputs = [ + output + for output, _ in sorted( + ((output, output["list_index"]) for output in task["outputs"].values()), + key=lambda x: x[1], + ) + ] # serialize the kwargs into AiiDA Data function_kwargs = serialize_to_aiida_nodes(function_kwargs) # transfer the args to kwargs diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index ba0afcc6..11c85fb4 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -459,7 +459,6 @@ def on_create(self) -> None: super().on_create() wgdata = self.inputs.wg._dict - # print("wgdata: ", wgdata) restart_process = ( orm.load_node(wgdata["restart_process"].value) if wgdata.get("restart_process") @@ -515,7 +514,10 @@ def read_wgdata_from_base(self) -> t.Dict[str, t.Any]: wgdata = self.node.base.extras.get("_workgraph") for name, task in wgdata["tasks"].items(): wgdata["tasks"][name] = deserialize_unsafe(task) - for _, prop in wgdata["tasks"][name]["properties"].items(): + for _, input in wgdata["tasks"][name]["inputs"].items(): + if input["property"] is None: + continue + prop = input["property"] if isinstance(prop["value"], PickledLocalFunction): prop["value"] = prop["value"].value wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"]) @@ -546,6 +548,7 @@ def update_task(self, task: Task): This is used in error handlers to update the task parameters.""" tdata = task.to_dict() self.ctx._tasks[task.name]["properties"] = tdata["properties"] + self.ctx._tasks[task.name]["inputs"] = tdata["inputs"] self.reset_task(task.name) def get_task_state_info(self, name: str, key: str) -> str: @@ -736,7 +739,13 @@ def update_task_state(self, name: str) -> None: self.report(f"Task: {name} failed.") self.run_error_handlers(name) elif isinstance(node, orm.Data): - task["results"] = {task["outputs"][0]["name"]: node} + # + output_name = [ + output_name + for output_name in list(task["outputs"].keys()) + if output_name not in ["_wait", "_outputs"] + ][0] + task["results"] = {output_name: node} self.set_task_state_info(task["name"], "state", "FINISHED") self.task_set_context(name) self.report(f"Task: {name} finished.") @@ -749,16 +758,24 @@ def set_normal_task_results(self, name, results): """Set the results of a normal task. A normal task is created by decorating a function with @task(). """ + from aiida_workgraph.utils import get_sorted_names + task = self.ctx._tasks[name] if isinstance(results, tuple): if len(task["outputs"]) != len(results): return self.exit_codes.OUTPUS_NOT_MATCH_RESULTS - for i in range(len(task["outputs"])): - task["results"][task["outputs"][i]["name"]] = results[i] + output_names = get_sorted_names(task["outputs"]) + for i, output_name in enumerate(output_names): + task["results"][output_name] = results[i] elif isinstance(results, dict): task["results"] = results else: - task["results"][task["outputs"][0]["name"]] = results + output_name = [ + output_name + for output_name in list(task["outputs"].keys()) + if output_name not in ["_wait", "_outputs"] + ][0] + task["results"][output_name] = results self.task_set_context(name) self.set_task_state_info(name, "state", "FINISHED") self.report(f"Task: {name} finished.") @@ -785,10 +802,8 @@ def update_while_task_state(self, name: str) -> None: f"Wihle Task {name}: this iteration finished. Try to reset for the next iteration." ) # reset the condition tasks - for input in self.ctx._tasks[name]["inputs"]: - if input["name"].upper() == "CONDITIONS": - for link in input["links"]: - self.reset_task(link["from_node"], recursive=False) + for link in self.ctx._tasks[name]["inputs"]["conditions"]["links"]: + self.reset_task(link["from_node"], recursive=False) # reset the task and all its children, so that the task can run again # do not reset the execution count self.reset_task(name, reset_execution_count=False) @@ -806,7 +821,7 @@ def should_run_while_task(self, name: str) -> tuple[bool, t.Any]: # check the conditions of the while task not_excess_max_iterations = ( self.ctx._tasks[name]["execution_count"] - < self.ctx._tasks[name]["properties"]["max_iterations"]["value"] + < self.ctx._tasks[name]["inputs"]["max_iterations"]["property"]["value"] ) conditions = [not_excess_max_iterations] _, kwargs, _, _, _ = self.get_inputs(name) @@ -981,7 +996,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None executor, _ = get_executor(task["executor"]) # print("executor: ", executor) args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(name) - for i, key in enumerate(self.ctx._tasks[name]["metadata"]["args"]): + for i, key in enumerate(self.ctx._tasks[name]["args"]): kwargs[key] = args[i] # update the port namespace kwargs = update_nested_dict_with_special_keys(kwargs) @@ -998,7 +1013,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None if continue_workgraph: self.continue_workgraph() elif task["metadata"]["node_type"].upper() == "DATA": - for key in self.ctx._tasks[name]["metadata"]["args"]: + for key in self.ctx._tasks[name]["args"]: kwargs.pop(key, None) results = create_data_node(executor, args, kwargs) self.set_task_state_info(name, "process", results) @@ -1174,7 +1189,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None self.update_parent_task_state(name) self.continue_workgraph() elif task["metadata"]["node_type"].upper() in ["AWAITABLE"]: - for key in self.ctx._tasks[name]["metadata"]["args"]: + for key in self.ctx._tasks[name]["args"]: kwargs.pop(key, None) awaitable_target = asyncio.ensure_future( self.run_executor(executor, args, kwargs, var_args, var_kwargs), @@ -1185,10 +1200,15 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None self.to_context(**{name: awaitable}) elif task["metadata"]["node_type"].upper() in ["MONITOR"]: - for key in self.ctx._tasks[name]["metadata"]["args"]: + for key in self.ctx._tasks[name]["args"]: kwargs.pop(key, None) # add function and interval to the args - args = [executor, kwargs.pop("interval"), kwargs.pop("timeout"), *args] + args = [ + executor, + kwargs.pop("interval", 1), + kwargs.pop("timeout", 3600), + *args, + ] awaitable_target = asyncio.ensure_future( self.run_executor(monitor, args, kwargs, var_args, var_kwargs), loop=self.loop, @@ -1200,10 +1220,10 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None self.to_context(**{name: awaitable}) elif task["metadata"]["node_type"].upper() in ["NORMAL"]: # Normal task is created by decoratoring a function with @task() - if "context" in task["metadata"]["kwargs"]: + if "context" in task["kwargs"]: self.ctx.task_name = name kwargs.update({"context": self.ctx}) - for key in self.ctx._tasks[name]["metadata"]["args"]: + for key in self.ctx._tasks[name]["args"]: kwargs.pop(key, None) try: results = self.run_executor( @@ -1262,26 +1282,22 @@ def get_inputs( task = self.ctx._tasks[name] properties = task.get("properties", {}) inputs = {} - for input in task["inputs"]: + for name, input in task["inputs"].items(): # print(f"input: {input['name']}") if len(input["links"]) == 0: - inputs[input["name"]] = self.update_context_variable( - properties[input["name"]]["value"] - ) + inputs[name] = self.update_context_variable(input["property"]["value"]) elif len(input["links"]) == 1: link = input["links"][0] if self.ctx._tasks[link["from_node"]]["results"] is None: - inputs[input["name"]] = None + inputs[name] = None else: # handle the special socket _wait, _outputs if link["from_socket"] == "_wait": continue elif link["from_socket"] == "_outputs": - inputs[input["name"]] = self.ctx._tasks[link["from_node"]][ - "results" - ] + inputs[name] = self.ctx._tasks[link["from_node"]]["results"] else: - inputs[input["name"]] = get_nested_dict( + inputs[name] = get_nested_dict( self.ctx._tasks[link["from_node"]]["results"], link["from_socket"], ) @@ -1289,18 +1305,18 @@ def get_inputs( elif len(input["links"]) > 1: value = {} for link in input["links"]: - name = f'{link["from_node"]}_{link["from_socket"]}' + item_name = f'{link["from_node"]}_{link["from_socket"]}' # handle the special socket _wait, _outputs if link["from_socket"] == "_wait": continue if self.ctx._tasks[link["from_node"]]["results"] is None: - value[name] = None + value[item_name] = None else: - value[name] = self.ctx._tasks[link["from_node"]]["results"][ - link["from_socket"] - ] - inputs[input["name"]] = value - for name in task["metadata"].get("args", []): + value[item_name] = self.ctx._tasks[link["from_node"]][ + "results" + ][link["from_socket"]] + inputs[name] = value + for name in task.get("args", []): if name in inputs: args.append(inputs[name]) args_dict[name] = inputs[name] @@ -1308,21 +1324,21 @@ def get_inputs( value = self.update_context_variable(properties[name]["value"]) args.append(value) args_dict[name] = value - for name in task["metadata"].get("kwargs", []): + for name in task.get("kwargs", []): if name in inputs: kwargs[name] = inputs[name] else: value = self.update_context_variable(properties[name]["value"]) kwargs[name] = value - if task["metadata"]["var_args"] is not None: - name = task["metadata"]["var_args"] + if task["var_args"] is not None: + name = task["var_args"] if name in inputs: var_args = inputs[name] else: value = self.update_context_variable(properties[name]["value"]) var_args = value - if task["metadata"]["var_kwargs"] is not None: - name = task["metadata"]["var_kwargs"] + if task["var_kwargs"] is not None: + name = task["var_kwargs"] if name in inputs: var_kwargs = inputs[name] else: diff --git a/aiida_workgraph/executors/monitors.py b/aiida_workgraph/executors/monitors.py index e82d5299..20edad67 100644 --- a/aiida_workgraph/executors/monitors.py +++ b/aiida_workgraph/executors/monitors.py @@ -13,17 +13,20 @@ async def monitor(function, interval, timeout, *args, **kwargs): await asyncio.sleep(interval) -def file_monitor(filename): +def file_monitor(filename: str): """Check if the file exists.""" import os return os.path.exists(filename) -def time_monitor(time): +def time_monitor(time: str): """Return True if the current time is greater than the given time.""" import datetime + # load the time string + time = datetime.datetime.strptime(time, "%Y-%m-%d %H:%M:%S.%f") + return datetime.datetime.now() > time diff --git a/aiida_workgraph/property.py b/aiida_workgraph/property.py index 5d803ce3..4098e729 100644 --- a/aiida_workgraph/property.py +++ b/aiida_workgraph/property.py @@ -56,12 +56,12 @@ def set_value(self, value: Any) -> None: raise Exception("{} is not an {}.".format(value, DataClass.__name__)) def get_serialize(self) -> Dict[str, str]: - serialize = {"path": "aiida.orm.utils.serialize", "name": "serialize"} + serialize = {"module": "aiida.orm.utils.serialize", "name": "serialize"} return serialize def get_deserialize(self) -> Dict[str, str]: deserialize = { - "path": "aiida.orm.utils.serialize", + "module": "aiida.orm.utils.serialize", "name": "deserialize_unsafe", } return deserialize diff --git a/aiida_workgraph/socket.py b/aiida_workgraph/socket.py index c1a8899c..0211c4b9 100644 --- a/aiida_workgraph/socket.py +++ b/aiida_workgraph/socket.py @@ -32,12 +32,12 @@ def __init__( self.add_property(DataClass, name, **kwargs) def get_serialize(self) -> dict: - serialize = {"path": "aiida.orm.utils.serialize", "name": "serialize"} + serialize = {"module": "aiida.orm.utils.serialize", "name": "serialize"} return serialize def get_deserialize(self) -> dict: deserialize = { - "path": "aiida.orm.utils.serialize", + "module": "aiida.orm.utils.serialize", "name": "deserialize_unsafe", } return deserialize diff --git a/aiida_workgraph/task.py b/aiida_workgraph/task.py index 55ca15b2..5b8b761e 100644 --- a/aiida_workgraph/task.py +++ b/aiida_workgraph/task.py @@ -111,9 +111,11 @@ def from_dict(cls, data: Dict[str, Any], task_pool: Optional[Any] = None) -> "Ta Returns: Node: An instance of Node initialized with the provided data.""" - from aiida_workgraph.tasks import task_pool + from aiida_workgraph.tasks import task_pool as workgraph_task_pool from aiida.orm.utils.serialize import deserialize_unsafe + if task_pool is None: + task_pool = workgraph_task_pool task = GraphNode.from_dict(data, node_pool=task_pool) task.context_mapping = data.get("context_mapping", {}) task.waiting_on.add(data.get("wait", [])) diff --git a/aiida_workgraph/tasks/builtins.py b/aiida_workgraph/tasks/builtins.py index 48f7e386..3dde1fa3 100644 --- a/aiida_workgraph/tasks/builtins.py +++ b/aiida_workgraph/tasks/builtins.py @@ -91,7 +91,7 @@ def create_sockets(self) -> None: def get_executor(self) -> Dict[str, str]: return { - "path": "aiida_workgraph.executors.builtins", + "module": "aiida_workgraph.executors.builtins", "name": "GatherWorkChain", } @@ -143,7 +143,7 @@ def create_sockets(self) -> None: def get_executor(self) -> Dict[str, str]: return { - "path": "aiida.orm", + "module": "aiida.orm", "name": "Int", } @@ -162,7 +162,7 @@ def create_sockets(self) -> None: def get_executor(self) -> Dict[str, str]: return { - "path": "aiida.orm", + "module": "aiida.orm", "name": "Float", } @@ -181,7 +181,7 @@ def create_sockets(self) -> None: def get_executor(self) -> Dict[str, str]: return { - "path": "aiida.orm", + "module": "aiida.orm", "name": "Str", } @@ -202,7 +202,7 @@ def create_sockets(self) -> None: def get_executor(self) -> Dict[str, str]: return { - "path": "aiida.orm", + "module": "aiida.orm", "name": "List", } @@ -223,7 +223,7 @@ def create_sockets(self) -> None: def get_executor(self) -> Dict[str, str]: return { - "path": "aiida.orm", + "module": "aiida.orm", "name": "Dict", } @@ -251,7 +251,7 @@ def create_sockets(self) -> None: def get_executor(self) -> Dict[str, str]: return { - "path": "aiida.orm", + "module": "aiida.orm", "name": "load_node", } @@ -276,7 +276,7 @@ def create_sockets(self) -> None: def get_executor(self) -> Dict[str, str]: return { - "path": "aiida.orm", + "module": "aiida.orm", "name": "load_code", } @@ -303,6 +303,6 @@ def create_sockets(self) -> None: def get_executor(self) -> Dict[str, str]: return { - "path": "aiida_workgraph.executors.builtins", + "module": "aiida_workgraph.executors.builtins", "name": "select", } diff --git a/aiida_workgraph/tasks/monitors.py b/aiida_workgraph/tasks/monitors.py index 4d5f8793..7cdc5268 100644 --- a/aiida_workgraph/tasks/monitors.py +++ b/aiida_workgraph/tasks/monitors.py @@ -22,11 +22,12 @@ def create_sockets(self) -> None: inp.add_property("workgraph.any", default=86400.0) inp = self.inputs.new("workgraph.any", "_wait") inp.link_limit = 100000 + self.outputs.new("workgraph.any", "result") self.outputs.new("workgraph.any", "_wait") def get_executor(self) -> Dict[str, str]: return { - "path": "aiida_workgraph.executors.monitors", + "module": "aiida_workgraph.executors.monitors", "name": "time_monitor", } @@ -51,11 +52,12 @@ def create_sockets(self) -> None: inp.add_property("workgraph.any", default=86400.0) inp = self.inputs.new("workgraph.any", "_wait") inp.link_limit = 100000 + self.outputs.new("workgraph.any", "result") self.outputs.new("workgraph.any", "_wait") def get_executor(self) -> Dict[str, str]: return { - "path": "aiida_workgraph.executors.monitors", + "module": "aiida_workgraph.executors.monitors", "name": "file_monitor", } @@ -82,10 +84,11 @@ def create_sockets(self) -> None: inp.add_property("workgraph.any", default=86400.0) inp = self.inputs.new("workgraph.any", "_wait") inp.link_limit = 100000 + self.outputs.new("workgraph.any", "result") self.outputs.new("workgraph.any", "_wait") def get_executor(self) -> Dict[str, str]: return { - "path": "aiida_workgraph.executors.monitors", + "module": "aiida_workgraph.executors.monitors", "name": "task_monitor", } diff --git a/aiida_workgraph/tasks/pythonjob.py b/aiida_workgraph/tasks/pythonjob.py index 0e93ea64..a1ee8866 100644 --- a/aiida_workgraph/tasks/pythonjob.py +++ b/aiida_workgraph/tasks/pythonjob.py @@ -12,7 +12,7 @@ class PythonJob(Task): @classmethod def get_function_kwargs(cls, data) -> Dict[str, Any]: input_kwargs = set() - for name in data["metadata"]["kwargs"]: + for name in data["kwargs"]: # all the kwargs are after computer is the input for the PythonJob, should be AiiDA Data node if name == "computer": break diff --git a/aiida_workgraph/tasks/test.py b/aiida_workgraph/tasks/test.py index e0c2bfa4..6997d779 100644 --- a/aiida_workgraph/tasks/test.py +++ b/aiida_workgraph/tasks/test.py @@ -26,7 +26,7 @@ def create_sockets(self) -> None: def get_executor(self) -> Dict[str, str]: return { - "path": "aiida_workgraph.executors.test", + "module": "aiida_workgraph.executors.test", "name": "add", } @@ -51,7 +51,7 @@ def create_sockets(self) -> None: def get_executor(self) -> Dict[str, str]: return { - "path": "aiida_workgraph.executors.test", + "module": "aiida_workgraph.executors.test", "name": "greater", } @@ -81,7 +81,7 @@ def create_sockets(self) -> None: def get_executor(self) -> Dict[str, str]: return { - "path": "aiida_workgraph.executors.test", + "module": "aiida_workgraph.executors.test", "name": "sum_diff", } diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index 5752c03d..8f795de4 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -5,6 +5,36 @@ from aiida.engine.runners import Runner +def get_sorted_names(data: dict) -> list: + """Get the sorted names from a dictionary.""" + print("data: ", data) + sorted_names = [ + name + for name, _ in sorted( + ((name, item["list_index"]) for name, item in data.items()), + key=lambda x: x[1], + ) + ] + return sorted_names + + +def store_nodes_recursely(data: Any) -> None: + """Recurse through a data structure and store any unstored nodes that are found along the way + :param data: a data structure potentially containing unstored nodes + """ + from aiida.orm import Node + import collections.abc + + if isinstance(data, Node) and not data.is_stored: + data.store() + elif isinstance(data, collections.abc.Mapping): + for _, value in data.items(): + store_nodes_recursely(value) + elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str): + for value in data: + store_nodes_recursely(value) + + def get_executor(data: Dict[str, Any]) -> Union[Process, Any]: """Import executor from path and return the executor and type.""" import importlib @@ -27,10 +57,10 @@ def get_executor(data: Dict[str, Any]) -> Union[Process, Any]: executor = CalculationFactory(data["name"]) elif type == "DataFactory": executor = DataFactory(data["name"]) - elif data["name"] == "" and data["path"] == "": + elif data["name"] == "" and data["module"] == "": executor = None else: - module = importlib.import_module("{}".format(data.get("path", ""))) + module = importlib.import_module("{}".format(data.get("module", ""))) executor = getattr(module, data["name"]) return executor, type @@ -131,13 +161,21 @@ def merge_properties(wgdata: Dict[str, Any]) -> None: "code": 1}} So that no "." in the key name. """ - for name, task in wgdata["tasks"].items(): + for _, task in wgdata["tasks"].items(): for key, prop in task["properties"].items(): if "." in key and prop["value"] not in [None, {}]: root, key = key.split(".", 1) - update_nested_dict( - task["properties"][root]["value"], key, prop["value"] - ) + root_prop = task["properties"][root] + update_nested_dict(root_prop["value"], key, prop["value"]) + prop["value"] = None + for key, input in task["inputs"].items(): + if input["property"] is None: + continue + prop = input["property"] + if "." in key and prop["value"] not in [None, {}]: + root, key = key.split(".", 1) + root_prop = task["inputs"][root]["property"] + update_nested_dict(root_prop["value"], key, prop["value"]) prop["value"] = None @@ -359,7 +397,10 @@ def serialize_properties(wgdata): import inspect for _, task in wgdata["tasks"].items(): - for _, prop in task["properties"].items(): + for _, input in task["inputs"].items(): + if input["property"] is None: + continue + prop = input["property"] if inspect.isfunction(prop["value"]): prop["value"] = PickledLocalFunction(prop["value"]).store() @@ -532,39 +573,50 @@ def recursive_to_dict(attr_dict): return attr_dict -def process_properties(properties: Dict) -> Dict: +def get_raw_value(identifier, value: Any) -> Any: + """Get the raw value from a Data node.""" + if identifier in [ + "workgraph.int", + "workgraph.float", + "workgraph.string", + "workgraph.bool", + ]: + if value is not None: + return value + elif identifier in [ + "workgraph.aiida_int", + "workgraph.aiida_float", + "workgraph.aiida_string", + "workgraph.aiida_bool", + ]: + if value is not None and isinstance(value, orm.Data): + return value.value + else: + return value + elif ( + identifier == "workgraph.aiida_structure" + and value is not None + and isinstance(value, orm.StructureData) + ): + content = value.backend_entity.attributes + content["node_type"] = value.node_type + return content + + +def process_properties(task: Dict) -> Dict: """Extract raw values.""" result = {} - for key, prop in properties.items(): + for name, prop in task["properties"].items(): identifier = prop["identifier"] value = prop.get("value") - - if identifier in [ - "workgraph.int", - "workgraph.float", - "workgraph.string", - "workgraph.bool", - ]: - if value is not None: - result[key] = value - elif identifier in [ - "workgraph.aiida_int", - "workgraph.aiida_float", - "workgraph.aiida_string", - "workgraph.aiida_bool", - ]: - if value is not None and isinstance(value, orm.Data): - result[key] = value.value - else: - result[key] = value - elif ( - identifier == "workgraph.aiida_structure" - and value is not None - and isinstance(value, orm.StructureData) - ): - content = value.backend_entity.attributes - content["node_type"] = value.node_type - result[key] = content + result[name] = get_raw_value(identifier, value) + # + for name, input in task["inputs"].items(): + if input["property"] is not None: + prop = input["property"] + identifier = prop["identifier"] + value = prop.get("value") + result[name] = get_raw_value(identifier, value) return result @@ -583,12 +635,8 @@ def workgraph_to_short_json( # for name, task in wgdata["tasks"].items(): # Add required inputs to nodes - inputs = [ - input - for input in task["inputs"] - if input["name"] in task["metadata"]["args"] - ] - properties = process_properties(task.get("properties", {})) + inputs = [{"name": name} for name in task["inputs"] if name in task["args"]] + properties = process_properties(task) wgdata_short["nodes"][name] = { "label": task["name"], "node_type": task["metadata"]["node_type"].upper(), diff --git a/aiida_workgraph/utils/analysis.py b/aiida_workgraph/utils/analysis.py index 0ce58d46..477a2f27 100644 --- a/aiida_workgraph/utils/analysis.py +++ b/aiida_workgraph/utils/analysis.py @@ -83,20 +83,24 @@ def build_task_link(self) -> None: """ # reset task input links for name, task in self.wgdata["tasks"].items(): - for input in task["inputs"]: + for _, input in task["inputs"].items(): input["links"] = [] - for output in task["outputs"]: + for _, output in task["outputs"].items(): output["links"] = [] for link in self.wgdata["links"]: to_socket = [ socket - for socket in self.wgdata["tasks"][link["to_node"]]["inputs"] - if socket["name"] == link["to_socket"] + for name, socket in self.wgdata["tasks"][link["to_node"]][ + "inputs" + ].items() + if name == link["to_socket"] ][0] from_socket = [ socket - for socket in self.wgdata["tasks"][link["from_node"]]["outputs"] - if socket["name"] == link["from_socket"] + for name, socket in self.wgdata["tasks"][link["from_node"]][ + "outputs" + ].items() + if name == link["from_socket"] ][0] to_socket["links"].append(link) from_socket["links"].append(link) @@ -134,7 +138,7 @@ def find_zone_inputs(self, name: str) -> None: """Find the input and outputs tasks for the zone.""" task = self.wgdata["tasks"][name] input_tasks = [] - for input in self.wgdata["tasks"][name]["inputs"]: + for _, input in self.wgdata["tasks"][name]["inputs"].items(): for link in input["links"]: input_tasks.append(link["from_node"]) # find all the input tasks @@ -152,7 +156,7 @@ def find_zone_inputs(self, name: str) -> None: else: # if the child task is not a zone, get the input tasks of the child task # find all the input tasks which outside the while zone - for input in self.wgdata["tasks"][child_task]["inputs"]: + for _, input in self.wgdata["tasks"][child_task]["inputs"].items(): for link in input["links"]: input_tasks.append(link["from_node"]) # find the input tasks which are not in the zone @@ -209,7 +213,10 @@ def insert_workgraph_to_db(self) -> None: self.process.base.extras.set("_workgraph_short", short_wgdata) self.save_task_states() for name, task in self.wgdata["tasks"].items(): - for _, prop in task["properties"].items(): + for _, input in task["inputs"].items(): + if input["property"] is None: + continue + prop = input["property"] if inspect.isfunction(prop["value"]): prop["value"] = PickledLocalFunction(prop["value"]).store() self.wgdata["tasks"][name] = serialize(task) diff --git a/aiida_workgraph/web/backend/app/utils.py b/aiida_workgraph/web/backend/app/utils.py index 25c35751..e7cd4199 100644 --- a/aiida_workgraph/web/backend/app/utils.py +++ b/aiida_workgraph/web/backend/app/utils.py @@ -83,7 +83,7 @@ def node_to_short_json(workgraph_pk: int, tdata: Dict[str, Any]) -> Dict[str, An "metadata": [ ["name", tdata["name"]], ["node_type", tdata["metadata"]["node_type"]], - ["identifier", tdata["metadata"]["identifier"]], + ["identifier", tdata["identifier"]], ], "executor": executor, } diff --git a/aiida_workgraph/widget/src/widget/__init__.py b/aiida_workgraph/widget/src/widget/__init__.py index 0fafbb78..ee6113e6 100644 --- a/aiida_workgraph/widget/src/widget/__init__.py +++ b/aiida_workgraph/widget/src/widget/__init__.py @@ -43,7 +43,7 @@ def from_node(self, node: Any) -> None: tdata.pop("executor", None) tdata.pop("node_class", None) tdata.pop("process", None) - tdata["label"] = tdata["metadata"]["identifier"] + tdata["label"] = tdata["identifier"] wgdata = {"name": node.name, "nodes": {node.name: tdata}, "links": []} self.value = wgdata diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index 339435a0..bb1c580e 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -177,6 +177,7 @@ def save_to_base(self, wgdata: Dict[str, Any]) -> None: def to_dict(self, store_nodes=False) -> Dict[str, Any]: import cloudpickle as pickle + from aiida_workgraph.utils import store_nodes_recursely wgdata = super().to_dict() # save the sequence and context @@ -200,11 +201,7 @@ def to_dict(self, store_nodes=False) -> Dict[str, Any]: wgdata["error_handlers"] = pickle.dumps(self.error_handlers) wgdata["tasks"] = wgdata.pop("nodes") if store_nodes: - for task in wgdata["tasks"].values(): - for prop in task["properties"].values(): - if isinstance(prop["value"], aiida.orm.Node): - prop["value"].store() - + store_nodes_recursely(wgdata) return wgdata def wait(self, timeout: int = 50, tasks: dict = None) -> None: diff --git a/docs/gallery/concept/autogen/task.py b/docs/gallery/concept/autogen/task.py index 1b01b289..366f7f93 100644 --- a/docs/gallery/concept/autogen/task.py +++ b/docs/gallery/concept/autogen/task.py @@ -169,7 +169,7 @@ def create_sockets(self): def get_executor(self): return { - "path": "aiida_workgraph.executors.test", + "module": "aiida_workgraph.executors.test", "name": "add", } diff --git a/pyproject.toml b/pyproject.toml index d6963fb9..52683623 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "numpy~=1.21", "scipy", "ase", - "node-graph>=0.0.18", + "node-graph>=0.0.19", "aiida-core>=2.3", "cloudpickle", "aiida-shell", diff --git a/tests/datas/test_calcfunction.yaml b/tests/datas/test_calcfunction.yaml index 0723f109..e7d9de2f 100644 --- a/tests/datas/test_calcfunction.yaml +++ b/tests/datas/test_calcfunction.yaml @@ -7,21 +7,28 @@ metadata: tasks: - identifier: workgraph.aiida_float name: float1 - properties: - value: 3.0 + inputs: + - name: value + property: + value: 3.0 - identifier: workgraph.test_sum_diff name: sumdiff1 - properties: - x: 2.0 inputs: - - to_socket: y - from_node: float1 - from_socket: 0 + - name: x + property: + value: 2.0 - identifier: workgraph.test_sum_diff name: sumdiff2 - properties: - x: 4.0 inputs: - - to_socket: y - from_node: sumdiff1 - from_socket: 0 + - name: x + property: + value: 4.0 +links: + - to_node: sumdiff1 + from_node: float1 + to_socket: y + from_socket: result + - to_node: sumdiff2 + from_node: sumdiff1 + to_socket: y + from_socket: sum diff --git a/tests/test_awaitable_task.py b/tests/test_awaitable_task.py index 6d118027..83f51339 100644 --- a/tests/test_awaitable_task.py +++ b/tests/test_awaitable_task.py @@ -30,7 +30,7 @@ def test_time_monitor(decorated_add): monitor1 = wg.add_task( "workgraph.time_monitor", "monitor1", - datetime=datetime.datetime.now() + datetime.timedelta(seconds=10), + datetime=str(datetime.datetime.now() + datetime.timedelta(seconds=10)), ) add1 = wg.add_task(decorated_add, "add1", x=1, y=2) add1.waiting_on.add(monitor1) @@ -49,7 +49,7 @@ async def create_test_file(filepath="/tmp/test_file_monitor.txt", t=2): with open(filepath, "w") as f: f.write("test") - monitor_file_path = tmp_path / "test_file_monitor.txt" + monitor_file_path = str(tmp_path / "test_file_monitor.txt") wg = WorkGraph(name="test_file_monitor") monitor1 = wg.add_task( diff --git a/tests/test_decorator.py b/tests/test_decorator.py index d1ea476b..2d96c708 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -70,9 +70,7 @@ def test_decorators_task_args(task_function): assert tdata["kwargs"] == ["b"] assert tdata["var_args"] is None assert tdata["var_kwargs"] == "c" - assert set([output["name"] for output in tdata["outputs"]]) == set( - ["result", "_outputs", "_wait"] - ) + assert set(tdata["outputs"].keys()) == set(["result", "_outputs", "_wait"]) @pytest.fixture(params=["decorator_factory", "decorator"]) diff --git a/tests/test_link.py b/tests/test_link.py index e87c98b3..e7f2a3ba 100644 --- a/tests/test_link.py +++ b/tests/test_link.py @@ -6,24 +6,24 @@ def test_multiply_link() -> None: """Test multiply link.""" from aiida_workgraph import task, WorkGraph - from aiida.orm import Float, load_node + from aiida.orm import Float @task.calcfunction() - def sum(datas): + def sum(**datas): total = 0 - for data in datas: - total += load_node(data).value + for _, data in datas.items(): + total += data.value return Float(total) wg = WorkGraph(name="test_multiply_link") float1 = wg.add_task("workgraph.aiida_node", pk=Float(1.0).store().pk) float2 = wg.add_task("workgraph.aiida_node", pk=Float(2.0).store().pk) float3 = wg.add_task("workgraph.aiida_node", pk=Float(3.0).store().pk) - gather1 = wg.add_task("workgraph.gather", "gather1") sum1 = wg.add_task(sum, "sum1") - wg.add_link(float1.outputs[0], gather1.inputs[0]) - wg.add_link(float2.outputs[0], gather1.inputs[0]) - wg.add_link(float3.outputs[0], gather1.inputs[0]) - wg.add_link(gather1.outputs["result"], sum1.inputs["datas"]) - wg.submit(wait=True) + sum1.inputs["datas"].link_limit = 100 + wg.add_link(float1.outputs[0], sum1.inputs["datas"]) + wg.add_link(float2.outputs[0], sum1.inputs["datas"]) + wg.add_link(float3.outputs[0], sum1.inputs["datas"]) + # wg.submit(wait=True) + wg.run() assert sum1.node.outputs.result.value == 6 diff --git a/tests/test_python.py b/tests/test_python.py index c760666f..2e7ddba2 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -486,6 +486,8 @@ def add(x: str, y: str) -> str: ) assert wg.tasks["add"].outputs["result"].value.value == "Hello, World!" wg = WorkGraph.load(wg.pk) + wg.tasks["add"].inputs["x"].value = "Hello, " + wg.tasks["add"].inputs["y"].value = "World!" def test_exit_code(fixture_localhost, python_executable_path): diff --git a/tests/test_workgraph.py b/tests/test_workgraph.py index 88792980..d136c164 100644 --- a/tests/test_workgraph.py +++ b/tests/test_workgraph.py @@ -5,17 +5,13 @@ from aiida.calculations.arithmetic.add import ArithmeticAddCalculation -def test_to_dict(wg_calcfunction): +def test_from_dict(decorated_add): """Export NodeGraph to dict.""" - wg = wg_calcfunction - wgdata = wg.to_dict() - assert len(wgdata["tasks"]) == len(wg.tasks) - assert len(wgdata["links"]) == len(wg.links) - - -def test_from_dict(wg_calcfunction): - """Export NodeGraph to dict.""" - wg = wg_calcfunction + wg = WorkGraph("test_from_dict") + task1 = wg.add_task(decorated_add, x=2, y=3) + wg.add_task( + "workgraph.test_sum_diff", name="sumdiff2", x=4, y=task1.outputs["result"] + ) wgdata = wg.to_dict() wg1 = WorkGraph.from_dict(wgdata) assert len(wg.tasks) == len(wg1.tasks)