Skip to content

Commit

Permalink
Refactor engine: prepare inputs for PythonJob, ShellJob, WorkGraph (#93)
Browse files Browse the repository at this point in the history
* move run nodes (pythonjob, workgraph) to utils
*add ShellJob as a built-in type in the engine.
  • Loading branch information
superstar54 authored May 29, 2024
1 parent 9997068 commit df22cf7
Show file tree
Hide file tree
Showing 11 changed files with 719 additions and 697 deletions.
4 changes: 4 additions & 0 deletions aiida_workgraph/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def new(
from aiida_workgraph.decorator import (
build_node_from_callable,
build_PythonJob_node,
build_ShellJob_node,
)

# build the node on the fly if the identifier is a callable
Expand All @@ -34,6 +35,9 @@ def new(
if isinstance(identifier, str) and identifier.upper() == "PYTHONJOB":
# copy the inputs and outputs from the function node to the PythonJob node
identifier, _ = build_PythonJob_node(kwargs.pop("function"))
if isinstance(identifier, str) and identifier.upper() == "SHELLJOB":
# copy the inputs and outputs from the function node to the SHELLJob node
identifier, _ = build_ShellJob_node(kwargs.pop("add_outputs", None))
# Call the original new method
return super().new(identifier, name, uuid, **kwargs)

Expand Down
29 changes: 29 additions & 0 deletions aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,34 @@ def build_PythonJob_node(func: Callable) -> Node:
return node, ndata


def build_ShellJob_node(outputs=None) -> Node:
"""Build PythonJob node from function."""
from aiida_shell.calculations.shell import ShellJob

ndata = {"executor": ShellJob, "node_type": "CALCJOB"}
_, ndata = build_node_from_AiiDA(ndata)
# Extend the outputs
ndata["outputs"].extend([["General", "stdout"], ["General", "stderr"]])
outputs = [] if outputs is None else outputs
# add user defined outputs
for output in outputs:
if output not in ndata["outputs"]:
ndata["outputs"].append(output)
#
ndata["identifier"] = "ShellJob"
ndata["inputs"].extend(
[
["General", "command"],
["General", "resolve_command"],
]
)
ndata["kwargs"].extend(["command", "resolve_command"])
ndata["node_type"] = "SHELLJOB"
node = create_node(ndata)
node.is_aiida_component = True
return node, ndata


def build_node_from_workgraph(wg: any) -> Node:
"""Build node from workgraph."""
from aiida_workgraph.node import Node
Expand Down Expand Up @@ -374,6 +402,7 @@ def serialize_function(func: Callable) -> Dict[str, Any]:
"executor": pickle.dumps(func),
"type": "function",
"is_pickle": True,
"function_name": func.__name__,
"function_source_code": function_source_code,
"import_statements": import_statements,
}
Expand Down
130 changes: 130 additions & 0 deletions aiida_workgraph/engine/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from aiida_workgraph.orm.serializer import serialize_to_aiida_nodes
from aiida import orm


def prepare_for_workgraph_node(node: dict, kwargs: dict) -> tuple:
"""Prepare the inputs for WorkGraph node"""
from aiida_workgraph.utils import merge_properties

print("node type: workgraph.")
wgdata = node["executor"]["wgdata"]
wgdata["name"] = node["name"]
wgdata["metadata"]["group_outputs"] = node["metadata"]["group_outputs"]
# update the workgraph data by kwargs
for node_name, data in kwargs.items():
# because kwargs is updated using update_nested_dict_with_special_keys
# which means the data is grouped by the node name
for socket_name, value in data.items():
wgdata["nodes"][node_name]["properties"][socket_name]["value"] = value
# merge the properties
merge_properties(wgdata)
metadata = {"call_link_label": node["name"]}
inputs = {"wg": wgdata, "metadata": metadata}
return inputs, wgdata


def prepare_for_pythonjob(node: dict, kwargs: dict, var_kwargs: dict) -> dict:
"""Prepare the inputs for PythonJob"""
from aiida_workgraph.utils import get_or_create_code
import os

print("node type: Python.")
# get the names kwargs for the PythonJob, which are the inputs before _wait
function_kwargs = {}
for input in node["inputs"]:
if input["name"] == "_wait":
break
function_kwargs[input["name"]] = kwargs.pop(input["name"], None)
# if the var_kwargs is not None, we need to pop the var_kwargs from the kwargs
# then update the function_kwargs if var_kwargs is not None
if node["metadata"]["var_kwargs"] is not None:
function_kwargs.pop(node["metadata"]["var_kwargs"], None)
if var_kwargs:
function_kwargs.update(var_kwargs.value)
# setup code
code = kwargs.pop("code", None)
computer = kwargs.pop("computer", None)
code_label = kwargs.pop("code_label", None)
code_path = kwargs.pop("code_path", None)
prepend_text = kwargs.pop("prepend_text", None)
upload_files = kwargs.pop("upload_files", {})
new_upload_files = {}
# change the string in the upload files to SingleFileData, or FolderData
for key, source in upload_files.items():
# only alphanumeric and underscores are allowed in the key
# replace all "." with "_dot_"
new_key = key.replace(".", "_dot_")
if isinstance(source, str):
if os.path.isfile(source):
new_upload_files[new_key] = orm.SinglefileData(file=source)
elif os.path.isdir(source):
new_upload_files[new_key] = orm.FolderData(tree=source)
elif isinstance(source, (orm.SinglefileData, orm.FolderData)):
new_upload_files[new_key] = source
else:
raise ValueError(f"Invalid upload file type: {type(source)}, {source}")
#
if code is None:
code = get_or_create_code(
computer=computer if computer else "localhost",
code_label=code_label if code_label else "python3",
code_path=code_path if code_path else None,
prepend_text=prepend_text if prepend_text else None,
)
parent_folder = kwargs.pop("parent_folder", None)
metadata = kwargs.pop("metadata", {})
metadata.update({"call_link_label": node["name"]})
# get the source code of the function
function_name = node["executor"]["function_name"]
function_source_code = (
node["executor"]["import_statements"]
+ "\n"
+ node["executor"]["function_source_code"]
)
# outputs
output_name_list = [output["name"] for output in node["outputs"]]
# serialize the kwargs into AiiDA Data
function_kwargs = serialize_to_aiida_nodes(function_kwargs)
# transfer the args to kwargs
inputs = {
"function_source_code": orm.Str(function_source_code),
"function_name": orm.Str(function_name),
"code": code,
"function_kwargs": function_kwargs,
"upload_files": new_upload_files,
"output_name_list": orm.List(output_name_list),
"parent_folder": parent_folder,
"metadata": metadata,
**kwargs,
}
return inputs


def prepare_for_shelljob(node: dict, kwargs: dict) -> dict:
"""Prepare the inputs for ShellJob"""
from aiida_shell.launch import prepare_code, convert_nodes_single_file_data
from aiida.common import lang
from aiida.orm import AbstractCode

print("node type: ShellJob.")
command = kwargs.pop("command", None)
resolve_command = kwargs.pop("resolve_command", False)
metadata = kwargs.pop("metadata", {})
# setup code
if isinstance(command, str):
computer = (metadata or {}).get("options", {}).pop("computer", None)
code = prepare_code(command, computer, resolve_command)
else:
lang.type_check(command, AbstractCode)
code = command
metadata.update({"call_link_label": node["name"]})
inputs = {
"code": code,
"nodes": convert_nodes_single_file_data(kwargs.pop("nodes", {})),
"filenames": kwargs.pop("filenames", {}),
"arguments": kwargs.pop("arguments", []),
"outputs": kwargs.pop("outputs", []),
"parser": kwargs.pop("parser", None),
"metadata": metadata or {},
}
return inputs
115 changes: 21 additions & 94 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from aiida import orm
from aiida.orm import Node, ProcessNode, WorkChainNode
from aiida.orm.utils import load_node
from aiida_workgraph.orm.serializer import serialize_to_aiida_nodes


from aiida.engine.processes.exit_code import ExitCode
Expand Down Expand Up @@ -578,6 +577,7 @@ def update_node_state(self, name: str) -> None:
"GRAPH_BUILDER",
"WORKGRAPH",
"PYTHONJOB",
"SHELLJOB",
]
and node["state"] == "RUNNING"
):
Expand Down Expand Up @@ -685,6 +685,7 @@ def run_nodes(self, names: t.List[str], continue_workgraph: bool = True) -> None
"GRAPH_BUILDER",
"WORKGRAPH",
"PYTHONJOB",
"SHELLJOB",
]:
if len(self._awaitables) > self.ctx.max_number_awaitables:
print(
Expand Down Expand Up @@ -797,27 +798,10 @@ def run_nodes(self, names: t.List[str], continue_workgraph: bool = True) -> None
self.ctx.nodes[name]["state"] = "RUNNING"
self.to_context(**{name: process})
elif node["metadata"]["node_type"].upper() in ["WORKGRAPH"]:
from aiida_workgraph.utils import merge_properties
from .utils import prepare_for_workgraph_node
from aiida_workgraph.utils.analysis import WorkGraphSaver

print("node type: workgraph.")
wgdata = node["executor"]["wgdata"]
wgdata["name"] = name
wgdata["metadata"]["group_outputs"] = self.ctx.nodes[name]["metadata"][
"group_outputs"
]
# update the workgraph data by kwargs
for node_name, data in kwargs.items():
# because kwargs is updated using update_nested_dict_with_special_keys
# which means the data is grouped by the node name
for socket_name, value in data.items():
wgdata["nodes"][node_name]["properties"][socket_name][
"value"
] = value
# merge the properties
merge_properties(wgdata)
metadata = {"call_link_label": name}
inputs = {"wg": wgdata, "metadata": metadata}
inputs, wgdata = prepare_for_workgraph_node(node, kwargs)
process_inited = WorkGraphEngine(inputs=inputs)
process_inited.runner.persister.save_checkpoint(process_inited)
saver = WorkGraphSaver(process_inited.node, wgdata)
Expand All @@ -829,80 +813,9 @@ def run_nodes(self, names: t.List[str], continue_workgraph: bool = True) -> None
self.to_context(**{name: process})
elif node["metadata"]["node_type"].upper() in ["PYTHONJOB"]:
from aiida_workgraph.calculations.python import PythonJob
from aiida_workgraph.utils import get_or_create_code
import os

print("node type: Python.")
# get the names kwargs for the PythonJob, which are the inputs before _wait
function_kwargs = {}
for input in node["inputs"]:
if input["name"] == "_wait":
break
function_kwargs[input["name"]] = kwargs.pop(input["name"], None)
# if the var_kwargs is not None, we need to pop the var_kwargs from the kwargs
# then update the function_kwargs if var_kwargs is not None
if node["metadata"]["var_kwargs"] is not None:
function_kwargs.pop(node["metadata"]["var_kwargs"], None)
if var_kwargs:
function_kwargs.update(var_kwargs.value)
# setup code
code = kwargs.pop("code", None)
computer = kwargs.pop("computer", None)
code_label = kwargs.pop("code_label", None)
code_path = kwargs.pop("code_path", None)
prepend_text = kwargs.pop("prepend_text", None)
upload_files = kwargs.pop("upload_files", {})
new_upload_files = {}
# change the string in the upload files to SingleFileData, or FolderData
for key, source in upload_files.items():
# only alphanumeric and underscores are allowed in the key
# replace all "." with "_dot_"
new_key = key.replace(".", "_dot_")
if isinstance(source, str):
if os.path.isfile(source):
new_upload_files[new_key] = orm.SinglefileData(file=source)
elif os.path.isdir(source):
new_upload_files[new_key] = orm.FolderData(tree=source)
elif isinstance(source, (orm.SinglefileData, orm.FolderData)):
new_upload_files[new_key] = source
else:
raise ValueError(
f"Invalid upload file type: {type(source)}, {source}"
)
#
if code is None:
code = get_or_create_code(
computer=computer if computer else "localhost",
code_label=code_label if code_label else "python3",
code_path=code_path if code_path else None,
prepend_text=prepend_text if prepend_text else None,
)
parent_folder = kwargs.pop("parent_folder", None)
metadata = kwargs.pop("metadata", {})
metadata.update({"call_link_label": name})
# get the source code of the function
function_name = executor.__name__
function_source_code = (
node["executor"]["import_statements"]
+ "\n"
+ node["executor"]["function_source_code"]
)
# outputs
output_name_list = [output["name"] for output in node["outputs"]]
# serialize the kwargs into AiiDA Data
function_kwargs = serialize_to_aiida_nodes(function_kwargs)
# transfer the args to kwargs
inputs = {
"function_source_code": orm.Str(function_source_code),
"function_name": orm.Str(function_name),
"code": code,
"function_kwargs": function_kwargs,
"upload_files": new_upload_files,
"output_name_list": orm.List(output_name_list),
"parent_folder": parent_folder,
"metadata": metadata,
**kwargs,
}
from .utils import prepare_for_pythonjob

inputs = prepare_for_pythonjob(node, kwargs, var_kwargs)
# since aiida 2.5.0, we can pass inputs directly to the submit, no need to use **inputs
process = self.submit(
PythonJob,
Expand All @@ -912,6 +825,20 @@ def run_nodes(self, names: t.List[str], continue_workgraph: bool = True) -> None
node["process"] = process
self.ctx.nodes[name]["state"] = "RUNNING"
self.to_context(**{name: process})
elif node["metadata"]["node_type"].upper() in ["SHELLJOB"]:
from aiida_shell.calculations.shell import ShellJob
from .utils import prepare_for_shelljob

inputs = prepare_for_shelljob(node, kwargs)
# since aiida 2.5.0, we can pass inputs directly to the submit, no need to use **inputs
process = self.submit(
ShellJob,
**inputs,
)
process.label = name
node["process"] = process
self.ctx.nodes[name]["state"] = "RUNNING"
self.to_context(**{name: process})
elif node["metadata"]["node_type"].upper() in ["NORMAL"]:
print("node type: Normal.")
# normal function does not have a process
Expand Down
3 changes: 1 addition & 2 deletions aiida_workgraph/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from node_graph.utils import get_entries
from .builtin import AiiDAGather, AiiDAToCtx, AiiDAFromCtx, AiiDAShell
from .builtin import AiiDAGather, AiiDAToCtx, AiiDAFromCtx
from .test import (
AiiDAInt,
AiiDAFloat,
Expand All @@ -23,7 +23,6 @@
AiiDAGather,
AiiDAToCtx,
AiiDAFromCtx,
AiiDAShell,
AiiDAInt,
AiiDAFloat,
AiiDAString,
Expand Down
Loading

0 comments on commit df22cf7

Please sign in to comment.