Skip to content

Commit

Permalink
bump node-graph to 0.1.4
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 4, 2024
1 parent b182af0 commit 8f35b5d
Show file tree
Hide file tree
Showing 13 changed files with 174 additions and 143 deletions.
3 changes: 0 additions & 3 deletions docs/gallery/concept/autogen/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,12 @@ class MyAdd(Task):
"module": "aiida_workgraph.executors.test",
"name": "add",
}
kwargs = ["x", "y"]

def create_sockets(self):
self.inputs.clear()
self.outputs.clear()
inp = self.inputs.new("workgraph.Any", "x")
inp.add_property("workgraph.Any", "x", default=0.0)
inp = self.inputs.new("workgraph.Any", "y")
inp.add_property("workgraph.Any", "y", default=0.0)
self.outputs.new("workgraph.Any", "sum")


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"numpy~=1.21",
"scipy",
"ase",
"node-graph==0.1.3",
"node-graph==0.1.4",
"node-graph-widget",
"aiida-core>=2.3",
"cloudpickle",
Expand Down
115 changes: 58 additions & 57 deletions src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def create_task(tdata):
def add_input_recursive(
inputs: List[List[Union[str, Dict[str, Any]]]],
port: PortNamespace,
args: List,
kwargs: List,
prefix: Optional[str] = None,
required: bool = True,
) -> List[List[Union[str, Dict[str, Any]]]]:
Expand All @@ -68,17 +66,13 @@ def add_input_recursive(
{
"identifier": "workgraph.namespace",
"name": port_name,
"arg_type": "kwargs",
"metadata": {"required": required, "dynamic": port.dynamic},
"property": {"identifier": "workgraph.any", "default": None},
}
)
if required:
args.append(port_name)
else:
kwargs.append(port_name)
for value in port.values():
add_input_recursive(
inputs, value, args, kwargs, prefix=port_name, required=required
)
add_input_recursive(inputs, value, prefix=port_name, required=required)
else:
if port_name not in input_names:
# port.valid_type can be a single type or a tuple of types,
Expand All @@ -89,11 +83,14 @@ def add_input_recursive(
socket_type = type_mapping.get(port.valid_type[0], "workgraph.any")
else:
socket_type = type_mapping.get(port.valid_type, "workgraph.any")
inputs.append({"identifier": socket_type, "name": port_name})
if required:
args.append(port_name)
else:
kwargs.append(port_name)
inputs.append(
{
"identifier": socket_type,
"name": port_name,
"arg_type": "kwargs",
"metadata": {"required": required},
}
)
return inputs


Expand All @@ -115,12 +112,24 @@ def add_output_recursive(
# so if you change the value of one port, the value of all the ports of other tasks will be changed
# consider to use None as default value
if port_name not in output_names:
outputs.append({"identifier": "workgraph.namespace", "name": port_name})
outputs.append(
{
"identifier": "workgraph.namespace",
"name": port_name,
"metadata": {"required": required},
}
)
for value in port.values():
add_output_recursive(outputs, value, prefix=port_name, required=required)
else:
if port_name not in output_names:
outputs.append({"identifier": "workgraph.any", "name": port_name})
outputs.append(
{
"identifier": "workgraph.any",
"name": port_name,
"metadata": {"required": required},
}
)
return outputs


Expand Down Expand Up @@ -221,12 +230,9 @@ def build_task_from_AiiDA(
outputs = [] if outputs is None else outputs
executor = tdata["executor"]
spec = executor.spec()
args = []
kwargs = []
user_defined_input_names = [input["name"] for input in inputs]
for _key, port in spec.inputs.ports.items():
add_input_recursive(inputs, port, args, kwargs, required=port.required)
for _key, port in spec.outputs.ports.items():
for _, port in spec.inputs.ports.items():
add_input_recursive(inputs, port, required=port.required)
for _, port in spec.outputs.ports.items():
add_output_recursive(outputs, port, required=port.required)
# Only check this for calcfunction and workfunction
if inspect.isfunction(executor) and spec.inputs.dynamic:
Expand All @@ -237,27 +243,18 @@ def build_task_from_AiiDA(
executor.process_class._var_keyword
or executor.process_class._var_positional
)
tdata["var_kwargs"] = name
# if user already defined the var_args in the inputs, skip it
if name not in [input["name"] for input in inputs]:
inputs.append(
{
"identifier": "workgraph.any",
"name": name,
"arg_type": "var_kwargs",
"metadata": {"dynamic": True},
"property": {"identifier": "workgraph.any", "default": {}},
}
)
# When the input is dyanmic, if user defines some input names does not included in the args and kwargs,
# which means the user define the input names manually, we must add them to the kwargs
for key in user_defined_input_names:
if key not in args and key not in kwargs:
if key == name:
continue
kwargs.append(key)

# TODO In order to reload the WorkGraph from process, "is_pickle" should be True
# so I pickled the function here, but this is not necessary
# we need to update the node_graph to support the path and name of the function

tdata["identifier"] = tdata.pop("identifier", tdata["executor"].__name__)
tdata["executor"] = build_callable(executor)
tdata["executor"]["type"] = tdata["metadata"]["task_type"]
Expand All @@ -270,10 +267,15 @@ def build_task_from_AiiDA(
# add built-in sockets
outputs.append({"identifier": "workgraph.any", "name": "_outputs"})
outputs.append({"identifier": "workgraph.any", "name": "_wait"})
inputs.append({"identifier": "workgraph.any", "name": "_wait", "link_limit": 1e6})
inputs.append(
{
"identifier": "workgraph.any",
"name": "_wait",
"link_limit": 1e6,
"arg_type": "none",
}
)
tdata["metadata"]["node_class"] = {"module": "aiida_workgraph.task", "name": "Task"}
tdata["args"] = args
tdata["kwargs"] = kwargs
tdata["inputs"] = inputs
tdata["outputs"] = outputs
task = create_task(tdata)
Expand Down Expand Up @@ -323,11 +325,6 @@ def build_pythonjob_task(func: Callable) -> Task:
tdata["outputs"][output["name"]] = output
# change "copy_files" link_limit to 1e6
tdata["inputs"]["copy_files"]["link_limit"] = 1e6
# append the kwargs of the PythonJob task to the function task
kwargs = tdata["kwargs"]
kwargs.extend(["computer", "command_info"])
kwargs.extend(tdata_py["kwargs"])
tdata["kwargs"] = kwargs
tdata["metadata"]["task_type"] = "PYTHONJOB"
tdata["identifier"] = "workgraph.pythonjob"
tdata["metadata"]["node_class"] = {
Expand Down Expand Up @@ -369,7 +366,6 @@ def build_shelljob_task(
if input["name"] not in tdata["inputs"]:
input["list_index"] = len(tdata["inputs"]) + 1
tdata["inputs"][input["name"]] = input
tdata["kwargs"].append(input["name"])
# Extend the outputs
for output in [
{"identifier": "workgraph.any", "name": "stdout"},
Expand Down Expand Up @@ -399,7 +395,6 @@ def build_shelljob_task(
]:
input["list_index"] = len(tdata["inputs"]) + 1
tdata["inputs"][input["name"]] = input
tdata["kwargs"].extend(["command", "resolve_command"])
tdata["metadata"]["task_type"] = "SHELLJOB"
task = create_task(tdata)
task.is_aiida_component = True
Expand Down Expand Up @@ -448,13 +443,18 @@ def build_task_from_workgraph(wg: any) -> Task:
"from": f"{task.name}.{socket.name}",
}
)
kwargs = [input["name"] for input in inputs]
# add built-in sockets
outputs.append({"identifier": "workgraph.any", "name": "_outputs"})
outputs.append({"identifier": "workgraph.any", "name": "_wait"})
inputs.append({"identifier": "workgraph.any", "name": "_wait", "link_limit": 1e6})
inputs.append(
{
"identifier": "workgraph.any",
"name": "_wait",
"link_limit": 1e6,
"arg_type": "none",
}
)
tdata["metadata"]["node_class"] = {"module": "aiida_workgraph.task", "name": "Task"}
tdata["kwargs"] = kwargs
tdata["inputs"] = inputs
tdata["outputs"] = outputs
tdata["identifier"] = wg.name
Expand Down Expand Up @@ -518,20 +518,23 @@ def generate_tdata(
"""Generate task data for creating a task."""
from node_graph.decorator import generate_input_sockets

args, kwargs, var_args, var_kwargs, _inputs = generate_input_sockets(
_inputs = generate_input_sockets(
func, inputs, properties, type_mapping=type_mapping
)
task_outputs = outputs
# add built-in sockets
_inputs.append({"identifier": "workgraph.any", "name": "_wait", "link_limit": 1e6})
_inputs.append(
{
"identifier": "workgraph.any",
"name": "_wait",
"link_limit": 1e6,
"arg_type": "none",
}
)
task_outputs.append({"identifier": "workgraph.any", "name": "_wait"})
task_outputs.append({"identifier": "workgraph.any", "name": "_outputs"})
tdata = {
"identifier": identifier,
"args": args,
"kwargs": kwargs,
"var_args": var_args,
"var_kwargs": var_kwargs,
"metadata": {
"task_type": task_type,
"catalog": catalog,
Expand All @@ -552,7 +555,7 @@ def generate_tdata(
class TaskDecoratorCollection:
"""Collection of task decorators."""

# decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor
# decorator with arguments indentifier, properties, inputs, outputs, executor
@staticmethod
@nonfunctional_usage
def decorator_task(
Expand All @@ -569,8 +572,6 @@ def decorator_task(
Attributes:
indentifier (str): task identifier
catalog (str): task catalog
args (list): task args
kwargs (dict): task kwargs
properties (list): task properties
inputs (list): task inputs
outputs (list): task outputs
Expand Down Expand Up @@ -610,7 +611,7 @@ def decorator(func):

return decorator

# decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor
# decorator with arguments indentifier, properties, inputs, outputs, executor
@staticmethod
@nonfunctional_usage
def decorator_graph_builder(
Expand Down
2 changes: 1 addition & 1 deletion src/aiida_workgraph/executors/builtins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def select(condition, true, false):
def select(condition, true=None, false=None):
"""Select the data based on the condition."""
if condition:
return true
Expand Down
4 changes: 2 additions & 2 deletions src/aiida_workgraph/executors/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ async def monitor(function, interval, timeout, *args, **kwargs):
await asyncio.sleep(interval)


def file_monitor(filename: str):
def file_monitor(filepath: str):
"""Check if the file exists."""
import os

return os.path.exists(filename)
return os.path.exists(filepath)


def time_monitor(time: str):
Expand Down
Loading

0 comments on commit 8f35b5d

Please sign in to comment.