Skip to content

Commit

Permalink
Allow mixed str/dict inputs/outputs to tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
GeigerJ2 committed Nov 19, 2024
1 parent 2559c21 commit f11b84e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 47 deletions.
24 changes: 14 additions & 10 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,18 +768,22 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di
if the former convert them to a list of `dict`s with `name` as the key.
:param inout_list: The input/output list to be validated.
:param list_type: "input" or "output" to indicate what is to be validated.
:raises TypeError: If a list of mixed or wrong types is provided to the task
:param list_type: "input" or "output" to indicate what is to be validated for better error message.
:raises TypeError: If wrong types are provided to the task
:return: Processed `inputs`/`outputs` list.
"""

if all(isinstance(item, str) for item in inout_list):
return [{"name": item} for item in inout_list]
elif all(isinstance(item, dict) for item in inout_list):
return inout_list
elif not all(isinstance(item, dict) for item in inout_list):
if not all(isinstance(item, (dict, str)) for item in inout_list):
raise TypeError(
f"Provide either a list of `str` or `dict` as `{list_type}`, not mixed types."
f"Wrong type provided in the `{list_type}` list to the task, must be either `str` or `dict`."
)
else:
raise TypeError(f"Wrong type provided in the `{list_type}` list to the task.")

processed_inout_list = []

for item in inout_list:
if isinstance(item, str):
processed_inout_list.append({"name": item})
elif isinstance(item, dict):
processed_inout_list.append(item)

return processed_inout_list
57 changes: 20 additions & 37 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,66 +3,49 @@
from aiida_workgraph.utils import validate_task_inout


def test_validate_task_inout_empty_list():
"""Test validation with a list of strings."""
input_list = []
result = validate_task_inout(input_list, "inputs")
assert result == []


def test_validate_task_inout_str_list():
"""Test validation with a list of strings."""
input_list = ["task1", "task2"]
result = validate_task_inout(input_list, "input")
result = validate_task_inout(input_list, "inputs")
assert result == [{"name": "task1"}, {"name": "task2"}]


def test_validate_task_inout_dict_list():
"""Test validation with a list of dictionaries."""
input_list = [{"name": "task1"}, {"name": "task2"}]
result = validate_task_inout(input_list, "input")
result = validate_task_inout(input_list, "inputs")
assert result == input_list


@pytest.mark.parametrize(
"input_list, list_type, expected_error",
[
# Mixed types error cases
(
["task1", {"name": "task2"}],
"input",
"Provide either a list of `str` or `dict` as `input`, not mixed types.",
),
(
[{"name": "task1"}, "task2"],
"output",
"Provide either a list of `str` or `dict` as `output`, not mixed types.",
),
# Empty list cases
([], "input", None),
([], "output", None),
],
)
def test_validate_task_inout_mixed_types(input_list, list_type, expected_error):
"""Test error handling for mixed type lists."""
if expected_error:
with pytest.raises(TypeError) as excinfo:
validate_task_inout(input_list, list_type)
assert str(excinfo.value) == expected_error
else:
# For empty lists, no error should be raised
result = validate_task_inout(input_list, list_type)
assert result == []
def test_validate_task_inout_mixed_list():
"""Test validation with a list of dictionaries."""
input_list = ["task1", {"name": "task2"}]
result = validate_task_inout(input_list, "inputs")
assert result == [{"name": "task1"}, {"name": "task2"}]


@pytest.mark.parametrize(
"input_list, list_type",
[
# Invalid type cases
([1, 2, 3], "input"),
([None, None], "output"),
([True, False], "input"),
(["task", 123], "output"),
([1, 2, 3], "inputs"),
([None, None], "outputs"),
([True, False], "inputs"),
(["task", 123], "outputs"),
],
)
def test_validate_task_inout_invalid_types(input_list, list_type):
"""Test error handling for completely invalid type lists."""
with pytest.raises(TypeError) as excinfo:
validate_task_inout(input_list, list_type)
assert "Provide either a list of" in str(excinfo.value)
assert "Wrong type provided" in str(excinfo.value)


def test_validate_task_inout_dict_with_extra_keys():
Expand All @@ -71,5 +54,5 @@ def test_validate_task_inout_dict_with_extra_keys():
{"name": "task1", "description": "first task"},
{"name": "task2", "priority": "high"},
]
result = validate_task_inout(input_list, "input")
result = validate_task_inout(input_list, "inputs")
assert result == input_list

0 comments on commit f11b84e

Please sign in to comment.