Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more test #378

Merged
merged 5 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/gallery/autogen/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,22 @@ def multiply(x, y):
generate_node_graph(wg.pk)


######################################################################
# One can also set task inputs from an AiiDA process builder directly.
#

from aiida.calculations.arithmetic.add import ArithmeticAddCalculation

builder = ArithmeticAddCalculation.get_builder()
builder.code = code
builder.x = Int(2)
builder.y = Int(3)

wg = WorkGraph("test_set_inputs_from_builder")
add1 = wg.add_task(ArithmeticAddCalculation, name="add1")
add1.set_from_builder(builder)


######################################################################
# Graph builder
# -------------
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph"
"workgraph.aiida_bool" = "aiida_workgraph.properties.builtins:PropertyAiiDABool"
"workgraph.aiida_int_vector" = "aiida_workgraph.properties.builtins:PropertyAiiDAIntVector"
"workgraph.aiida_float_vector" = "aiida_workgraph.properties.builtins:PropertyAiiDAFloatVector"
"workgraph.aiida_aiida_dict" = "aiida_workgraph.properties.builtins:PropertyAiiDADict"
"workgraph.aiida_list" = "aiida_workgraph.properties.builtins:PropertyAiiDAList"
"workgraph.aiida_dict" = "aiida_workgraph.properties.builtins:PropertyAiiDADict"
"workgraph.aiida_structuredata" = "aiida_workgraph.properties.builtins:PropertyStructureData"

[project.entry-points."aiida_workgraph.socket"]
Expand All @@ -138,6 +139,8 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph"
"workgraph.aiida_bool" = "aiida_workgraph.sockets.builtins:SocketAiiDABool"
"workgraph.aiida_int_vector" = "aiida_workgraph.sockets.builtins:SocketAiiDAIntVector"
"workgraph.aiida_float_vector" = "aiida_workgraph.sockets.builtins:SocketAiiDAFloatVector"
"workgraph.aiida_list" = "aiida_workgraph.sockets.builtins:SocketAiiDAList"
"workgraph.aiida_dict" = "aiida_workgraph.sockets.builtins:SocketAiiDADict"
"workgraph.aiida_structuredata" = "aiida_workgraph.sockets.builtins:SocketStructureData"


Expand Down
2 changes: 2 additions & 0 deletions src/aiida_workgraph/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER

WORKGRAPH_EXTRA_KEY = "_workgraph"


def load_config() -> dict:
"""Load the configuration from the config file."""
Expand Down
2 changes: 2 additions & 0 deletions src/aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
orm.Float: "workgraph.aiida_float",
orm.Str: "workgraph.aiida_string",
orm.Bool: "workgraph.aiida_bool",
orm.List: "workgraph.aiida_list",
orm.Dict: "workgraph.aiida_dict",
orm.StructureData: "workgraph.aiida_structuredata",
}

Expand Down
3 changes: 2 additions & 1 deletion src/aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,9 @@ def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None:
def read_wgdata_from_base(self) -> t.Dict[str, t.Any]:
"""Read workgraph data from base.extras."""
from aiida_workgraph.orm.function_data import PickledLocalFunction
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY

wgdata = self.node.base.extras.get("_workgraph")
wgdata = self.node.base.extras.get(WORKGRAPH_EXTRA_KEY)
for name, task in wgdata["tasks"].items():
wgdata["tasks"][name] = deserialize_unsafe(task)
for _, input in wgdata["tasks"][name]["inputs"].items():
Expand Down
12 changes: 12 additions & 0 deletions src/aiida_workgraph/properties/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ def validate(self, value: any) -> None:
)


class PropertyAiiDAList(TaskProperty):
"""A new class for List type."""

identifier: str = "workgraph.aiida_list"
allowed_types = (list, orm.List, str, type(None))

def set_value(self, value: Union[list, orm.List, str] = None) -> None:
if isinstance(value, (list)):
value = orm.List(list=value)
super().set_value(value)


class PropertyAiiDADict(TaskProperty):
"""A new class for Dict type."""

Expand Down
14 changes: 14 additions & 0 deletions src/aiida_workgraph/sockets/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ class SocketAiiDABool(TaskSocket):
property_identifier: str = "workgraph.aiida_bool"


class SocketAiiDAList(TaskSocket):
"""AiiDAList socket."""

identifier: str = "workgraph.aiida_list"
property_identifier: str = "workgraph.aiida_list"


class SocketAiiDADict(TaskSocket):
"""AiiDADict socket."""

identifier: str = "workgraph.aiida_dict"
property_identifier: str = "workgraph.aiida_dict"


class SocketAiiDAIntVector(TaskSocket):
"""Socket with a AiiDAIntVector property."""

