diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 633bc5ca95bf..f072a26e8b74 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -90,13 +90,13 @@ steps: - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py + - TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py + - TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py diff --git a/examples/llava_example.py b/examples/llava_example.py index 4c9eabd261e5..017faf7795d8 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -1,20 +1,36 @@ from vllm import LLM from vllm.assets.image import ImageAsset +from PIL import Image + + +def rescale_image_size(image: Image.Image, size_factor: float) -> Image.Image: + """Rescale the dimensions of an image by a constant factor.""" + new_width = int(image.width * size_factor) + new_height = int(image.height * size_factor) + return image.resize((new_width, new_height)) + def run_llava(): - llm = LLM(model="llava-hf/llava-1.5-7b-hf") + llm = LLM( + model="llava-hf/llava-v1.6-mistral-7b-hf") # , tensor_parallel_size=2) prompt = "USER: \nWhat is the content of this image?\nASSISTANT:" image = ImageAsset("stop_sign").pil_image - outputs = llm.generate({ + # Showing image of different resolution in a batch. + outputs = llm.generate([{ "prompt": prompt, "multi_modal_data": { "image": image - }, - }) + } + }, { + "prompt": prompt, + "multi_modal_data": { + "image": rescale_image_size(image, 0.25) + } + }]) for o in outputs: generated_text = o.outputs[0].text diff --git a/tests/distributed/test_multimodal_broadcast.py b/tests/distributed/test_multimodal_broadcast.py index 8e0e8ecd675e..a99917f58694 100644 --- a/tests/distributed/test_multimodal_broadcast.py +++ b/tests/distributed/test_multimodal_broadcast.py @@ -19,10 +19,10 @@ model = os.environ["TEST_DIST_MODEL"] -if model.startswith("llava-hf/llava"): +if model.startswith("llava-hf/llava-1.5"): from ..models.test_llava import models, run_test -elif model.startswith("microsoft/Phi-3-vision"): - from ..models.test_phi3v import models, run_test +elif model.startswith("llava-hf/llava-v1.6"): + from ..models.test_llava_next import models, run_test else: raise NotImplementedError(f"Unsupported model: {model}") @@ -45,7 +45,8 @@ def test_models(hf_runner, vllm_runner, image_assets, vllm_runner, image_assets, model=models[0], - size_factors=[1.0], + # So that LLaVA-NeXT processor may return nested list + size_factors=[0.25, 0.5, 1.0], dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, diff --git a/tests/distributed/test_parallel_state.py b/tests/distributed/test_parallel_state.py index 3adcf6b61046..29444eac7e74 100644 --- a/tests/distributed/test_parallel_state.py +++ b/tests/distributed/test_parallel_state.py @@ -9,20 +9,28 @@ def test_split_tensor_dict(): test_dict = { - "key_a": "a", - "key_b": torch.arange(8, dtype=torch.float32), + "key_a": + "a", + "key_b": + torch.arange(8, dtype=torch.float32), "key_c": { "key_1": torch.arange(5, dtype=torch.float32), "key_2": torch.tensor([], dtype=torch.float32), "key_3": 123, }, "key_d": {}, + "key_e": [ + torch.arange(11, dtype=torch.float32), + torch.arange(13, dtype=torch.float32) + ] } metadata_list, tensor_list = _split_tensor_dict(test_dict) - assert len(metadata_list) == 6 + assert len(metadata_list) == 7 assert torch.allclose(tensor_list[0], test_dict["key_b"]) assert torch.allclose(tensor_list[1], test_dict["key_c"]["key_1"]) assert torch.allclose(tensor_list[2], test_dict["key_c"]["key_2"]) + assert torch.allclose(tensor_list[3], test_dict["key_e"][0]) + assert torch.allclose(tensor_list[4], test_dict["key_e"][1]) def test_split_tensor_dict_invalid_key(): diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 2f200c13ea00..9c64f39eb6d0 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -1,14 +1,12 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type import pytest from transformers import AutoConfig, AutoTokenizer -from vllm.model_executor.models.llava_next import ( - get_llava_next_image_feature_size) from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ..conftest import IMAGE_ASSETS +from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from .utils import check_logprobs_close pytestmark = pytest.mark.vlm @@ -27,6 +25,8 @@ IMAGE_TOKEN_ID = 32000 +models = ["llava-hf/llava-v1.6-vicuna-7b-hf"] + def vllm_to_hf_output(vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], @@ -50,34 +50,19 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, return hf_output_ids, hf_output_str, out_logprobs -@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-vicuna-7b-hf"]) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype, max_tokens, num_logprobs) -> None: - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test is under tests/images. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding vision language config as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): images = [asset.pil_image for asset in image_assets] inputs_per_image = [( @@ -89,6 +74,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, with vllm_runner(model, dtype=dtype, max_model_len=4096, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, enforce_eager=True) as vllm_model: vllm_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, @@ -122,9 +109,54 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, ) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, + dtype, max_tokens, num_logprobs) -> None: + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + @pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144), (183, 488, 776)]) def test_image_feature_size(height_and_width_and_result): + # Avoid initializing CUDA too early in distributed tests + from vllm.model_executor.models.llava_next import ( + get_llava_next_image_feature_size) + height, width, result = height_and_width_and_result config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") assert get_llava_next_image_feature_size(config, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e9c6fc3a255e..11e91cacbbe5 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -74,10 +74,29 @@ def _split_tensor_dict( elif isinstance(value, dict): if len(value) == 0: metadata_list.append((prefix + key, value)) - inner_metadata_list, inner_tensor_list = _split_tensor_dict( - value, prefix + key + "%") - metadata_list.extend(inner_metadata_list) - tensor_list.extend(inner_tensor_list) + else: + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%") + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + elif isinstance(value, list): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + elif isinstance(value[0], torch.Tensor): + # should all be Tensors + metadata_list_value = [] + for v in value: + assert isinstance(v, torch.Tensor) + metadata_list_value.append( + TensorMetadata(v.device.type, v.dtype, v.size())) + tensor_list.append(v) + metadata_list.append((prefix + key, metadata_list_value)) + else: + # no nested nested list, only primitive types allowed if not Tensor + assert not any( + isinstance(v, list) or isinstance(v, torch.Tensor) + for v in value) + metadata_list.append((prefix + key, value)) else: metadata_list.append((prefix + key, value)) return metadata_list, tensor_list @@ -561,6 +580,26 @@ def broadcast_tensor_dict( async_op=True) async_handles.append(handle) _update_nested_dict(tensor_dict, key, tensor) + elif isinstance(value, list) and len(value) > 0 and isinstance( + value[0], TensorMetadata): + tensor_list = [] + for t in value: + tensor = torch.empty(t.size, + dtype=t.dtype, + device=t.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_list.append(tensor) + else: + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group + if tensor.is_cpu else group, + async_op=True) + async_handles.append(handle) + tensor_list.append(tensor) + _update_nested_dict(tensor_dict, key, tensor_list) else: _update_nested_dict(tensor_dict, key, value) for async_handle in async_handles: @@ -651,6 +690,21 @@ def recv_tensor_dict( src=self.ranks[src], group=group) _update_nested_dict(tensor_dict, key, tensor) + elif isinstance(value, list) and len(value) > 0 and isinstance( + value[0], TensorMetadata): + tensor_list = [] + for t in value: + tensor = torch.empty(t.size, + dtype=t.dtype, + device=t.device) + if tensor.numel() == 0: + tensor_list.append(tensor) + else: + torch.distributed.recv( + tensor, + src=self.ranks[src], + group=metadata_group if tensor.is_cpu else group) + _update_nested_dict(tensor_dict, key, tensor_list) else: _update_nested_dict(tensor_dict, key, value) return tensor_dict