Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM][Bugfix] Make sure that multi_modal_kwargs is broadcasted properly #5880

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ steps:

- label: Core Test
mirror_hardwares: [amd]
command: pytest -v -s core
commands:
- pytest -v -s core
- pytest -v -s distributed/test_parallel_state.py

- label: Distributed Comm Ops Test
#mirror_hardwares: [amd]
Expand Down
49 changes: 49 additions & 0 deletions tests/distributed/test_parallel_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Any, Dict

import torch

from vllm.distributed.parallel_state import (_split_tensor_dict,
_update_nested_dict)


def test_split_tensor_dict():
test_dict = {
"key_a": "a",
"key_b": torch.arange(8, dtype=torch.float32),
"key_c": {
"key_1": torch.arange(5, dtype=torch.float32),
"key_2": torch.tensor([], dtype=torch.float32),
"key_3": 123,
},
"key_d": {},
}
metadata_list, tensor_list = _split_tensor_dict(test_dict)
assert len(metadata_list) == 6
assert torch.allclose(tensor_list[0], test_dict["key_b"])
assert torch.allclose(tensor_list[1], test_dict["key_c"]["key_1"])
assert torch.allclose(tensor_list[2], test_dict["key_c"]["key_2"])


def test_update_nested_dict():
flattened_keys_values = [("key1%key2%key3", "value1"),
("key1%key2%key4", "value2"),
("key1%key5", "value3"), ("key6%key7", "value4"),
("key8", "value5")]
res: Dict[str, Any] = {}

# Update the nested dictionary with each flattened key-value pair
for flat_key, value in flattened_keys_values:
_update_nested_dict(res, flat_key, value)
assert res == {
"key1": {
"key2": {
"key3": "value1",
"key4": "value2"
},
"key5": "value3"
},
"key6": {
"key7": "value4"
},
"key8": "value5"
}
37 changes: 29 additions & 8 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,17 @@ class GraphCaptureContext:


def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.

If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
metadata will be "key1%key2".
"""
metadata_list = []
metadata_list: List[Tuple[str, Any]] = []
tensor_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
Expand All @@ -60,13 +63,31 @@ def _split_tensor_dict(
# receiving side will set the device index.
device = value.device.type
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size())))
(prefix + key, TensorMetadata(device, value.dtype,
value.size())))
tensor_list.append(value)
elif isinstance(value, dict):
if len(value) == 0:
metadata_list.append((prefix + key, value))
inner_metadata_list, inner_tensor_list = _split_tensor_dict(
value, prefix + key + "%")
metadata_list.extend(inner_metadata_list)
tensor_list.extend(inner_tensor_list)
else:
metadata_list.append((key, value))
metadata_list.append((prefix + key, value))
return metadata_list, tensor_list


def _update_nested_dict(nested_dict, flattened_key, value):
key_splits = flattened_key.split("%")
cur_dict = nested_dict
for k in key_splits[:-1]:
if k not in cur_dict:
cur_dict[k] = {}
cur_dict = cur_dict[k]
cur_dict[key_splits[-1]] = value


class GroupCoordinator:
"""
PyTorch ProcessGroup wrapper for a group of processes.
Expand Down Expand Up @@ -410,7 +431,7 @@ def broadcast_tensor_dict(
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
_update_nested_dict(tensor_dict, key, tensor)
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
Expand All @@ -426,9 +447,9 @@ def broadcast_tensor_dict(
group=group,
async_op=True)
async_handles.append(handle)
tensor_dict[key] = tensor
_update_nested_dict(tensor_dict, key, tensor)
else:
tensor_dict[key] = value
_update_nested_dict(tensor_dict, key, value)
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
Expand Down
Loading