Expand Down
17 changes: 14 additions & 3 deletions src/aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,25 @@ def set_context(self, context: Dict[str, Any]) -> None:
raise ValueError(msg)
self.context_mapping.update(context)

def set_from_builder(self, builder: Any) -> None:
"""Set the task inputs from a AiiDA ProcessBuilder."""
from aiida_workgraph.utils import get_dict_from_builder

data = get_dict_from_builder(builder)
self.set(data)

def set_from_protocol(self, *args: Any, **kwargs: Any) -> None:
"""Set the task inputs from protocol data."""
from aiida_workgraph.utils import get_executor, get_dict_from_builder
from aiida_workgraph.utils import get_executor

executor = get_executor(self.get_executor())[0]
# check if the executor has the get_builder_from_protocol method
if not hasattr(executor, "get_builder_from_protocol"):
raise AttributeError(
f"Executor {executor.__name__} does not have the get_builder_from_protocol method."
)
builder = executor.get_builder_from_protocol(*args, **kwargs)
data = get_dict_from_builder(builder)
self.set(data)
self.set_from_builder(builder)

@classmethod
def new(
Expand Down
3 changes: 2 additions & 1 deletion src/aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,11 @@ def get_workgraph_data(process: Union[int, orm.Node]) -> Optional[Dict[str, Any]
"""Get the workgraph data from the process node."""
from aiida.orm.utils.serialize import deserialize_unsafe
from aiida.orm import load_node
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY

if isinstance(process, int):
process = load_node(process)
wgdata = process.base.extras.get("_workgraph", None)
wgdata = process.base.extras.get(WORKGRAPH_EXTRA_KEY, None)
if wgdata is None:
return
for name, task in wgdata["tasks"].items():
Expand Down
7 changes: 4 additions & 3 deletions src/aiida_workgraph/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# import datetime
from aiida.orm import ProcessNode
from aiida.orm.utils.serialize import serialize, deserialize_unsafe
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY


class WorkGraphSaver:
Expand Down Expand Up @@ -223,7 +224,7 @@ def insert_workgraph_to_db(self) -> None:
# nodes is a copy of tasks, so we need to pop it out
self.wgdata["error_handlers"] = serialize(self.wgdata["error_handlers"])
self.wgdata["context"] = serialize(self.wgdata["context"])
self.process.base.extras.set("_workgraph", self.wgdata)
self.process.base.extras.set(WORKGRAPH_EXTRA_KEY, self.wgdata)

def save_task_states(self) -> Dict:
"""Get task states."""
Expand Down Expand Up @@ -277,7 +278,7 @@ def get_wgdata_from_db(
) -> Optional[Dict]:

process = self.process if process is None else process
wgdata = process.base.extras.get("_workgraph", None)
wgdata = process.base.extras.get(WORKGRAPH_EXTRA_KEY, None)
if wgdata is None:
print("No workgraph data found in the process node.")
return
Expand Down Expand Up @@ -318,7 +319,7 @@ def exist_in_db(self) -> bool:
Returns:
bool: _description_
"""
if self.process.base.extras.get("_workgraph", None) is not None:
if self.process.base.extras.get(WORKGRAPH_EXTRA_KEY, None) is not None:
return True
return False

Expand Down
3 changes: 1 addition & 2 deletions src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,7 @@ def load(cls, pk: int) -> Optional["WorkGraph"]:
process = aiida.orm.load_node(pk)
wgdata = get_workgraph_data(process)
if wgdata is None:
print("No workgraph data found in the process node.")
return
raise ValueError(f"WorkGraph data not found for process {pk}.")
wg = cls.from_dict(wgdata)
wg.process = process
wg.update()
Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def add_code(fixture_localhost):
from aiida.orm import InstalledCode

code = InstalledCode(
label="add", computer=fixture_localhost, filepath_executable="/bin/bash"
label="add",
computer=fixture_localhost,
filepath_executable="/bin/bash",
default_calc_job_plugin="arithmetic.add",
)
code.store()
return code
Expand Down
6 changes: 6 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def test_load_config():
from aiida_workgraph.config import load_config

config = load_config()
assert isinstance(config, dict)
assert config == {}
1 change: 1 addition & 0 deletions tests/test_error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def handle_negative_sum(task: Task):
}
},
)
assert len(wg.error_handlers) == 1
wg.submit(
inputs={
"add1": {"code": add_code, "x": orm.Int(1), "y": orm.Int(-2)},
Expand Down
12 changes: 9 additions & 3 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
(orm.Str, "abc", "workgraph.aiida_string"),
(orm.Bool, True, "workgraph.aiida_bool"),
(orm.Bool, "{{variable}}", "workgraph.aiida_bool"),
(orm.List, [1, 2, 3], "workgraph.aiida_list"),
(orm.Dict, {"a": 1}, "workgraph.aiida_dict"),
),
)
def test_type_mapping(data_type, data, identifier) -> None:
Expand Down Expand Up @@ -46,10 +48,14 @@ def test_vector_socket() -> None:
"vector2d",
property_data={"size": 2, "default": [1, 2]},
)
try:
assert t.inputs["vector2d"].property.get_metadata() == {
"size": 2,
"default": [1, 2],
}
with pytest.raises(ValueError, match="Invalid size: Expected 2, got 3 instead."):
t.inputs["vector2d"].value = [1, 2, 3]
except Exception as e:
assert "Invalid size: Expected 2, got 3 instead." in str(e)
with pytest.raises(ValueError, match="Invalid item type: Expected "):
t.inputs["vector2d"].value = [1.1, 2.2]


def test_aiida_data_socket() -> None:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,24 @@ def test_set_inputs(decorated_add: Callable) -> None:
]
is False
)


def test_set_inputs_from_builder(add_code) -> None:
"""Test setting inputs of a task from a builder function."""
from aiida.calculations.arithmetic.add import ArithmeticAddCalculation

wg = WorkGraph(name="test_set_inputs_from_builder")
add1 = wg.add_task(ArithmeticAddCalculation, "add1")
# create the builder
builder = add_code.get_builder()
builder.x = 1
builder.y = 2
add1.set_from_builder(builder)
assert add1.inputs["x"].value == 1
assert add1.inputs["y"].value == 2
assert add1.inputs["code"].value == add_code
with pytest.raises(
AttributeError,
match=f"Executor {ArithmeticAddCalculation.__name__} does not have the get_builder_from_protocol method.",
):
add1.set_from_protocol(code=add_code, protocol="fast")
33 changes: 32 additions & 1 deletion tests/test_workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,29 @@ def test_add_task():
assert len(wg.links) == 1


def test_show_state(wg_calcfunction):
from io import StringIO
import sys

# Redirect stdout to capture prints
captured_output = StringIO()
sys.stdout = captured_output
# Call the method
wg_calcfunction.name = "test_show_state"
wg_calcfunction.show()
# Reset stdout
sys.stdout = sys.__stdout__
# Check the output
output = captured_output.getvalue()
assert "WorkGraph: test_show_state, PK: None, State: CREATED" in output
assert "sumdiff1" in output
assert "PLANNED" in output


def test_save_load(wg_calcfunction):
"""Save the workgraph"""
from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY

wg = wg_calcfunction
wg.name = "test_save_load"
wg.save()
Expand All @@ -38,6 +59,12 @@ def test_save_load(wg_calcfunction):
assert wg.process.label == "test_save_load"
wg2 = WorkGraph.load(wg.process.pk)
assert len(wg.tasks) == len(wg2.tasks)
# remove the extra
wg.process.base.extras.delete(WORKGRAPH_EXTRA_KEY)
with pytest.raises(
ValueError, match=f"WorkGraph data not found for process {wg.process.pk}."
):
WorkGraph.load(wg.process.pk)


def test_organize_nested_inputs():
Expand Down Expand Up @@ -86,7 +113,7 @@ def test_reset_message(wg_calcjob):
assert "Action: reset. {'add2'}" in report


def test_restart(wg_calcfunction):
def test_restart_and_reset(wg_calcfunction):
"""Restart from a finished workgraph.
Load the workgraph, modify the task, and restart the workgraph.
Only the modified node and its child tasks will be rerun."""
Expand All @@ -109,6 +136,10 @@ def test_restart(wg_calcfunction):
assert wg1.tasks["sumdiff2"].node.pk != wg.tasks["sumdiff2"].pk
assert wg1.tasks["sumdiff3"].node.pk != wg.tasks["sumdiff3"].pk
assert wg1.tasks["sumdiff3"].node.outputs.sum == 19
wg1.reset()
assert wg1.process is None
assert wg1.tasks["sumdiff3"].process is None
assert wg1.tasks["sumdiff3"].state == "PLANNED"


def test_extend_workgraph(decorated_add_multiply_group):
Expand Down
4 changes: 4 additions & 0 deletions tests/widget/test_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def test_workgraph_widget(wg_calcfunction):
# to_html
data = wg.to_html()
assert isinstance(data, IFrame)
# check _repr_mimebundle_ is working
data = wg._repr_mimebundle_()


def test_workgraph_task(wg_calcfunction):
Expand All @@ -26,3 +28,5 @@ def test_workgraph_task(wg_calcfunction):
# to html
data = wg.tasks["sumdiff2"].to_html()
assert isinstance(data, IFrame)
# check _repr_mimebundle_ is working
data = wg.tasks["sumdiff2"]._repr_mimebundle_()
Loading