diff --git a/aiida_workgraph/properties/builtins.py b/aiida_workgraph/properties/builtins.py index bddd647e..5315bcd7 100644 --- a/aiida_workgraph/properties/builtins.py +++ b/aiida_workgraph/properties/builtins.py @@ -1,10 +1,19 @@ from typing import Dict, List, Union, Callable from aiida_workgraph.property import TaskProperty -from node_graph.serializer import SerializeJson -from node_graph.properties.builtins import PropertyVector, PropertyAny +from node_graph.serializer import SerializeJson, SerializePickle from aiida import orm +class PropertyAny(TaskProperty, SerializePickle): + """A new class for Any type.""" + + identifier: str = "workgraph.any" + data_type = "Any" + + def __init__(self, name, description="", default=None, update=None) -> None: + super().__init__(name, description, default, update) + + class PropertyInt(TaskProperty, SerializeJson): """A new class for integer type.""" @@ -278,6 +287,25 @@ def set_value(self, value: Union[Dict, orm.Dict, str]) -> None: raise Exception("{} is not a dict.".format(value)) +# ==================================== +class PropertyVector(TaskProperty, SerializePickle): + """Vector property""" + + identifier: str = "workgraph.vector" + data_type = "Vector" + + def __init__(self, name, description="", size=3, default=[], update=None) -> None: + super().__init__(name, description, default, update) + self.size = size + + def copy(self): + p = self.__class__( + self.name, self.description, self.size, self.value, self.update + ) + p.value = self.value + return p + + class PropertyAiiDAIntVector(PropertyVector): """A new class for integer vector type.""" diff --git a/tests/test_socket.py b/tests/test_socket.py index 7abfb8d7..59865060 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -1,11 +1,13 @@ import pytest from aiida_workgraph import WorkGraph, task from aiida import orm +from typing import Any @pytest.mark.parametrize( - "data_type, socket_type", + "data_type, identifier", ( + (Any, "workgraph.any"), (int, "workgraph.int"), (float, "workgraph.float"), (bool, "workgraph.bool"), @@ -16,14 +18,15 @@ (orm.Bool, "workgraph.aiida_bool"), ), ) -def test_type_mapping(data_type, socket_type) -> None: +def test_type_mapping(data_type, identifier) -> None: """Test the mapping of data types to socket types.""" @task() def add(x: data_type): pass - assert add.task().inputs["x"].identifier == socket_type + assert add.task().inputs["x"].identifier == identifier + assert add.task().inputs["x"].property.identifier == identifier def test_socket(decorated_multiply) -> None: