Skip to content

Commit

Permalink
Add fp16 support of Qwen1.5MoE models (A2.7B) to DeepSpeed-FastGen
Browse files Browse the repository at this point in the history
  • Loading branch information
ZonePG committed Apr 12, 2024
1 parent a8b8215 commit b458df5
Show file tree
Hide file tree
Showing 9 changed files with 508 additions and 1 deletion.
2 changes: 2 additions & 0 deletions blogs/deepspeed-fastgen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ We currently support the following model architectures in this alpha release of
* [Mixtral](https://huggingface.co/models?other=mixtral)
* [Phi-2](https://huggingface.co/models?other=phi-msft)
* [Qwen](https://huggingface.co/models?other=qwen)
* [Qwen2](https://huggingface.co/models?other=qwen2)
* [Qwen2-MoE](https://huggingface.co/models?other=qwen2_moe)

All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer.

Expand Down
3 changes: 3 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PhiPolicy,
QwenPolicy,
Qwen2Policy,
Qwen2MoePolicy,
)
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata
Expand Down Expand Up @@ -123,6 +124,8 @@ def build_hf_engine(path: str,
policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen2":
policy = Qwen2Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen2_moe":
policy = Qwen2MoePolicy(model_config, checkpoint_engine=checkpoint_engine)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,8 @@
} else if (2 == N_TOP_K) { \
constexpr int CONST_TOP_K = 2; \
__VA_ARGS__(); \
} else if (4 == N_TOP_K) { \
constexpr int CONST_TOP_K = 4; \
__VA_ARGS__(); \
} \
}()
1 change: 1 addition & 0 deletions deepspeed/inference/v2/model_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .phi import *
from .qwen import *
from .qwen_v2 import *
from .qwen_v2_moe import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .policy import Qwen2MoePolicy
103 changes: 103 additions & 0 deletions deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# Create a container object to save model-specific tensors using the policy file above.

from ..common_parameters import *
from ..layer_container_base import LayerContainer
'''
# HF Qwen1.5-MoE-A2.7B model looks like this:
Qwen2MoeForCausalLM(
(model): Qwen2MoeModel(
(embed_tokens): Embedding(151936, 2048)
(layers): ModuleList(
(0-23): 24 x Qwen2MoeDecoderLayer(
(self_attn): Qwen2MoeSdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=True)
(k_proj): Linear(in_features=2048, out_features=2048, bias=True)
(v_proj): Linear(in_features=2048, out_features=2048, bias=True)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): Qwen2MoeRotaryEmbedding()
)
(mlp): Qwen2MoeSparseMoeBlock(
(gate): Linear(in_features=2048, out_features=60, bias=False)
(experts): ModuleList(
(0-59): 60 x Qwen2MoeMLP(
(gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
(up_proj): Linear(in_features=2048, out_features=1408, bias=False)
(down_proj): Linear(in_features=1408, out_features=2048, bias=False)
(act_fn): SiLU()
)
)
(shared_expert): Qwen2MoeMLP(
(gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
(up_proj): Linear(in_features=2048, out_features=5632, bias=False)
(down_proj): Linear(in_features=5632, out_features=2048, bias=False)
(act_fn): SiLU()
)
(shared_expert_gate): Linear(in_features=2048, out_features=1, bias=False)
)
(input_layernorm): Qwen2MoeRMSNorm()
(post_attention_layernorm): Qwen2MoeRMSNorm()
)
)
(norm): Qwen2MoeRMSNorm()
)
(lm_head): Linear(in_features=2048, out_features=151936, bias=False)
)
'''


class Qwen2MoeTransformerContainer(LayerContainer):
"""
Transformer layer container for the Qwen2Moe model.
"""
qkv_w: UnfusedQKVParameter
qkv_b: UnfusedQKVParameter
attn_out_w: AttentionOutputParameter
moe_gate: MoEGatingWeightParameter
moe_mlp_1: UnfusedMoEGatedMLPParameter
moe_mlp_2: UnfusedMoEMLP2Parameter
shared_moe_mlp_1: GatedMLPParameter
shared_moe_mlp_2: MLP2Parameter
shared_moe_gate: MoEGatingWeightParameter
attn_norm_gamma: NormParameter
mlp_norm_gamma: NormParameter

PARAM_MAPPING = {
"self_attn.q_proj.weight": "qkv_w.q_params",
"self_attn.k_proj.weight": "qkv_w.k_params",
"self_attn.v_proj.weight": "qkv_w.v_params",
"self_attn.q_proj.bias": "qkv_b.q_params",
"self_attn.k_proj.bias": "qkv_b.k_params",
"self_attn.v_proj.bias": "qkv_b.v_params",
"self_attn.o_proj.weight": "attn_out_w.params",
"mlp.gate.weight": "moe_gate.params",
"mlp.experts.*.gate_proj.weight": "moe_mlp_1.gating_experts",
"mlp.experts.*.up_proj.weight": "moe_mlp_1.up_experts",
"mlp.experts.*.down_proj.weight": "moe_mlp_2.experts",
"mlp.shared_expert.gate_proj.weight": "shared_moe_mlp_1.gate_params",
"mlp.shared_expert.up_proj.weight": "shared_moe_mlp_1.up_params",
"mlp.shared_expert.down_proj.weight": "shared_moe_mlp_2.params",
"mlp.shared_expert_gate.weight": "shared_moe_gate.params",
"input_layernorm.weight": "attn_norm_gamma.params",
"post_attention_layernorm.weight": "mlp_norm_gamma.params",
}


class Qwen2MoeNonTransformerContainer(LayerContainer):
"""
Non-Transformer layer container for the Qwen2Moe model.
"""
word_emb: EmbeddingParameter
word_unembed: UnembedParameter
final_norm: NormParameter

PARAM_MAPPING = {
"model.embed_tokens.weight": "word_emb.params",
"model.norm.weight": "final_norm.params",
"lm_head.weight": "word_unembed.params",
}
Loading

0 comments on commit b458df5

Please sign in to comment.