Skip to content

Commit

Permalink
adapt decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Aug 5, 2024
1 parent a7f0538 commit f580bcd
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
23 changes: 19 additions & 4 deletions aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,13 @@ def build_task_from_workgraph(wg: any) -> Task:
task.group_outputs = group_outputs
return task

def decorator_wrapper(callable: Callable):
def decorator_task_wrapper(*args, **kwargs):
if len(args) == 1 and isinstance(args[0], Callable) and len(kwargs) == 0:
return callable()(args[0])
else:
return callable(*args, **kwargs)
return decorator_task_wrapper

def generate_tdata(
func: Callable,
Expand Down Expand Up @@ -459,9 +466,10 @@ def generate_tdata(

class TaskDecoratorCollection:
"""Collection of task decorators."""

# decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor
@staticmethod
@decorator_wrapper
def decorator_task(
identifier: Optional[str] = None,
task_type: str = "Normal",
Expand All @@ -481,7 +489,6 @@ def decorator_task(
inputs (list): task inputs
outputs (list): task outputs
"""

def decorator(func):
nonlocal identifier, task_type

Expand Down Expand Up @@ -511,6 +518,7 @@ def decorator(func):

# decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor
@staticmethod
@decorator_wrapper
def decorator_graph_builder(
identifier: Optional[str] = None,
properties: Optional[List[Tuple[str, str]]] = None,
Expand Down Expand Up @@ -561,6 +569,7 @@ def decorator(func):
return decorator

@staticmethod
@decorator_wrapper
def calcfunction(**kwargs: Any) -> Callable:
def decorator(func):
# First, apply the calcfunction decorator
Expand All @@ -579,6 +588,7 @@ def decorator(func):
return decorator

@staticmethod
@decorator_wrapper
def workfunction(**kwargs: Any) -> Callable:
def decorator(func):
# First, apply the workfunction decorator
Expand All @@ -597,6 +607,7 @@ def decorator(func):
return decorator

@staticmethod
@decorator_wrapper
def pythonjob(**kwargs: Any) -> Callable:
def decorator(func):
# first create a task from the function
Expand All @@ -612,17 +623,21 @@ def decorator(func):

return func

return decorator
return decorator

# Making decorator_task accessible as 'task'
task = decorator_task
#task = decorator_task()

# Making decorator_graph_builder accessible as 'graph_builder'
graph_builder = decorator_graph_builder

def __call__(self, *args, **kwargs):
# This allows using '@task' to directly apply the decorator_task functionality
return self.decorator_task(*args, **kwargs)
if len(args) == 1 and isinstance(args[0], Callable) and len(kwargs) == 0:
return self.decorator_task()(args[0])
else:
return self.decorator_task(*args, **kwargs)


task = TaskDecoratorCollection()
22 changes: 15 additions & 7 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
import pytest
from aiida_workgraph import WorkGraph
from typing import Callable
from aiida_workgraph import task


def test_args() -> None:
from aiida_workgraph import task

class Helper:
@staticmethod
@task.calcfunction()
def test_callable(a, b=1, **c):
print(a, b, c)

@staticmethod
@task.calcfunction
def test(a, b=1, **c):
print(a, b, c)

@pytest.mark.parametrize("test_calcfunction", [Helper.test_callable, Helper.test])
def test_args(test_calcfunction) -> None:


metadata_kwargs = set(
[
f"metadata.{key}"
for key in test.process_class.spec().inputs.ports["metadata"].ports.keys()
for key in test_calcfunction.process_class.spec().inputs.ports["metadata"].ports.keys()
]
)
kwargs = set(test.process_class.spec().inputs.ports.keys()).union(metadata_kwargs)
kwargs = set(test_calcfunction.process_class.spec().inputs.ports.keys()).union(metadata_kwargs)
kwargs.remove("a")
#
n = test.task()
n = test_calcfunction.task()
assert n.args == ["a"]
assert set(n.kwargs) == set(kwargs)
assert n.var_args is None
Expand Down

0 comments on commit f580bcd

Please sign in to comment.