Skip to content

Commit

Permalink
Allow nonfunctional usage of decorators (#191)
Browse files Browse the repository at this point in the history
The decorators that could be used as

@decorator_factory()
def add(x, y):
    return x+y

can now be used nonfunctional

@decorator_factory
def add(x, y):
    return x+y

    Implements feature request in issue #191
  • Loading branch information
agoscinski committed Aug 9, 2024
1 parent d41a8ae commit 9358571
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 11 deletions.
42 changes: 41 additions & 1 deletion aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,38 @@ def build_task_from_workgraph(wg: any) -> Task:
return task


def nonfunctional_usage(callable: Callable):
"""
This is a decorator for a decorator factory (a function that returns a decorator).
It allows the usage of the decorator factory in a nonfunctional way. So a decorator
factory that has been decorated by this decorator that could only be used befor like
this
.. code-block:: python
@decorator_factory()
def foo():
pass
can now be also used like this
.. code-block:: python
@decorator_factory
def foo():
pass
"""

def decorator_task_wrapper(*args, **kwargs):
if len(args) == 1 and isinstance(args[0], Callable) and len(kwargs) == 0:
return callable()(args[0])
else:
return callable(*args, **kwargs)

return decorator_task_wrapper


def generate_tdata(
func: Callable,
identifier: str,
Expand Down Expand Up @@ -462,6 +494,7 @@ class TaskDecoratorCollection:

# decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor
@staticmethod
@nonfunctional_usage
def decorator_task(
identifier: Optional[str] = None,
task_type: str = "Normal",
Expand Down Expand Up @@ -511,6 +544,7 @@ def decorator(func):

# decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor
@staticmethod
@nonfunctional_usage
def decorator_graph_builder(
identifier: Optional[str] = None,
properties: Optional[List[Tuple[str, str]]] = None,
Expand Down Expand Up @@ -561,6 +595,7 @@ def decorator(func):
return decorator

@staticmethod
@nonfunctional_usage
def calcfunction(**kwargs: Any) -> Callable:
def decorator(func):
# First, apply the calcfunction decorator
Expand All @@ -579,6 +614,7 @@ def decorator(func):
return decorator

@staticmethod
@nonfunctional_usage
def workfunction(**kwargs: Any) -> Callable:
def decorator(func):
# First, apply the workfunction decorator
Expand All @@ -597,6 +633,7 @@ def decorator(func):
return decorator

@staticmethod
@nonfunctional_usage
def pythonjob(**kwargs: Any) -> Callable:
def decorator(func):
# first create a task from the function
Expand All @@ -622,7 +659,10 @@ def decorator(func):

def __call__(self, *args, **kwargs):
# This allows using '@task' to directly apply the decorator_task functionality
return self.decorator_task(*args, **kwargs)
if len(args) == 1 and isinstance(args[0], Callable) and len(kwargs) == 0:
return self.decorator_task()(args[0])
else:
return self.decorator_task(*args, **kwargs)


task = TaskDecoratorCollection()
4 changes: 2 additions & 2 deletions docs/source/concept/task.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@
"from aiida import orm\n",
"\n",
"# define add task\n",
"@task()\n",
"@task # this is equivalent to passing no arguments @task()\n",
"def add(x, y):\n",
" return x + y\n",
"\n",
"# define multiply calcfunction task\n",
"@task.calcfunction()\n",
"@task.calcfunction # this is equivalent to passing no arguments @task.calculation()\n",
"def multiply(x, y):\n",
" return orm.Float(x + y)\n",
"\n",
Expand Down
140 changes: 132 additions & 8 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,156 @@
import pytest
from aiida_workgraph import WorkGraph
from typing import Callable
from aiida_workgraph import task


def test_args() -> None:
from aiida_workgraph import task
@pytest.fixture(params=["decorator_factory", "decorator"])
def task_calcfunction(request):
if request.param == "decorator_factory":

@task.calcfunction()
def test(a, b=1, **c):
print(a, b, c)
@task.calcfunction()
def test(a, b=1, **c):
print(a, b, c)

elif request.param == "decorator":

@task.calcfunction
def test(a, b=1, **c):
print(a, b, c)

else:
raise ValueError(f"{request.param} not supported.")
return test


def test_decorators_calcfunction_args(task_calcfunction) -> None:
metadata_kwargs = set(
[
f"metadata.{key}"
for key in task_calcfunction.process_class.spec()
.inputs.ports["metadata"]
.ports.keys()
]
)
kwargs = set(task_calcfunction.process_class.spec().inputs.ports.keys()).union(
metadata_kwargs
)
kwargs.remove("a")
#
n = task_calcfunction.task()
assert n.args == ["a"]
assert set(n.kwargs) == set(kwargs)
assert n.var_args is None
assert n.var_kwargs == "c"
assert n.outputs.keys() == ["result", "_outputs", "_wait"]


@pytest.fixture(params=["decorator_factory", "decorator"])
def task_function(request):
if request.param == "decorator_factory":

@task()
def test(a, b=1, **c):
print(a, b, c)

elif request.param == "decorator":

@task
def test(a, b=1, **c):
print(a, b, c)

else:
raise ValueError(f"{request.param} not supported.")
return test


def test_decorators_task_args(task_function):

tdata = task_function.tdata
assert tdata["args"] == ["a"]
assert tdata["kwargs"] == ["b"]
assert tdata["var_args"] is None
assert tdata["var_kwargs"] == "c"
assert set([output["name"] for output in tdata["outputs"]]) == set(
["result", "_outputs", "_wait"]
)


@pytest.fixture(params=["decorator_factory", "decorator"])
def task_workfunction(request):
if request.param == "decorator_factory":

@task.workfunction()
def test(a, b=1, **c):
print(a, b, c)

elif request.param == "decorator":

@task.workfunction
def test(a, b=1, **c):
print(a, b, c)

else:
raise ValueError(f"{request.param} not supported.")
return test


def test_decorators_workfunction_args(task_workfunction) -> None:
metadata_kwargs = set(
[
f"metadata.{key}"
for key in test.process_class.spec().inputs.ports["metadata"].ports.keys()
for key in task_workfunction.process_class.spec()
.inputs.ports["metadata"]
.ports.keys()
]
)
kwargs = set(test.process_class.spec().inputs.ports.keys()).union(metadata_kwargs)
kwargs = set(task_workfunction.process_class.spec().inputs.ports.keys()).union(
metadata_kwargs
)
kwargs.remove("a")
#
n = test.task()
n = task_workfunction.task()
assert n.args == ["a"]
assert set(n.kwargs) == set(kwargs)
assert n.var_args is None
assert n.var_kwargs == "c"
assert n.outputs.keys() == ["result", "_outputs", "_wait"]


@pytest.fixture(params=["decorator_factory", "decorator"])
def task_graph_builder(request):
if request.param == "decorator_factory":

@task.graph_builder()
def add_multiply_group(a, b=1, **c):
wg = WorkGraph("add_multiply_group")
print(a, b, c)
return wg

elif request.param == "decorator":

@task.graph_builder
def add_multiply_group(a, b=1, **c):
wg = WorkGraph("add_multiply_group")
print(a, b, c)
return wg

else:
raise ValueError(f"{request.param} not supported.")

return add_multiply_group


def test_decorators_graph_builder_args(task_graph_builder) -> None:
assert task_graph_builder.identifier == "add_multiply_group"
n = task_graph_builder.task()
assert n.args == ["a"]
assert n.kwargs == ["b"]
assert n.var_args is None
assert n.var_kwargs == "c"
assert set(n.outputs.keys()) == set(["_outputs", "_wait"])


def test_inputs_outputs_workchain() -> None:
from aiida.workflows.arithmetic.multiply_add import MultiplyAddWorkChain

Expand Down

0 comments on commit 9358571

Please sign in to comment.