Skip to content

Commit

Permalink
Refactor task's data, bumps node-graph to 0.0.19 (#304)
Browse files Browse the repository at this point in the history
This PR bumps `node-graph` to 0.0.19, so that we don't need to pickle the `node-class`:
1) inputs and outputs to dict
2) keep property inside input
3) use the `module` and `name` for the node_class and executor
  • Loading branch information
superstar54 authored and agoscinski committed Sep 19, 2024
1 parent a24b6b7 commit 8047817
Show file tree
Hide file tree
Showing 24 changed files with 341 additions and 236 deletions.
147 changes: 78 additions & 69 deletions aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def create_task(tdata):
from node_graph.decorator import create_node

tdata["type_mapping"] = type_mapping
tdata["node_type"] = tdata.pop("task_type")
tdata["metadata"]["node_type"] = tdata["metadata"].pop("task_type")
return create_node(tdata)


Expand Down Expand Up @@ -132,10 +132,10 @@ def build_task(
return build_task_from_workgraph(executor)
elif isinstance(executor, str):
(
path,
module,
executor_name,
) = executor.rsplit(".", 1)
executor, _ = get_executor({"path": path, "name": executor_name})
executor, _ = get_executor({"module": module, "name": executor_name})
if callable(executor):
return build_task_from_callable(executor, inputs=inputs, outputs=outputs)

Expand All @@ -162,24 +162,26 @@ def build_task_from_callable(
and issubclass(executor, Task)
):
return executor
tdata = {}
tdata = {"metadata": {}}
if inspect.isfunction(executor):
# calcfunction and workfunction
if getattr(executor, "node_class", False):
tdata["task_type"] = task_types.get(executor.node_class, "NORMAL")
tdata["metadata"]["task_type"] = task_types.get(
executor.node_class, "NORMAL"
)
tdata["executor"] = executor
return build_task_from_AiiDA(tdata, inputs=inputs, outputs=outputs)[0]
else:
tdata["task_type"] = "NORMAL"
tdata["metadata"]["task_type"] = "NORMAL"
tdata["executor"] = executor
return build_task_from_function(executor, inputs=inputs, outputs=outputs)
else:
if issubclass(executor, CalcJob):
tdata["task_type"] = "CALCJOB"
tdata["metadata"]["task_type"] = "CALCJOB"
tdata["executor"] = executor
return build_task_from_AiiDA(tdata, inputs=inputs, outputs=outputs)[0]
elif issubclass(executor, WorkChain):
tdata["task_type"] = "WORKCHAIN"
tdata["metadata"]["task_type"] = "WORKCHAIN"
tdata["executor"] = executor
return build_task_from_AiiDA(tdata, inputs=inputs, outputs=outputs)[0]
raise ValueError("The executor is not supported.")
Expand All @@ -203,9 +205,8 @@ def build_task_from_AiiDA(
) -> Task:
"""Register a task from a AiiDA component.
For example: CalcJob, WorkChain, CalcFunction, WorkFunction."""
from aiida_workgraph.task import Task

# print(executor)
tdata.setdefault("metadata", {})
inputs = [] if inputs is None else inputs
outputs = [] if outputs is None else outputs
executor = tdata["executor"]
Expand Down Expand Up @@ -250,24 +251,24 @@ def build_task_from_AiiDA(
tdata["identifier"] = tdata.pop("identifier", tdata["executor"].__name__)
tdata["executor"] = {
"executor": pickle.dumps(executor),
"type": tdata["task_type"],
"type": tdata["metadata"]["task_type"],
"is_pickle": True,
}
if tdata["task_type"].upper() in ["CALCFUNCTION", "WORKFUNCTION"]:
if tdata["metadata"]["task_type"].upper() in ["CALCFUNCTION", "WORKFUNCTION"]:
outputs = (
[{"identifier": "workgraph.any", "name": "result"}]
if not outputs
else outputs
)
# build executor from the function
tdata["executor"] = PickledFunction.build_executor(executor)
# tdata["executor"]["type"] = tdata["task_type"]
# tdata["executor"]["type"] = tdata["metadata"]["task_type"]
# print("kwargs: ", kwargs)
# add built-in sockets
outputs.append({"identifier": "workgraph.any", "name": "_outputs"})
outputs.append({"identifier": "workgraph.any", "name": "_wait"})
inputs.append({"identifier": "workgraph.any", "name": "_wait", "link_limit": 1e6})
tdata["node_class"] = Task
tdata["metadata"]["node_class"] = {"module": "aiida_workgraph.task", "name": "Task"}
tdata["args"] = args
tdata["kwargs"] = kwargs
tdata["inputs"] = inputs
Expand All @@ -280,7 +281,6 @@ def build_task_from_AiiDA(
def build_pythonjob_task(func: Callable) -> Task:
"""Build PythonJob task from function."""
from aiida_workgraph.calculations.python import PythonJob
from aiida_workgraph.tasks.pythonjob import PythonJob as PythonJobTask
from copy import deepcopy

# if the function is not a task, build a task from the function
Expand All @@ -290,42 +290,46 @@ def build_pythonjob_task(func: Callable) -> Task:
raise ValueError(
"GraphBuilder task cannot be run remotely. Please remove 'PythonJob'."
)
tdata = {"executor": PythonJob, "task_type": "CALCJOB"}
tdata = {
"metadata": {"task_type": "PYTHONJOB"},
"executor": PythonJob,
}
_, tdata_py = build_task_from_AiiDA(tdata)
tdata = deepcopy(func.tdata)
# merge the inputs and outputs from the PythonJob task to the function task
# skip the already existed inputs and outputs
inputs = tdata["inputs"]
inputs.extend(
[
{"identifier": "workgraph.string", "name": "computer"},
{"identifier": "workgraph.string", "name": "code_label"},
{"identifier": "workgraph.string", "name": "code_path"},
{"identifier": "workgraph.string", "name": "prepend_text"},
]
)
outputs = tdata["outputs"]
for input in tdata_py["inputs"]:
if input not in inputs:
inputs.append(input)
for output in tdata_py["outputs"]:
if output not in outputs:
outputs.append(output)
outputs.append({"identifier": "workgraph.any", "name": "exit_code"})
for input in [
{"identifier": "workgraph.string", "name": "computer"},
{"identifier": "workgraph.string", "name": "code_label"},
{"identifier": "workgraph.string", "name": "code_path"},
{"identifier": "workgraph.string", "name": "prepend_text"},
]:
input["list_index"] = len(tdata["inputs"]) + 1
tdata["inputs"][input["name"]] = input
for name, input in tdata_py["inputs"].items():
if name not in tdata["inputs"]:
input["list_index"] = len(tdata["inputs"]) + 1
tdata["inputs"][name] = input
for name, output in tdata_py["outputs"].items():
if name not in tdata["outputs"]:
output["list_index"] = len(tdata["outputs"]) + 1
tdata["outputs"][name] = output
for output in [{"identifier": "workgraph.any", "name": "exit_code"}]:
output["list_index"] = len(tdata["outputs"]) + 1
tdata["outputs"][output["name"]] = output
# change "copy_files" link_limit to 1e6
for input in inputs:
if input["name"] == "copy_files":
input["link_limit"] = 1e6
tdata["inputs"]["copy_files"]["link_limit"] = 1e6
# append the kwargs of the PythonJob task to the function task
kwargs = tdata["kwargs"]
kwargs.extend(["computer", "code_label", "code_path", "prepend_text"])
kwargs.extend(tdata_py["kwargs"])
tdata["inputs"] = inputs
tdata["outputs"] = outputs
tdata["kwargs"] = kwargs
tdata["task_type"] = "PYTHONJOB"
tdata["metadata"]["task_type"] = "PYTHONJOB"
tdata["identifier"] = "workgraph.pythonjob"
tdata["node_class"] = PythonJobTask
tdata["metadata"]["node_class"] = {
"module": "aiida_workgraph.tasks.pythonjob",
"name": "PythonJob",
}
task = create_task(tdata)
task.is_aiida_component = True
return task, tdata
Expand All @@ -339,7 +343,10 @@ def build_shelljob_task(
from aiida_shell.parsers.shell import ShellParser
from node_graph.socket import NodeSocket

tdata = {"executor": ShellJob, "task_type": "SHELLJOB"}
tdata = {
"metadata": {"task_type": "SHELLJOB"},
"executor": ShellJob,
}
_, tdata = build_task_from_AiiDA(tdata)
# create input sockets for the nodes, if it is linked other sockets
links = {}
Expand All @@ -354,16 +361,17 @@ def build_shelljob_task(
# Output socket itself is not a value, so we remove the key from the nodes
nodes.pop(key)
for input in inputs:
if input not in tdata["inputs"]:
tdata["inputs"].append(input)
if input["name"] not in tdata["inputs"]:
input["list_index"] = len(tdata["inputs"]) + 1
tdata["inputs"][input["name"]] = input
tdata["kwargs"].append(input["name"])
# Extend the outputs
tdata["outputs"].extend(
[
{"identifier": "workgraph.any", "name": "stdout"},
{"identifier": "workgraph.any", "name": "stderr"},
]
)
for output in [
{"identifier": "workgraph.any", "name": "stdout"},
{"identifier": "workgraph.any", "name": "stderr"},
]:
output["list_index"] = len(tdata["outputs"]) + 1
tdata["outputs"][output["name"]] = output
outputs = [] if outputs is None else outputs
parser_outputs = [] if parser_outputs is None else parser_outputs
outputs = [
Expand All @@ -373,29 +381,29 @@ def build_shelljob_task(
outputs.extend(parser_outputs)
# add user defined outputs
for output in outputs:
if output not in tdata["outputs"]:
tdata["outputs"].append(output)
if output["name"] not in tdata["outputs"]:
output["list_index"] = len(tdata["outputs"]) + 1
tdata["outputs"][output["name"]] = output
#
tdata["identifier"] = "ShellJob"
tdata["inputs"].extend(
[
{"identifier": "workgraph.any", "name": "command"},
{"identifier": "workgraph.any", "name": "resolve_command"},
]
)
for input in [
{"identifier": "workgraph.any", "name": "command"},
{"identifier": "workgraph.any", "name": "resolve_command"},
]:
input["list_index"] = len(tdata["inputs"]) + 1
tdata["inputs"][input["name"]] = input
tdata["kwargs"].extend(["command", "resolve_command"])
tdata["task_type"] = "SHELLJOB"
tdata["metadata"]["task_type"] = "SHELLJOB"
task = create_task(tdata)
task.is_aiida_component = True
return task, tdata, links


def build_task_from_workgraph(wg: any) -> Task:
"""Build task from workgraph."""
from aiida_workgraph.task import Task
from aiida.orm.utils.serialize import serialize

tdata = {"task_type": "workgraph"}
tdata = {"metadata": {"task_type": "workgraph"}}
inputs = []
outputs = []
group_outputs = []
Expand All @@ -407,15 +415,15 @@ def build_task_from_workgraph(wg: any) -> Task:
if socket.name == "_wait":
continue
inputs.append(
{"identifier": "workgraph.any", "name": f"{task.name}.{socket.name}"}
{"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"}
)
# outputs
outputs.append({"identifier": "workgraph.any", "name": f"{task.name}"})
for socket in task.outputs:
if socket.name in ["_wait", "_outputs"]:
continue
outputs.append(
{"identifier": "workgraph.any", "name": f"{task.name}.{socket.name}"}
{"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"}
)
group_outputs.append(
{
Expand All @@ -428,16 +436,16 @@ def build_task_from_workgraph(wg: any) -> Task:
outputs.append({"identifier": "workgraph.any", "name": "_outputs"})
outputs.append({"identifier": "workgraph.any", "name": "_wait"})
inputs.append({"identifier": "workgraph.any", "name": "_wait", "link_limit": 1e6})
tdata["node_class"] = Task
tdata["metadata"]["node_class"] = {"module": "aiida_workgraph.task", "name": "Task"}
tdata["kwargs"] = kwargs
tdata["inputs"] = inputs
tdata["outputs"] = outputs
tdata["identifier"] = wg.name
executor = {
"path": "aiida_workgraph.engine.workgraph",
"module": "aiida_workgraph.engine.workgraph",
"name": "WorkGraphEngine",
"wgdata": serialize(wg.to_dict(store_nodes=True)),
"type": tdata["task_type"],
"type": tdata["metadata"]["task_type"],
"is_pickle": False,
}
tdata["executor"] = executor
Expand Down Expand Up @@ -490,7 +498,6 @@ def generate_tdata(
) -> Dict[str, Any]:
"""Generate task data for creating a task."""
from node_graph.decorator import generate_input_sockets
from aiida_workgraph.task import Task

args, kwargs, var_args, var_kwargs, _inputs = generate_input_sockets(
func, inputs, properties, type_mapping=type_mapping
Expand All @@ -501,17 +508,19 @@ def generate_tdata(
task_outputs.append({"identifier": "workgraph.any", "name": "_wait"})
task_outputs.append({"identifier": "workgraph.any", "name": "_outputs"})
tdata = {
"node_class": Task,
"identifier": identifier,
"args": args,
"kwargs": kwargs,
"var_args": var_args,
"var_kwargs": var_kwargs,
"task_type": task_type,
"metadata": {
"task_type": task_type,
"catalog": catalog,
"node_class": {"module": "aiida_workgraph.task", "name": "Task"},
},
"properties": properties,
"inputs": _inputs,
"outputs": task_outputs,
"catalog": catalog,
}
tdata["executor"] = PickledFunction.build_executor(func)
if additional_data:
Expand Down
33 changes: 25 additions & 8 deletions aiida_workgraph/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple:
"""Prepare the inputs for WorkGraph task"""
from aiida_workgraph.utils import merge_properties
from aiida_workgraph.utils import merge_properties, serialize_properties
from aiida.orm.utils.serialize import deserialize_unsafe

wgdata = deserialize_unsafe(task["executor"]["wgdata"])
Expand All @@ -16,9 +16,12 @@ def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple:
# because kwargs is updated using update_nested_dict_with_special_keys
# which means the data is grouped by the task name
for socket_name, value in data.items():
wgdata["tasks"][task_name]["properties"][socket_name]["value"] = value
wgdata["tasks"][task_name]["inputs"][socket_name]["property"][
"value"
] = value
# merge the properties
merge_properties(wgdata)
serialize_properties(wgdata)
metadata = {"call_link_label": task["name"]}
inputs = {"wg": wgdata, "metadata": metadata}
return inputs, wgdata
Expand All @@ -31,14 +34,22 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict:

# get the names kwargs for the PythonJob, which are the inputs before _wait
function_kwargs = {}
for input in task["inputs"]:
if input["name"] == "_wait":
# TODO better way to find the function_kwargs
input_names = [
name
for name, _ in sorted(
((name, input["list_index"]) for name, input in task["inputs"].items()),
key=lambda x: x[1],
)
]
for name in input_names:
if name == "_wait":
break
function_kwargs[input["name"]] = kwargs.pop(input["name"], None)
function_kwargs[name] = kwargs.pop(name, None)
# if the var_kwargs is not None, we need to pop the var_kwargs from the kwargs
# then update the function_kwargs if var_kwargs is not None
if task["metadata"]["var_kwargs"] is not None:
function_kwargs.pop(task["metadata"]["var_kwargs"], None)
if task["var_kwargs"] is not None:
function_kwargs.pop(task["var_kwargs"], None)
if var_kwargs:
# var_kwargs can be AttributeDict if it get data from the previous task output
if isinstance(var_kwargs, (dict, AttributeDict)):
Expand Down Expand Up @@ -88,7 +99,13 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict:
+ task["executor"]["function_source_code_without_decorator"]
)
# outputs
function_outputs = task["outputs"]
function_outputs = [
output
for output, _ in sorted(
((output, output["list_index"]) for output in task["outputs"].values()),
key=lambda x: x[1],
)
]
# serialize the kwargs into AiiDA Data
function_kwargs = serialize_to_aiida_nodes(function_kwargs)
# transfer the args to kwargs
Expand Down
Loading

0 comments on commit 8047817

Please sign in to comment.