Skip to content

Commit

Permalink
Re-introduce working changes after merge
Browse files Browse the repository at this point in the history
  • Loading branch information
GeigerJ2 committed Dec 16, 2024
1 parent 9280f02 commit 4110f48
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 2 deletions.
6 changes: 6 additions & 0 deletions src/aiida_workgraph/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
)
from typing import Any, Callable, Optional, Union

from aiida.engine import ProcessBuilder


class TaskCollection(NodeCollection):
def _new(
Expand All @@ -18,6 +20,7 @@ def _new(
build_pythonjob_task,
build_shelljob_task,
build_task_from_workgraph,
build_task_from_builder,
)
from aiida_workgraph.workgraph import WorkGraph

Expand All @@ -41,6 +44,9 @@ def _new(
return task
if isinstance(identifier, WorkGraph):
identifier = build_task_from_workgraph(identifier)
if isinstance(identifier, ProcessBuilder):
task = build_task_from_builder(identifier)
return task
return super()._new(identifier, name, uuid, **kwargs)


Expand Down
93 changes: 91 additions & 2 deletions src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from aiida_workgraph.utils import get_executor
from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain
from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain, ProcessBuilder
from aiida_workgraph.task import Task
from aiida_workgraph.utils import build_callable, validate_task_inout
from aiida_workgraph.utils import build_callable, validate_task_inout, get_dict_from_builder
import inspect
from aiida_workgraph.config import builtin_inputs, builtin_outputs, task_types
from aiida_workgraph.orm.mapping import type_mapping
Expand Down Expand Up @@ -299,6 +299,95 @@ def build_task_from_workgraph(wg: any) -> Task:
return task


def build_task_from_builder(builder: ProcessBuilder) -> Task:
"""Build task from an aiida-core ProcessBuilder."""
from aiida.orm.utils.serialize import serialize

tdata = {"metadata": {"task_type": "builder"}}
inputs = []
outputs = []
group_outputs = []

process_class = builder._process_class

# executor = get_executor(self.get_executor())[0]
# builder = executor.get_builder_from_protocol(*args, **kwargs)
# TODO: Instantiate AiiDA object from the builder, and pass that, rather than having to manually construct the tdata
# TODO: here again
# data = get_dict_from_builder(builder)

# # data.pop('identifier')
# data = {**data, **tdata}

# data['identifier'] = 'a'
# # tdata["identifier"] = wg.name
# task = Task.from_dict(data=data)
return task

# def add_task(
# self, identifier: Union[str, callable], name: str = None, **kwargs
# ) -> Task:

# self.set(data)

# add all the inputs/outputs from the tasks in the workgraph

# for task in wg.tasks:
# # inputs
# inputs.append(
# {
# "identifier": "workgraph.namespace",
# "name": f"{task.name}",
# }
# )
# for socket in task.inputs:
# if socket.name == "_wait":
# continue
# inputs.append(
# {"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"}
# )
# # outputs
# outputs.append(
# {
# "identifier": "workgraph.namespace",
# "name": f"{task.name}",
# }
# )
# for socket in task.outputs:
# if socket.name in ["_wait", "_outputs"]:
# continue
# outputs.append(
# {"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"}
# )
# group_outputs.append(
# {
# "name": f"{task.name}.{socket.name}",
# "from": f"{task.name}.{socket.name}",
# }
# )
# kwargs = [input["name"] for input in inputs]
# # add built-in sockets
# outputs.append({"identifier": "workgraph.any", "name": "_outputs"})
# outputs.append({"identifier": "workgraph.any", "name": "_wait"})
# inputs.append({"identifier": "workgraph.any", "name": "_wait", "link_limit": 1e6})
# tdata["metadata"]["node_class"] = {"module": "aiida_workgraph.task", "name": "Task"}
# tdata["kwargs"] = kwargs
# tdata["inputs"] = inputs
# tdata["outputs"] = outputs
# tdata["identifier"] = wg.name
# executor = {
# "module": "aiida_workgraph.engine.workgraph",
# "name": "WorkGraphEngine",
# "wgdata": serialize(wg.to_dict(store_nodes=True)),
# "type": tdata["metadata"]["task_type"],
# "is_pickle": False,
# }
# tdata["metadata"]["group_outputs"] = group_outputs
# tdata["executor"] = executor
# task = create_task(tdata)
# return task


def nonfunctional_usage(callable: Callable):
"""
This is a decorator for a decorator factory (a function that returns a decorator).
Expand Down

0 comments on commit 4110f48

Please sign in to comment.