diff --git a/docs/gallery/autogen/quick_start.py b/docs/gallery/autogen/quick_start.py index ad41ad1f..d7f0b7a8 100644 --- a/docs/gallery/autogen/quick_start.py +++ b/docs/gallery/autogen/quick_start.py @@ -357,6 +357,22 @@ def multiply(x, y): generate_node_graph(wg.pk) +###################################################################### +# One can also set task inputs from an AiiDA process builder directly. +# + +from aiida.calculations.arithmetic.add import ArithmeticAddCalculation + +builder = ArithmeticAddCalculation.get_builder() +builder.code = code +builder.x = Int(2) +builder.y = Int(3) + +wg = WorkGraph("test_set_inputs_from_builder") +add1 = wg.add_task(ArithmeticAddCalculation, name="add1") +add1.set_from_builder(builder) + + ###################################################################### # Graph builder # ------------- diff --git a/pyproject.toml b/pyproject.toml index 0ef55569..5659d500 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,8 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" "workgraph.aiida_bool" = "aiida_workgraph.properties.builtins:PropertyAiiDABool" "workgraph.aiida_int_vector" = "aiida_workgraph.properties.builtins:PropertyAiiDAIntVector" "workgraph.aiida_float_vector" = "aiida_workgraph.properties.builtins:PropertyAiiDAFloatVector" -"workgraph.aiida_aiida_dict" = "aiida_workgraph.properties.builtins:PropertyAiiDADict" +"workgraph.aiida_list" = "aiida_workgraph.properties.builtins:PropertyAiiDAList" +"workgraph.aiida_dict" = "aiida_workgraph.properties.builtins:PropertyAiiDADict" "workgraph.aiida_structuredata" = "aiida_workgraph.properties.builtins:PropertyStructureData" [project.entry-points."aiida_workgraph.socket"] @@ -138,6 +139,8 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" "workgraph.aiida_bool" = "aiida_workgraph.sockets.builtins:SocketAiiDABool" "workgraph.aiida_int_vector" = "aiida_workgraph.sockets.builtins:SocketAiiDAIntVector" "workgraph.aiida_float_vector" = "aiida_workgraph.sockets.builtins:SocketAiiDAFloatVector" +"workgraph.aiida_list" = "aiida_workgraph.sockets.builtins:SocketAiiDAList" +"workgraph.aiida_dict" = "aiida_workgraph.sockets.builtins:SocketAiiDADict" "workgraph.aiida_structuredata" = "aiida_workgraph.sockets.builtins:SocketStructureData" diff --git a/src/aiida_workgraph/config.py b/src/aiida_workgraph/config.py index b97fb8cd..519747ec 100644 --- a/src/aiida_workgraph/config.py +++ b/src/aiida_workgraph/config.py @@ -1,6 +1,8 @@ import json from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER +WORKGRAPH_EXTRA_KEY = "_workgraph" + def load_config() -> dict: """Load the configuration from the config file.""" diff --git a/src/aiida_workgraph/decorator.py b/src/aiida_workgraph/decorator.py index c5551b48..2ce1b753 100644 --- a/src/aiida_workgraph/decorator.py +++ b/src/aiida_workgraph/decorator.py @@ -29,6 +29,8 @@ orm.Float: "workgraph.aiida_float", orm.Str: "workgraph.aiida_string", orm.Bool: "workgraph.aiida_bool", + orm.List: "workgraph.aiida_list", + orm.Dict: "workgraph.aiida_dict", orm.StructureData: "workgraph.aiida_structuredata", } diff --git a/src/aiida_workgraph/engine/workgraph.py b/src/aiida_workgraph/engine/workgraph.py index 5b79c27a..6cfe0635 100644 --- a/src/aiida_workgraph/engine/workgraph.py +++ b/src/aiida_workgraph/engine/workgraph.py @@ -308,8 +308,9 @@ def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None: def read_wgdata_from_base(self) -> t.Dict[str, t.Any]: """Read workgraph data from base.extras.""" from aiida_workgraph.orm.function_data import PickledLocalFunction + from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY - wgdata = self.node.base.extras.get("_workgraph") + wgdata = self.node.base.extras.get(WORKGRAPH_EXTRA_KEY) for name, task in wgdata["tasks"].items(): wgdata["tasks"][name] = deserialize_unsafe(task) for _, input in wgdata["tasks"][name]["inputs"].items(): diff --git a/src/aiida_workgraph/properties/builtins.py b/src/aiida_workgraph/properties/builtins.py index 84ef1fb1..04fc87ef 100644 --- a/src/aiida_workgraph/properties/builtins.py +++ b/src/aiida_workgraph/properties/builtins.py @@ -102,6 +102,18 @@ def validate(self, value: any) -> None: ) +class PropertyAiiDAList(TaskProperty): + """A new class for List type.""" + + identifier: str = "workgraph.aiida_list" + allowed_types = (list, orm.List, str, type(None)) + + def set_value(self, value: Union[list, orm.List, str] = None) -> None: + if isinstance(value, (list)): + value = orm.List(list=value) + super().set_value(value) + + class PropertyAiiDADict(TaskProperty): """A new class for Dict type.""" diff --git a/src/aiida_workgraph/sockets/builtins.py b/src/aiida_workgraph/sockets/builtins.py index 9e946982..3d31e5a2 100644 --- a/src/aiida_workgraph/sockets/builtins.py +++ b/src/aiida_workgraph/sockets/builtins.py @@ -71,6 +71,20 @@ class SocketAiiDABool(TaskSocket): property_identifier: str = "workgraph.aiida_bool" +class SocketAiiDAList(TaskSocket): + """AiiDAList socket.""" + + identifier: str = "workgraph.aiida_list" + property_identifier: str = "workgraph.aiida_list" + + +class SocketAiiDADict(TaskSocket): + """AiiDADict socket.""" + + identifier: str = "workgraph.aiida_dict" + property_identifier: str = "workgraph.aiida_dict" + + class SocketAiiDAIntVector(TaskSocket): """Socket with a AiiDAIntVector property.""" diff --git a/src/aiida_workgraph/task.py b/src/aiida_workgraph/task.py index 69ef8885..085c778e 100644 --- a/src/aiida_workgraph/task.py +++ b/src/aiida_workgraph/task.py @@ -81,14 +81,25 @@ def set_context(self, context: Dict[str, Any]) -> None: raise ValueError(msg) self.context_mapping.update(context) + def set_from_builder(self, builder: Any) -> None: + """Set the task inputs from a AiiDA ProcessBuilder.""" + from aiida_workgraph.utils import get_dict_from_builder + + data = get_dict_from_builder(builder) + self.set(data) + def set_from_protocol(self, *args: Any, **kwargs: Any) -> None: """Set the task inputs from protocol data.""" - from aiida_workgraph.utils import get_executor, get_dict_from_builder + from aiida_workgraph.utils import get_executor executor = get_executor(self.get_executor())[0] + # check if the executor has the get_builder_from_protocol method + if not hasattr(executor, "get_builder_from_protocol"): + raise AttributeError( + f"Executor {executor.__name__} does not have the get_builder_from_protocol method." + ) builder = executor.get_builder_from_protocol(*args, **kwargs) - data = get_dict_from_builder(builder) - self.set(data) + self.set_from_builder(builder) @classmethod def new( diff --git a/src/aiida_workgraph/utils/__init__.py b/src/aiida_workgraph/utils/__init__.py index 55d2fb2e..9b6e91d6 100644 --- a/src/aiida_workgraph/utils/__init__.py +++ b/src/aiida_workgraph/utils/__init__.py @@ -322,10 +322,11 @@ def get_workgraph_data(process: Union[int, orm.Node]) -> Optional[Dict[str, Any] """Get the workgraph data from the process node.""" from aiida.orm.utils.serialize import deserialize_unsafe from aiida.orm import load_node + from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY if isinstance(process, int): process = load_node(process) - wgdata = process.base.extras.get("_workgraph", None) + wgdata = process.base.extras.get(WORKGRAPH_EXTRA_KEY, None) if wgdata is None: return for name, task in wgdata["tasks"].items(): diff --git a/src/aiida_workgraph/utils/analysis.py b/src/aiida_workgraph/utils/analysis.py index 477a2f27..f9d43716 100644 --- a/src/aiida_workgraph/utils/analysis.py +++ b/src/aiida_workgraph/utils/analysis.py @@ -3,6 +3,7 @@ # import datetime from aiida.orm import ProcessNode from aiida.orm.utils.serialize import serialize, deserialize_unsafe +from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY class WorkGraphSaver: @@ -223,7 +224,7 @@ def insert_workgraph_to_db(self) -> None: # nodes is a copy of tasks, so we need to pop it out self.wgdata["error_handlers"] = serialize(self.wgdata["error_handlers"]) self.wgdata["context"] = serialize(self.wgdata["context"]) - self.process.base.extras.set("_workgraph", self.wgdata) + self.process.base.extras.set(WORKGRAPH_EXTRA_KEY, self.wgdata) def save_task_states(self) -> Dict: """Get task states.""" @@ -277,7 +278,7 @@ def get_wgdata_from_db( ) -> Optional[Dict]: process = self.process if process is None else process - wgdata = process.base.extras.get("_workgraph", None) + wgdata = process.base.extras.get(WORKGRAPH_EXTRA_KEY, None) if wgdata is None: print("No workgraph data found in the process node.") return @@ -318,7 +319,7 @@ def exist_in_db(self) -> bool: Returns: bool: _description_ """ - if self.process.base.extras.get("_workgraph", None) is not None: + if self.process.base.extras.get(WORKGRAPH_EXTRA_KEY, None) is not None: return True return False diff --git a/src/aiida_workgraph/workgraph.py b/src/aiida_workgraph/workgraph.py index aad7b10e..dfd371bf 100644 --- a/src/aiida_workgraph/workgraph.py +++ b/src/aiida_workgraph/workgraph.py @@ -362,8 +362,7 @@ def load(cls, pk: int) -> Optional["WorkGraph"]: process = aiida.orm.load_node(pk) wgdata = get_workgraph_data(process) if wgdata is None: - print("No workgraph data found in the process node.") - return + raise ValueError(f"WorkGraph data not found for process {pk}.") wg = cls.from_dict(wgdata) wg.process = process wg.update() diff --git a/tests/conftest.py b/tests/conftest.py index 31b2b776..7facdb73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,7 +31,10 @@ def add_code(fixture_localhost): from aiida.orm import InstalledCode code = InstalledCode( - label="add", computer=fixture_localhost, filepath_executable="/bin/bash" + label="add", + computer=fixture_localhost, + filepath_executable="/bin/bash", + default_calc_job_plugin="arithmetic.add", ) code.store() return code diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..17764b89 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,6 @@ +def test_load_config(): + from aiida_workgraph.config import load_config + + config = load_config() + assert isinstance(config, dict) + assert config == {} diff --git a/tests/test_error_handler.py b/tests/test_error_handler.py index a7e7d89d..d1eff81f 100644 --- a/tests/test_error_handler.py +++ b/tests/test_error_handler.py @@ -38,6 +38,7 @@ def handle_negative_sum(task: Task): } }, ) + assert len(wg.error_handlers) == 1 wg.submit( inputs={ "add1": {"code": add_code, "x": orm.Int(1), "y": orm.Int(-2)}, diff --git a/tests/test_socket.py b/tests/test_socket.py index 264ecc67..835ec99b 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -19,6 +19,8 @@ (orm.Str, "abc", "workgraph.aiida_string"), (orm.Bool, True, "workgraph.aiida_bool"), (orm.Bool, "{{variable}}", "workgraph.aiida_bool"), + (orm.List, [1, 2, 3], "workgraph.aiida_list"), + (orm.Dict, {"a": 1}, "workgraph.aiida_dict"), ), ) def test_type_mapping(data_type, data, identifier) -> None: @@ -46,10 +48,14 @@ def test_vector_socket() -> None: "vector2d", property_data={"size": 2, "default": [1, 2]}, ) - try: + assert t.inputs["vector2d"].property.get_metadata() == { + "size": 2, + "default": [1, 2], + } + with pytest.raises(ValueError, match="Invalid size: Expected 2, got 3 instead."): t.inputs["vector2d"].value = [1, 2, 3] - except Exception as e: - assert "Invalid size: Expected 2, got 3 instead." in str(e) + with pytest.raises(ValueError, match="Invalid item type: Expected "): + t.inputs["vector2d"].value = [1.1, 2.2] def test_aiida_data_socket() -> None: diff --git a/tests/test_tasks.py b/tests/test_tasks.py index eafc7145..e8bd8051 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -87,3 +87,24 @@ def test_set_inputs(decorated_add: Callable) -> None: ] is False ) + + +def test_set_inputs_from_builder(add_code) -> None: + """Test setting inputs of a task from a builder function.""" + from aiida.calculations.arithmetic.add import ArithmeticAddCalculation + + wg = WorkGraph(name="test_set_inputs_from_builder") + add1 = wg.add_task(ArithmeticAddCalculation, "add1") + # create the builder + builder = add_code.get_builder() + builder.x = 1 + builder.y = 2 + add1.set_from_builder(builder) + assert add1.inputs["x"].value == 1 + assert add1.inputs["y"].value == 2 + assert add1.inputs["code"].value == add_code + with pytest.raises( + AttributeError, + match=f"Executor {ArithmeticAddCalculation.__name__} does not have the get_builder_from_protocol method.", + ): + add1.set_from_protocol(code=add_code, protocol="fast") diff --git a/tests/test_workgraph.py b/tests/test_workgraph.py index 1fa1366b..f65a7329 100644 --- a/tests/test_workgraph.py +++ b/tests/test_workgraph.py @@ -28,8 +28,29 @@ def test_add_task(): assert len(wg.links) == 1 +def test_show_state(wg_calcfunction): + from io import StringIO + import sys + + # Redirect stdout to capture prints + captured_output = StringIO() + sys.stdout = captured_output + # Call the method + wg_calcfunction.name = "test_show_state" + wg_calcfunction.show() + # Reset stdout + sys.stdout = sys.__stdout__ + # Check the output + output = captured_output.getvalue() + assert "WorkGraph: test_show_state, PK: None, State: CREATED" in output + assert "sumdiff1" in output + assert "PLANNED" in output + + def test_save_load(wg_calcfunction): """Save the workgraph""" + from aiida_workgraph.config import WORKGRAPH_EXTRA_KEY + wg = wg_calcfunction wg.name = "test_save_load" wg.save() @@ -38,6 +59,12 @@ def test_save_load(wg_calcfunction): assert wg.process.label == "test_save_load" wg2 = WorkGraph.load(wg.process.pk) assert len(wg.tasks) == len(wg2.tasks) + # remove the extra + wg.process.base.extras.delete(WORKGRAPH_EXTRA_KEY) + with pytest.raises( + ValueError, match=f"WorkGraph data not found for process {wg.process.pk}." + ): + WorkGraph.load(wg.process.pk) def test_organize_nested_inputs(): @@ -86,7 +113,7 @@ def test_reset_message(wg_calcjob): assert "Action: reset. {'add2'}" in report -def test_restart(wg_calcfunction): +def test_restart_and_reset(wg_calcfunction): """Restart from a finished workgraph. Load the workgraph, modify the task, and restart the workgraph. Only the modified node and its child tasks will be rerun.""" @@ -109,6 +136,10 @@ def test_restart(wg_calcfunction): assert wg1.tasks["sumdiff2"].node.pk != wg.tasks["sumdiff2"].pk assert wg1.tasks["sumdiff3"].node.pk != wg.tasks["sumdiff3"].pk assert wg1.tasks["sumdiff3"].node.outputs.sum == 19 + wg1.reset() + assert wg1.process is None + assert wg1.tasks["sumdiff3"].process is None + assert wg1.tasks["sumdiff3"].state == "PLANNED" def test_extend_workgraph(decorated_add_multiply_group): diff --git a/tests/widget/test_widget.py b/tests/widget/test_widget.py index 3c934a5b..697d2dc5 100644 --- a/tests/widget/test_widget.py +++ b/tests/widget/test_widget.py @@ -14,6 +14,8 @@ def test_workgraph_widget(wg_calcfunction): # to_html data = wg.to_html() assert isinstance(data, IFrame) + # check _repr_mimebundle_ is working + data = wg._repr_mimebundle_() def test_workgraph_task(wg_calcfunction): @@ -26,3 +28,5 @@ def test_workgraph_task(wg_calcfunction): # to html data = wg.tasks["sumdiff2"].to_html() assert isinstance(data, IFrame) + # check _repr_mimebundle_ is working + data = wg.tasks["sumdiff2"]._repr_mimebundle_()