Skip to content

Commit

Permalink
Allow nonfunctional usage of decorators
Browse files Browse the repository at this point in the history
The decorators that could be used as

@decorator()
def add(x, y):
    return x+y

can now be used nonfunctional

@decorator
def add(x, y):
    return x+y

Implements feature request in issue #191
  • Loading branch information
agoscinski committed Aug 5, 2024
1 parent f580bcd commit d1de081
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 19 deletions.
1 change: 0 additions & 1 deletion aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,6 @@ def decorator(func):

# Making decorator_task accessible as 'task'
task = decorator_task
#task = decorator_task()

# Making decorator_graph_builder accessible as 'graph_builder'
graph_builder = decorator_graph_builder
Expand Down
77 changes: 59 additions & 18 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,78 @@
from typing import Callable
from aiida_workgraph import task

class Helper:
@staticmethod
@task.calcfunction()
def test_callable(a, b=1, **c):
print(a, b, c)

@staticmethod
@task.calcfunction
def test(a, b=1, **c):
print(a, b, c)

@pytest.mark.parametrize("test_calcfunction", [Helper.test_callable, Helper.test])
def test_args(test_calcfunction) -> None:


@pytest.fixture(params=["decorator_factory", "decorator"])
def task_calcfunction(request):
if request.param == "decorator_factory":
@task.calcfunction()
def test(a, b=1, **c):
print(a, b, c)

elif request.param == "decorator":
@task.calcfunction
def test(a, b=1, **c):
print(a, b, c)
else:
raise ValueError(f"{request.param} not supported.")
return test


@pytest.fixture(params=["decorator_factory", "decorator"])
def task_function(request):
if request.param == "decorator_factory":
@task()
def test(a, b=1, **c):
print(a, b, c)

elif request.param == "decorator":
@task
def test(a, b=1, **c):
print(a, b, c)
else:
raise ValueError(f"{request.param} not supported.")
return test


@pytest.fixture(params=["decorator_factory", "decorator"])
def task_workfunction(request):
if request.param == "decorator_factory":
@task.workfunction()
def test(a, b=1, **c):
print(a, b, c)

elif request.param == "decorator":
@task.workfunction
def test(a, b=1, **c):
print(a, b, c)
else:
raise ValueError(f"{request.param} not supported.")
return test

def test_decorators_calcfunction_args(task_calcfunction) -> None:
metadata_kwargs = set(
[
f"metadata.{key}"
for key in test_calcfunction.process_class.spec().inputs.ports["metadata"].ports.keys()
for key in task_calcfunction.process_class.spec().inputs.ports["metadata"].ports.keys()
]
)
kwargs = set(test_calcfunction.process_class.spec().inputs.ports.keys()).union(metadata_kwargs)
kwargs = set(task_calcfunction.process_class.spec().inputs.ports.keys()).union(metadata_kwargs)
kwargs.remove("a")
#
n = test_calcfunction.task()
n = task_calcfunction.task()
assert n.args == ["a"]
assert set(n.kwargs) == set(kwargs)
assert n.var_args is None
assert n.var_kwargs == "c"
assert n.outputs.keys() == ["result", "_outputs", "_wait"]

def test_decorators_task_args(task_function):

tdata = task_function.tdata
assert tdata["args"] == ["a"]
assert tdata["kwargs"] == ["b"]
assert tdata["var_args"] is None
assert tdata["var_kwargs"] == "c"
assert set([output["name"] for output in tdata["outputs"]]) == set(["result", "_outputs", "_wait"])

def test_inputs_outputs_workchain() -> None:
from aiida.workflows.arithmetic.multiply_add import MultiplyAddWorkChain
Expand Down

0 comments on commit d1de081

Please sign in to comment.