Skip to content

Commit

Permalink
[Model] Pipeline Parallel Support for DeepSeek v2 (vllm-project#6519)
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
  • Loading branch information
tjohnson31415 authored and kylesayrs committed Aug 17, 2024
1 parent cec589c commit 0ec9a60
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 39 deletions.
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_PP_SUPPORTED_MODELS = [
"AquilaModel",
"AquilaForCausalLM",
"DeepseekV2ForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM",
"LLaMAForCausalLM",
Expand Down
153 changes: 114 additions & 39 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand All @@ -49,6 +50,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput

from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers


class DeepseekV2MLP(nn.Module):

Expand All @@ -59,17 +62,20 @@ def __init__(
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results)
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
Expand All @@ -88,6 +94,7 @@ def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -112,12 +119,14 @@ def __init__(
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group)
topk_group=config.topk_group,
prefix=f"{prefix}.experts")

self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None)
quant_config=None,
prefix=f"{prefix}.gate")
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
Expand Down Expand Up @@ -172,10 +181,9 @@ def __init__(
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
layer_idx=None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
Expand All @@ -195,38 +203,44 @@ def __init__(
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
else:
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.q_proj")

self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size,
self.kv_lora_rank +
self.qk_rope_head_dim,
bias=False,
quant_config=quant_config)
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
# O projection.
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling['type'] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
Expand Down Expand Up @@ -308,7 +322,7 @@ class DeepseekV2DecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
prefix: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
Expand All @@ -318,6 +332,9 @@ def __init__(
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.self_attn = DeepseekV2Attention(
config=config,
hidden_size=self.hidden_size,
Expand All @@ -333,18 +350,23 @@ def __init__(
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
layer_idx=layer_idx,
prefix=f"{prefix}.self_attn",
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand Down Expand Up @@ -389,38 +411,63 @@ def __init__(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
DeepseekV2DecoderLayer(config,
layer_idx,
cache_config=cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
else:
self.embed_tokens = PPMissingLayer()

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(
config,
prefix,
cache_config=cache_config,
quant_config=quant_config,
),
prefix=f"{prefix}.layers")

if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata,
residual)
kv_caches[i - self.start_layer],
attn_metadata, residual)

if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})

hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand All @@ -436,7 +483,10 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = DeepseekV2Model(config, cache_config, quant_config)
self.model = DeepseekV2Model(config,
cache_config,
quant_config,
prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
Expand All @@ -452,7 +502,7 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
Expand All @@ -469,6 +519,20 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
Expand Down Expand Up @@ -504,6 +568,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -514,6 +582,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
Expand All @@ -527,6 +599,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if name.endswith(".bias") and name not in params_dict:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down

0 comments on commit 0ec9a60

Please sign in to comment.