Skip to content

Commit

Permalink
Inclusion of InternVLChatModel In PP_SUPPORTED_MODELS(Pipeline Parall…
Browse files Browse the repository at this point in the history
…elism) (vllm-project#7860)
  • Loading branch information
Manikandan-Thangaraj-ZS0321 authored and dtrifiro committed Sep 12, 2024
1 parent d3b718f commit b7bf797
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 35 deletions.
38 changes: 22 additions & 16 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"


@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
"MODEL_NAME, DIST_BACKEND"),
[
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
])
@pytest.mark.parametrize(
("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, "
"MODEL_NAME, DIST_BACKEND"),
[
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"),
],
)
@fork_new_process_for_each_test
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND):
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
Expand Down Expand Up @@ -71,6 +74,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
if EAGER_MODE:
pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager")
if TRUST_REMOTE_CODE:
pp_args.append("--trust-remote-code")
tp_args.append("--trust-remote-code")
pp_env = None
if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
and CHUNKED_PREFILL):
Expand Down
7 changes: 6 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,12 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server.
"""

tokenizer = AutoTokenizer.from_pretrained(model)
trust_remote_code = "--trust-remote-code"
if trust_remote_code in arg1 or trust_remote_code in arg2:
tokenizer = AutoTokenizer.from_pretrained(model,
trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(model)

prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"]
Expand Down
8 changes: 5 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,20 @@
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096

_PP_SUPPORTED_MODELS = [
"AquilaModel",
"AquilaForCausalLM",
"AquilaModel",
"DeepseekV2ForCausalLM",
"GPT2LMHeadModel",
"InternLM2ForCausalLM",
"InternLMForCausalLM",
"InternVLChatModel",
"JAISLMHeadModel",
"LlamaForCausalLM",
"LLaMAForCausalLM",
"MistralForCausalLM",
"Phi3ForCausalLM",
"GPT2LMHeadModel",
"MixtralForCausalLM",
"NemotronForCausalLM",
"Phi3ForCausalLM",
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
"QWenLMHeadModel",
Expand Down
52 changes: 38 additions & 14 deletions vllm/model_executor/models/internlm2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# -*- coding: utf-8 -*-
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
Expand All @@ -28,6 +28,9 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)


class InternLM2MLP(nn.Module):

Expand Down Expand Up @@ -234,6 +237,7 @@ def __init__(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
Expand All @@ -243,11 +247,15 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
InternLMDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLMDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids)
Expand All @@ -260,21 +268,31 @@ def forward(
attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.tok_embeddings(input_ids)
residual = None
else:
hidden_states = self.tok_embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand All @@ -298,6 +316,8 @@ def __init__(
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def forward(
self,
Expand All @@ -308,7 +328,7 @@ def forward(
intermediate_tensors: IntermediateTensors,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states

def compute_logits(
Expand Down Expand Up @@ -345,6 +365,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -353,6 +375,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def __init__(self,
nn.Linear(llm_hidden_size, llm_hidden_size))

self.img_context_token_id = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
Expand Down Expand Up @@ -461,7 +463,7 @@ def forward(
positions,
kv_caches,
attn_metadata,
None,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states

Expand Down
16 changes: 16 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
if name.startswith(missing_layer_name):
return True
return False


def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):

def make_empty_intermediate_tensors(
batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
key: torch.zeros((batch_size, hidden_size),
dtype=dtype,
device=device)
for key in keys
})

return make_empty_intermediate_tensors

0 comments on commit b7bf797

Please sign in to comment.