diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index dab8ebd8..4fdcb6b2 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -415,6 +415,13 @@ def build_task_from_workgraph(wg: any) -> Task: task.group_outputs = group_outputs return task +def decorator_wrapper(callable: Callable): + def decorator_task_wrapper(*args, **kwargs): + if len(args) == 1 and isinstance(args[0], Callable) and len(kwargs) == 0: + return callable()(args[0]) + else: + return callable(*args, **kwargs) + return decorator_task_wrapper def generate_tdata( func: Callable, @@ -459,9 +466,10 @@ def generate_tdata( class TaskDecoratorCollection: """Collection of task decorators.""" - + # decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor @staticmethod + @decorator_wrapper def decorator_task( identifier: Optional[str] = None, task_type: str = "Normal", @@ -481,7 +489,6 @@ def decorator_task( inputs (list): task inputs outputs (list): task outputs """ - def decorator(func): nonlocal identifier, task_type @@ -511,6 +518,7 @@ def decorator(func): # decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor @staticmethod + @decorator_wrapper def decorator_graph_builder( identifier: Optional[str] = None, properties: Optional[List[Tuple[str, str]]] = None, @@ -561,6 +569,7 @@ def decorator(func): return decorator @staticmethod + @decorator_wrapper def calcfunction(**kwargs: Any) -> Callable: def decorator(func): # First, apply the calcfunction decorator @@ -579,6 +588,7 @@ def decorator(func): return decorator @staticmethod + @decorator_wrapper def workfunction(**kwargs: Any) -> Callable: def decorator(func): # First, apply the workfunction decorator @@ -597,6 +607,7 @@ def decorator(func): return decorator @staticmethod + @decorator_wrapper def pythonjob(**kwargs: Any) -> Callable: def decorator(func): # first create a task from the function @@ -612,17 +623,21 @@ def decorator(func): return func - return decorator + return decorator # Making decorator_task accessible as 'task' task = decorator_task + #task = decorator_task() # Making decorator_graph_builder accessible as 'graph_builder' graph_builder = decorator_graph_builder def __call__(self, *args, **kwargs): # This allows using '@task' to directly apply the decorator_task functionality - return self.decorator_task(*args, **kwargs) + if len(args) == 1 and isinstance(args[0], Callable) and len(kwargs) == 0: + return self.decorator_task()(args[0]) + else: + return self.decorator_task(*args, **kwargs) task = TaskDecoratorCollection() diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 664a73d8..de7467c2 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -1,25 +1,33 @@ import pytest from aiida_workgraph import WorkGraph from typing import Callable +from aiida_workgraph import task - -def test_args() -> None: - from aiida_workgraph import task - +class Helper: + @staticmethod @task.calcfunction() + def test_callable(a, b=1, **c): + print(a, b, c) + + @staticmethod + @task.calcfunction def test(a, b=1, **c): print(a, b, c) +@pytest.mark.parametrize("test_calcfunction", [Helper.test_callable, Helper.test]) +def test_args(test_calcfunction) -> None: + + metadata_kwargs = set( [ f"metadata.{key}" - for key in test.process_class.spec().inputs.ports["metadata"].ports.keys() + for key in test_calcfunction.process_class.spec().inputs.ports["metadata"].ports.keys() ] ) - kwargs = set(test.process_class.spec().inputs.ports.keys()).union(metadata_kwargs) + kwargs = set(test_calcfunction.process_class.spec().inputs.ports.keys()).union(metadata_kwargs) kwargs.remove("a") # - n = test.task() + n = test_calcfunction.task() assert n.args == ["a"] assert set(n.kwargs) == set(kwargs) assert n.var_args is None