Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow nonfunctional usage of decorators #199

Merged
merged 2 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,38 @@ def build_task_from_workgraph(wg: any) -> Task:
return task


def nonfunctional_usage(callable: Callable):
"""
This is a decorator for a decorator factory (a function that returns a decorator).
It allows the usage of the decorator factory in a nonfunctional way. So a decorator
factory that has been decorated by this decorator that could only be used befor like
this

.. code-block:: python

@decorator_factory()
def foo():
pass

can now be also used like this

.. code-block:: python

@decorator_factory
def foo():
pass

"""

def decorator_task_wrapper(*args, **kwargs):
if len(args) == 1 and isinstance(args[0], Callable) and len(kwargs) == 0:
return callable()(args[0])
else:
return callable(*args, **kwargs)

return decorator_task_wrapper


def generate_tdata(
func: Callable,
identifier: str,
Expand Down Expand Up @@ -462,6 +494,7 @@ class TaskDecoratorCollection:

# decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor
@staticmethod
@nonfunctional_usage
def decorator_task(
identifier: Optional[str] = None,
task_type: str = "Normal",
Expand Down Expand Up @@ -511,6 +544,7 @@ def decorator(func):

# decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor
@staticmethod
@nonfunctional_usage
def decorator_graph_builder(
identifier: Optional[str] = None,
properties: Optional[List[Tuple[str, str]]] = None,
Expand Down Expand Up @@ -561,6 +595,7 @@ def decorator(func):
return decorator

@staticmethod
@nonfunctional_usage
def calcfunction(**kwargs: Any) -> Callable:
def decorator(func):
# First, apply the calcfunction decorator
Expand All @@ -579,6 +614,7 @@ def decorator(func):
return decorator

@staticmethod
@nonfunctional_usage
def workfunction(**kwargs: Any) -> Callable:
def decorator(func):
# First, apply the workfunction decorator
Expand All @@ -597,6 +633,7 @@ def decorator(func):
return decorator

@staticmethod
@nonfunctional_usage
def pythonjob(**kwargs: Any) -> Callable:
def decorator(func):
# first create a task from the function
Expand All @@ -622,7 +659,10 @@ def decorator(func):

def __call__(self, *args, **kwargs):
# This allows using '@task' to directly apply the decorator_task functionality
return self.decorator_task(*args, **kwargs)
if len(args) == 1 and isinstance(args[0], Callable) and len(kwargs) == 0:
return self.decorator_task()(args[0])
else:
return self.decorator_task(*args, **kwargs)


task = TaskDecoratorCollection()
4 changes: 2 additions & 2 deletions docs/source/concept/task.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@
"from aiida import orm\n",
"\n",
"# define add task\n",
"@task()\n",
"@task # this is equivalent to passing no arguments @task()\n",
"def add(x, y):\n",
" return x + y\n",
"\n",
"# define multiply calcfunction task\n",
"@task.calcfunction()\n",
"@task.calcfunction # this is equivalent to passing no arguments @task.calculation()\n",
"def multiply(x, y):\n",
" return orm.Float(x + y)\n",
"\n",
Expand Down
140 changes: 132 additions & 8 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,156 @@
import pytest
from aiida_workgraph import WorkGraph
from typing import Callable
from aiida_workgraph import task


def test_args() -> None:
from aiida_workgraph import task
@pytest.fixture(params=["decorator_factory", "decorator"])
def task_calcfunction(request):
if request.param == "decorator_factory":

@task.calcfunction()
def test(a, b=1, **c):
print(a, b, c)
@task.calcfunction()
def test(a, b=1, **c):
print(a, b, c)

elif request.param == "decorator":

@task.calcfunction
def test(a, b=1, **c):
print(a, b, c)

else:
raise ValueError(f"{request.param} not supported.")
return test


def test_decorators_calcfunction_args(task_calcfunction) -> None:
metadata_kwargs = set(
[
f"metadata.{key}"
for key in task_calcfunction.process_class.spec()
.inputs.ports["metadata"]
.ports.keys()
]
)
kwargs = set(task_calcfunction.process_class.spec().inputs.ports.keys()).union(
metadata_kwargs
)
kwargs.remove("a")
#
n = task_calcfunction.task()
assert n.args == ["a"]
assert set(n.kwargs) == set(kwargs)
assert n.var_args is None
assert n.var_kwargs == "c"
assert n.outputs.keys() == ["result", "_outputs", "_wait"]


@pytest.fixture(params=["decorator_factory", "decorator"])
def task_function(request):
if request.param == "decorator_factory":

@task()
def test(a, b=1, **c):
print(a, b, c)

elif request.param == "decorator":

@task
def test(a, b=1, **c):
print(a, b, c)

else:
raise ValueError(f"{request.param} not supported.")
return test


def test_decorators_task_args(task_function):

tdata = task_function.tdata
assert tdata["args"] == ["a"]
assert tdata["kwargs"] == ["b"]
assert tdata["var_args"] is None
assert tdata["var_kwargs"] == "c"
assert set([output["name"] for output in tdata["outputs"]]) == set(
["result", "_outputs", "_wait"]
)


@pytest.fixture(params=["decorator_factory", "decorator"])
def task_workfunction(request):
if request.param == "decorator_factory":

@task.workfunction()
def test(a, b=1, **c):
print(a, b, c)

elif request.param == "decorator":

@task.workfunction
def test(a, b=1, **c):
print(a, b, c)

else:
raise ValueError(f"{request.param} not supported.")
return test


def test_decorators_workfunction_args(task_workfunction) -> None:
metadata_kwargs = set(
[
f"metadata.{key}"
for key in test.process_class.spec().inputs.ports["metadata"].ports.keys()
for key in task_workfunction.process_class.spec()
.inputs.ports["metadata"]
.ports.keys()
]
)
kwargs = set(test.process_class.spec().inputs.ports.keys()).union(metadata_kwargs)
kwargs = set(task_workfunction.process_class.spec().inputs.ports.keys()).union(
metadata_kwargs
)
kwargs.remove("a")
#
n = test.task()
n = task_workfunction.task()
assert n.args == ["a"]
assert set(n.kwargs) == set(kwargs)
assert n.var_args is None
assert n.var_kwargs == "c"
assert n.outputs.keys() == ["result", "_outputs", "_wait"]


@pytest.fixture(params=["decorator_factory", "decorator"])
def task_graph_builder(request):
if request.param == "decorator_factory":

@task.graph_builder()
def add_multiply_group(a, b=1, **c):
wg = WorkGraph("add_multiply_group")
print(a, b, c)
return wg

elif request.param == "decorator":

@task.graph_builder
def add_multiply_group(a, b=1, **c):
wg = WorkGraph("add_multiply_group")
print(a, b, c)
return wg

else:
raise ValueError(f"{request.param} not supported.")

return add_multiply_group


def test_decorators_graph_builder_args(task_graph_builder) -> None:
assert task_graph_builder.identifier == "add_multiply_group"
n = task_graph_builder.task()
assert n.args == ["a"]
assert n.kwargs == ["b"]
assert n.var_args is None
assert n.var_kwargs == "c"
assert set(n.outputs.keys()) == set(["_outputs", "_wait"])


def test_inputs_outputs_workchain() -> None:
from aiida.workflows.arithmetic.multiply_add import MultiplyAddWorkChain

Expand Down
Loading