From 4110f485cddf43abfc4f885b2510393f2bafe458 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Mon, 16 Dec 2024 16:13:33 +0100 Subject: [PATCH] Re-introduce working changes after merge --- src/aiida_workgraph/collection.py | 6 ++ src/aiida_workgraph/decorator.py | 93 ++++++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/src/aiida_workgraph/collection.py b/src/aiida_workgraph/collection.py index e73dbf59..db7a8d4e 100644 --- a/src/aiida_workgraph/collection.py +++ b/src/aiida_workgraph/collection.py @@ -4,6 +4,8 @@ ) from typing import Any, Callable, Optional, Union +from aiida.engine import ProcessBuilder + class TaskCollection(NodeCollection): def _new( @@ -18,6 +20,7 @@ def _new( build_pythonjob_task, build_shelljob_task, build_task_from_workgraph, + build_task_from_builder, ) from aiida_workgraph.workgraph import WorkGraph @@ -41,6 +44,9 @@ def _new( return task if isinstance(identifier, WorkGraph): identifier = build_task_from_workgraph(identifier) + if isinstance(identifier, ProcessBuilder): + task = build_task_from_builder(identifier) + return task return super()._new(identifier, name, uuid, **kwargs) diff --git a/src/aiida_workgraph/decorator.py b/src/aiida_workgraph/decorator.py index 263d4445..1edce96e 100644 --- a/src/aiida_workgraph/decorator.py +++ b/src/aiida_workgraph/decorator.py @@ -2,9 +2,9 @@ from typing import Any, Callable, Dict, List, Optional, Union, Tuple from aiida_workgraph.utils import get_executor -from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain +from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain, ProcessBuilder from aiida_workgraph.task import Task -from aiida_workgraph.utils import build_callable, validate_task_inout +from aiida_workgraph.utils import build_callable, validate_task_inout, get_dict_from_builder import inspect from aiida_workgraph.config import builtin_inputs, builtin_outputs, task_types from aiida_workgraph.orm.mapping import type_mapping @@ -299,6 +299,95 @@ def build_task_from_workgraph(wg: any) -> Task: return task +def build_task_from_builder(builder: ProcessBuilder) -> Task: + """Build task from an aiida-core ProcessBuilder.""" + from aiida.orm.utils.serialize import serialize + + tdata = {"metadata": {"task_type": "builder"}} + inputs = [] + outputs = [] + group_outputs = [] + + process_class = builder._process_class + + # executor = get_executor(self.get_executor())[0] + # builder = executor.get_builder_from_protocol(*args, **kwargs) + # TODO: Instantiate AiiDA object from the builder, and pass that, rather than having to manually construct the tdata + # TODO: here again + # data = get_dict_from_builder(builder) + + # # data.pop('identifier') + # data = {**data, **tdata} + + # data['identifier'] = 'a' + # # tdata["identifier"] = wg.name + # task = Task.from_dict(data=data) + return task + + # def add_task( + # self, identifier: Union[str, callable], name: str = None, **kwargs + # ) -> Task: + + # self.set(data) + + # add all the inputs/outputs from the tasks in the workgraph + + # for task in wg.tasks: + # # inputs + # inputs.append( + # { + # "identifier": "workgraph.namespace", + # "name": f"{task.name}", + # } + # ) + # for socket in task.inputs: + # if socket.name == "_wait": + # continue + # inputs.append( + # {"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"} + # ) + # # outputs + # outputs.append( + # { + # "identifier": "workgraph.namespace", + # "name": f"{task.name}", + # } + # ) + # for socket in task.outputs: + # if socket.name in ["_wait", "_outputs"]: + # continue + # outputs.append( + # {"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"} + # ) + # group_outputs.append( + # { + # "name": f"{task.name}.{socket.name}", + # "from": f"{task.name}.{socket.name}", + # } + # ) + # kwargs = [input["name"] for input in inputs] + # # add built-in sockets + # outputs.append({"identifier": "workgraph.any", "name": "_outputs"}) + # outputs.append({"identifier": "workgraph.any", "name": "_wait"}) + # inputs.append({"identifier": "workgraph.any", "name": "_wait", "link_limit": 1e6}) + # tdata["metadata"]["node_class"] = {"module": "aiida_workgraph.task", "name": "Task"} + # tdata["kwargs"] = kwargs + # tdata["inputs"] = inputs + # tdata["outputs"] = outputs + # tdata["identifier"] = wg.name + # executor = { + # "module": "aiida_workgraph.engine.workgraph", + # "name": "WorkGraphEngine", + # "wgdata": serialize(wg.to_dict(store_nodes=True)), + # "type": tdata["metadata"]["task_type"], + # "is_pickle": False, + # } + # tdata["metadata"]["group_outputs"] = group_outputs + # tdata["executor"] = executor + # task = create_task(tdata) + # return task + + def nonfunctional_usage(callable: Callable): """ This is a decorator for a decorator factory (a function that returns a decorator).