Skip to content

Commit

Permalink
[Bugfix] Fix InternVL2 vision embeddings process with pipeline parall…
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored and Jeffwan committed Sep 19, 2024
1 parent 6570f8f commit bafac13
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
10 changes: 8 additions & 2 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "ray"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "ray"),
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "ray"),
],
)
@fork_new_process_for_each_test
Expand All @@ -46,6 +48,8 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"8192",
"--pipeline-parallel-size",
str(PP_SIZE),
"--tensor-parallel-size",
Expand All @@ -62,7 +66,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
tp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"float16",
"--max-model-len",
"8192",
"--tensor-parallel-size",
str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI.
"--distributed-executor-backend",
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
Expand Down Expand Up @@ -480,7 +481,7 @@ def forward(
**kwargs: object,
) -> SamplerOutput:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
if image_input is not None and get_pp_group().is_first_rank:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
Expand Down

0 comments on commit bafac13

Please sign in to comment.