diff --git a/tests/distributed/test_parallel_state.py b/tests/distributed/test_parallel_state.py index 5d293b2c16c4..3adcf6b61046 100644 --- a/tests/distributed/test_parallel_state.py +++ b/tests/distributed/test_parallel_state.py @@ -1,5 +1,6 @@ from typing import Any, Dict +import pytest import torch from vllm.distributed.parallel_state import (_split_tensor_dict, @@ -24,6 +25,14 @@ def test_split_tensor_dict(): assert torch.allclose(tensor_list[2], test_dict["key_c"]["key_2"]) +def test_split_tensor_dict_invalid_key(): + test_dict = { + "a%b": "a", + } + with pytest.raises(AssertionError): + _split_tensor_dict(test_dict) + + def test_update_nested_dict(): flattened_keys_values = [("key1%key2%key3", "value1"), ("key1%key2%key4", "value2"), @@ -31,7 +40,6 @@ def test_update_nested_dict(): ("key8", "value5")] res: Dict[str, Any] = {} - # Update the nested dictionary with each flattened key-value pair for flat_key, value in flattened_keys_values: _update_nested_dict(res, flat_key, value) assert res == { diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 51616cb0fdb4..0c4ee0eb2c04 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -58,6 +58,9 @@ def _split_tensor_dict( metadata_list: List[Tuple[str, Any]] = [] tensor_list = [] for key, value in tensor_dict.items(): + assert "%" not in key, ( + "Avoid having '%' in key " + "as it is used as a separator for nested entries.") if isinstance(value, torch.Tensor): # Note: we cannot use `value.device` here, # because it contains not only the device type but also the device