From 935857162f399775c54cf8c048b987336633275d Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Fri, 9 Aug 2024 09:44:52 +0200 Subject: [PATCH] Allow nonfunctional usage of decorators (#191) The decorators that could be used as @decorator_factory() def add(x, y): return x+y can now be used nonfunctional @decorator_factory def add(x, y): return x+y Implements feature request in issue #191 --- aiida_workgraph/decorator.py | 42 +++++++++- docs/source/concept/task.ipynb | 4 +- tests/test_decorator.py | 140 +++++++++++++++++++++++++++++++-- 3 files changed, 175 insertions(+), 11 deletions(-) diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index dab8ebd8..2d2b56c7 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -416,6 +416,38 @@ def build_task_from_workgraph(wg: any) -> Task: return task +def nonfunctional_usage(callable: Callable): + """ + This is a decorator for a decorator factory (a function that returns a decorator). + It allows the usage of the decorator factory in a nonfunctional way. So a decorator + factory that has been decorated by this decorator that could only be used befor like + this + + .. code-block:: python + + @decorator_factory() + def foo(): + pass + + can now be also used like this + + .. code-block:: python + + @decorator_factory + def foo(): + pass + + """ + + def decorator_task_wrapper(*args, **kwargs): + if len(args) == 1 and isinstance(args[0], Callable) and len(kwargs) == 0: + return callable()(args[0]) + else: + return callable(*args, **kwargs) + + return decorator_task_wrapper + + def generate_tdata( func: Callable, identifier: str, @@ -462,6 +494,7 @@ class TaskDecoratorCollection: # decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor @staticmethod + @nonfunctional_usage def decorator_task( identifier: Optional[str] = None, task_type: str = "Normal", @@ -511,6 +544,7 @@ def decorator(func): # decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor @staticmethod + @nonfunctional_usage def decorator_graph_builder( identifier: Optional[str] = None, properties: Optional[List[Tuple[str, str]]] = None, @@ -561,6 +595,7 @@ def decorator(func): return decorator @staticmethod + @nonfunctional_usage def calcfunction(**kwargs: Any) -> Callable: def decorator(func): # First, apply the calcfunction decorator @@ -579,6 +614,7 @@ def decorator(func): return decorator @staticmethod + @nonfunctional_usage def workfunction(**kwargs: Any) -> Callable: def decorator(func): # First, apply the workfunction decorator @@ -597,6 +633,7 @@ def decorator(func): return decorator @staticmethod + @nonfunctional_usage def pythonjob(**kwargs: Any) -> Callable: def decorator(func): # first create a task from the function @@ -622,7 +659,10 @@ def decorator(func): def __call__(self, *args, **kwargs): # This allows using '@task' to directly apply the decorator_task functionality - return self.decorator_task(*args, **kwargs) + if len(args) == 1 and isinstance(args[0], Callable) and len(kwargs) == 0: + return self.decorator_task()(args[0]) + else: + return self.decorator_task(*args, **kwargs) task = TaskDecoratorCollection() diff --git a/docs/source/concept/task.ipynb b/docs/source/concept/task.ipynb index 3ea57cbc..9c3e40c3 100644 --- a/docs/source/concept/task.ipynb +++ b/docs/source/concept/task.ipynb @@ -45,12 +45,12 @@ "from aiida import orm\n", "\n", "# define add task\n", - "@task()\n", + "@task # this is equivalent to passing no arguments @task()\n", "def add(x, y):\n", " return x + y\n", "\n", "# define multiply calcfunction task\n", - "@task.calcfunction()\n", + "@task.calcfunction # this is equivalent to passing no arguments @task.calculation()\n", "def multiply(x, y):\n", " return orm.Float(x + y)\n", "\n", diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 664a73d8..a92478df 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -1,25 +1,115 @@ import pytest from aiida_workgraph import WorkGraph from typing import Callable +from aiida_workgraph import task -def test_args() -> None: - from aiida_workgraph import task +@pytest.fixture(params=["decorator_factory", "decorator"]) +def task_calcfunction(request): + if request.param == "decorator_factory": - @task.calcfunction() - def test(a, b=1, **c): - print(a, b, c) + @task.calcfunction() + def test(a, b=1, **c): + print(a, b, c) + elif request.param == "decorator": + + @task.calcfunction + def test(a, b=1, **c): + print(a, b, c) + + else: + raise ValueError(f"{request.param} not supported.") + return test + + +def test_decorators_calcfunction_args(task_calcfunction) -> None: + metadata_kwargs = set( + [ + f"metadata.{key}" + for key in task_calcfunction.process_class.spec() + .inputs.ports["metadata"] + .ports.keys() + ] + ) + kwargs = set(task_calcfunction.process_class.spec().inputs.ports.keys()).union( + metadata_kwargs + ) + kwargs.remove("a") + # + n = task_calcfunction.task() + assert n.args == ["a"] + assert set(n.kwargs) == set(kwargs) + assert n.var_args is None + assert n.var_kwargs == "c" + assert n.outputs.keys() == ["result", "_outputs", "_wait"] + + +@pytest.fixture(params=["decorator_factory", "decorator"]) +def task_function(request): + if request.param == "decorator_factory": + + @task() + def test(a, b=1, **c): + print(a, b, c) + + elif request.param == "decorator": + + @task + def test(a, b=1, **c): + print(a, b, c) + + else: + raise ValueError(f"{request.param} not supported.") + return test + + +def test_decorators_task_args(task_function): + + tdata = task_function.tdata + assert tdata["args"] == ["a"] + assert tdata["kwargs"] == ["b"] + assert tdata["var_args"] is None + assert tdata["var_kwargs"] == "c" + assert set([output["name"] for output in tdata["outputs"]]) == set( + ["result", "_outputs", "_wait"] + ) + + +@pytest.fixture(params=["decorator_factory", "decorator"]) +def task_workfunction(request): + if request.param == "decorator_factory": + + @task.workfunction() + def test(a, b=1, **c): + print(a, b, c) + + elif request.param == "decorator": + + @task.workfunction + def test(a, b=1, **c): + print(a, b, c) + + else: + raise ValueError(f"{request.param} not supported.") + return test + + +def test_decorators_workfunction_args(task_workfunction) -> None: metadata_kwargs = set( [ f"metadata.{key}" - for key in test.process_class.spec().inputs.ports["metadata"].ports.keys() + for key in task_workfunction.process_class.spec() + .inputs.ports["metadata"] + .ports.keys() ] ) - kwargs = set(test.process_class.spec().inputs.ports.keys()).union(metadata_kwargs) + kwargs = set(task_workfunction.process_class.spec().inputs.ports.keys()).union( + metadata_kwargs + ) kwargs.remove("a") # - n = test.task() + n = task_workfunction.task() assert n.args == ["a"] assert set(n.kwargs) == set(kwargs) assert n.var_args is None @@ -27,6 +117,40 @@ def test(a, b=1, **c): assert n.outputs.keys() == ["result", "_outputs", "_wait"] +@pytest.fixture(params=["decorator_factory", "decorator"]) +def task_graph_builder(request): + if request.param == "decorator_factory": + + @task.graph_builder() + def add_multiply_group(a, b=1, **c): + wg = WorkGraph("add_multiply_group") + print(a, b, c) + return wg + + elif request.param == "decorator": + + @task.graph_builder + def add_multiply_group(a, b=1, **c): + wg = WorkGraph("add_multiply_group") + print(a, b, c) + return wg + + else: + raise ValueError(f"{request.param} not supported.") + + return add_multiply_group + + +def test_decorators_graph_builder_args(task_graph_builder) -> None: + assert task_graph_builder.identifier == "add_multiply_group" + n = task_graph_builder.task() + assert n.args == ["a"] + assert n.kwargs == ["b"] + assert n.var_args is None + assert n.var_kwargs == "c" + assert set(n.outputs.keys()) == set(["_outputs", "_wait"]) + + def test_inputs_outputs_workchain() -> None: from aiida.workflows.arithmetic.multiply_add import MultiplyAddWorkChain