diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index 3f7731b7..8a933161 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -768,18 +768,22 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di if the former convert them to a list of `dict`s with `name` as the key. :param inout_list: The input/output list to be validated. - :param list_type: "input" or "output" to indicate what is to be validated. - :raises TypeError: If a list of mixed or wrong types is provided to the task + :param list_type: "input" or "output" to indicate what is to be validated for better error message. + :raises TypeError: If wrong types are provided to the task :return: Processed `inputs`/`outputs` list. """ - if all(isinstance(item, str) for item in inout_list): - return [{"name": item} for item in inout_list] - elif all(isinstance(item, dict) for item in inout_list): - return inout_list - elif not all(isinstance(item, dict) for item in inout_list): + if not all(isinstance(item, (dict, str)) for item in inout_list): raise TypeError( - f"Provide either a list of `str` or `dict` as `{list_type}`, not mixed types." + f"Wrong type provided in the `{list_type}` list to the task, must be either `str` or `dict`." ) - else: - raise TypeError(f"Wrong type provided in the `{list_type}` list to the task.") + + processed_inout_list = [] + + for item in inout_list: + if isinstance(item, str): + processed_inout_list.append({"name": item}) + elif isinstance(item, dict): + processed_inout_list.append(item) + + return processed_inout_list diff --git a/tests/test_utils.py b/tests/test_utils.py index 8e42db34..8ecaaab6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,66 +3,49 @@ from aiida_workgraph.utils import validate_task_inout +def test_validate_task_inout_empty_list(): + """Test validation with a list of strings.""" + input_list = [] + result = validate_task_inout(input_list, "inputs") + assert result == [] + + def test_validate_task_inout_str_list(): """Test validation with a list of strings.""" input_list = ["task1", "task2"] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == [{"name": "task1"}, {"name": "task2"}] def test_validate_task_inout_dict_list(): """Test validation with a list of dictionaries.""" input_list = [{"name": "task1"}, {"name": "task2"}] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == input_list -@pytest.mark.parametrize( - "input_list, list_type, expected_error", - [ - # Mixed types error cases - ( - ["task1", {"name": "task2"}], - "input", - "Provide either a list of `str` or `dict` as `input`, not mixed types.", - ), - ( - [{"name": "task1"}, "task2"], - "output", - "Provide either a list of `str` or `dict` as `output`, not mixed types.", - ), - # Empty list cases - ([], "input", None), - ([], "output", None), - ], -) -def test_validate_task_inout_mixed_types(input_list, list_type, expected_error): - """Test error handling for mixed type lists.""" - if expected_error: - with pytest.raises(TypeError) as excinfo: - validate_task_inout(input_list, list_type) - assert str(excinfo.value) == expected_error - else: - # For empty lists, no error should be raised - result = validate_task_inout(input_list, list_type) - assert result == [] +def test_validate_task_inout_mixed_list(): + """Test validation with a list of dictionaries.""" + input_list = ["task1", {"name": "task2"}] + result = validate_task_inout(input_list, "inputs") + assert result == [{"name": "task1"}, {"name": "task2"}] @pytest.mark.parametrize( "input_list, list_type", [ # Invalid type cases - ([1, 2, 3], "input"), - ([None, None], "output"), - ([True, False], "input"), - (["task", 123], "output"), + ([1, 2, 3], "inputs"), + ([None, None], "outputs"), + ([True, False], "inputs"), + (["task", 123], "outputs"), ], ) def test_validate_task_inout_invalid_types(input_list, list_type): """Test error handling for completely invalid type lists.""" with pytest.raises(TypeError) as excinfo: validate_task_inout(input_list, list_type) - assert "Provide either a list of" in str(excinfo.value) + assert "Wrong type provided" in str(excinfo.value) def test_validate_task_inout_dict_with_extra_keys(): @@ -71,5 +54,5 @@ def test_validate_task_inout_dict_with_extra_keys(): {"name": "task1", "description": "first task"}, {"name": "task2", "priority": "high"}, ] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == input_list