diff --git a/aiida_workgraph/sockets/builtins.py b/aiida_workgraph/sockets/builtins.py index 692fb75e..9e946982 100644 --- a/aiida_workgraph/sockets/builtins.py +++ b/aiida_workgraph/sockets/builtins.py @@ -1,4 +1,3 @@ -from typing import Optional, Any from aiida_workgraph.socket import TaskSocket @@ -6,197 +5,88 @@ class SocketAny(TaskSocket): """Any socket.""" identifier: str = "workgraph.any" - - def __init__( - self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.any", name, **kwargs) + property_identifier: str = "workgraph.any" class SocketNamespace(TaskSocket): """Namespace socket.""" identifier: str = "workgraph.namespace" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - # Set the default value to an empty dictionary - kwargs.setdefault("default", {}) - self.add_property("workgraph.any", name, **kwargs) + property_identifier: str = "workgraph.any" class SocketFloat(TaskSocket): """Float socket.""" identifier: str = "workgraph.float" - - def __init__( - self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.float", name, **kwargs) + property_identifier: str = "workgraph.float" class SocketInt(TaskSocket): """Int socket.""" identifier: str = "workgraph.int" - - def __init__( - self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.int", name, **kwargs) + property_identifier: str = "workgraph.int" class SocketString(TaskSocket): """String socket.""" identifier: str = "workgraph.string" - - def __init__( - self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.string", name, **kwargs) + property_identifier: str = "workgraph.string" class SocketBool(TaskSocket): """Bool socket.""" identifier: str = "workgraph.bool" - - def __init__( - self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.bool", name, **kwargs) + property_identifier: str = "workgraph.bool" class SocketAiiDAFloat(TaskSocket): """AiiDAFloat socket.""" identifier: str = "workgraph.aiida_float" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_float", name, **kwargs) + property_identifier: str = "workgraph.aiida_float" class SocketAiiDAInt(TaskSocket): """AiiDAInt socket.""" identifier: str = "workgraph.aiida_int" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_int", name, **kwargs) + property_identifier: str = "workgraph.aiida_int" class SocketAiiDAString(TaskSocket): """AiiDAString socket.""" identifier: str = "workgraph.aiida_string" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_string", name, **kwargs) + property_identifier: str = "workgraph.aiida_string" class SocketAiiDABool(TaskSocket): """AiiDABool socket.""" identifier: str = "workgraph.aiida_bool" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_bool", name, **kwargs) + property_identifier: str = "workgraph.aiida_bool" class SocketAiiDAIntVector(TaskSocket): """Socket with a AiiDAIntVector property.""" identifier: str = "workgraph.aiida_int_vector" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_int_vector", name, **kwargs) + property_identifier: str = "workgraph.aiida_int_vector" class SocketAiiDAFloatVector(TaskSocket): """Socket with a FloatVector property.""" identifier: str = "workgraph.aiida_float_vector" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_float_vector", name, **kwargs) + property_identifier: str = "workgraph.aiida_float_vector" class SocketStructureData(TaskSocket): """Any socket.""" identifier: str = "workgraph.aiida_structuredata" - - def __init__( - self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_structuredata", name, **kwargs) + property_identifier: str = "workgraph.aiida_structuredata" diff --git a/aiida_workgraph/task.py b/aiida_workgraph/task.py index 1bf48ee7..d8af66eb 100644 --- a/aiida_workgraph/task.py +++ b/aiida_workgraph/task.py @@ -57,10 +57,10 @@ def __init__( self.state = "PLANNED" self.action = "" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self, short: bool = False) -> Dict[str, Any]: from aiida.orm.utils.serialize import serialize - tdata = super().to_dict() + tdata = super().to_dict(short=short) tdata["context_mapping"] = self.context_mapping tdata["wait"] = [task.name for task in self.waiting_on] tdata["children"] = [] diff --git a/aiida_workgraph/tasks/builtins.py b/aiida_workgraph/tasks/builtins.py index 3e771a80..1fe5cee8 100644 --- a/aiida_workgraph/tasks/builtins.py +++ b/aiida_workgraph/tasks/builtins.py @@ -23,8 +23,8 @@ def create_sockets(self) -> None: inp.link_limit = 100000 self.outputs.new("workgraph.any", "_wait") - def to_dict(self) -> Dict[str, Any]: - tdata = super().to_dict() + def to_dict(self, short: bool = False) -> Dict[str, Any]: + tdata = super().to_dict(short=short) tdata["children"] = [task.name for task in self.children] return tdata @@ -164,7 +164,9 @@ class AiiDAFloat(Task): args = ["value"] def create_sockets(self) -> None: - self.inputs.new("workgraph.aiida_float", "value", default=0.0) + self.inputs.new( + "workgraph.aiida_float", "value", property_data={"default": 0.0} + ) self.outputs.new("workgraph.aiida_float", "result") diff --git a/aiida_workgraph/tasks/pythonjob.py b/aiida_workgraph/tasks/pythonjob.py index e204f68a..ad123017 100644 --- a/aiida_workgraph/tasks/pythonjob.py +++ b/aiida_workgraph/tasks/pythonjob.py @@ -17,8 +17,8 @@ def update_from_dict(self, data: Dict[str, Any], **kwargs) -> "PythonJob": self.deserialize_pythonjob_data(data) super().update_from_dict(data) - def to_dict(self) -> Dict[str, Any]: - data = super().to_dict() + def to_dict(self, short: bool = False) -> Dict[str, Any]: + data = super().to_dict(short=short) data["function_kwargs"] = self.function_kwargs return data diff --git a/pyproject.toml b/pyproject.toml index ffedd88d..25b8739e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "numpy~=1.21", "scipy", "ase", - "node-graph==0.1.2", + "node-graph==0.1.3", "aiida-core>=2.3", "cloudpickle", "aiida-shell~=0.8", diff --git a/tests/test_socket.py b/tests/test_socket.py index 6e050158..fea72d4d 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -109,7 +109,7 @@ def test(a, b=1, **kwargs): test1 = test.node() assert test1.inputs["kwargs"].link_limit == 1e6 assert test1.inputs["kwargs"].identifier == "workgraph.namespace" - assert test1.inputs["kwargs"].property.value == {} + assert test1.inputs["kwargs"].property.value is None @pytest.mark.parametrize( diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 74ef19e9..d2c5b916 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -34,7 +34,7 @@ def test_build_task_from_workgraph( wg = WorkGraph("build_task_from_workgraph") add1_task = wg.add_task(decorated_add, name="add1", x=1, y=3) wg_task = wg.add_task(wg_calcfunction, name="wg_calcfunction") - assert wg_task.inputs["sumdiff1"].value == {} + assert wg_task.inputs["sumdiff1"].value is None wg.add_task(decorated_add, name="add2", y=3) wg.add_link(add1_task.outputs["result"], wg_task.inputs["sumdiff1.x"]) wg.add_link(wg_task.outputs["sumdiff2.sum"], wg.tasks["add2"].inputs["x"])