-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for the Phi-3 small model
- Loading branch information
Showing
6 changed files
with
334 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 6 additions & 0 deletions
6
deepspeed/inference/v2/model_implementations/phi3small/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from .policy import Phi3SmallPolicy |
87 changes: 87 additions & 0 deletions
87
deepspeed/inference/v2/model_implementations/phi3small/containers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
# Create a container object to save model-specific tensors using the policy file above. | ||
|
||
from ..common_parameters import * | ||
from ..layer_container_base import LayerContainer | ||
''' | ||
# HF Phi-3 model looks like this: | ||
Phi3SmallForCausalLM( | ||
(model): Phi3Model( | ||
(embed_tokens): Embedding(32064, 3072) | ||
(embed_dropout): Dropout(p=0.0, inplace=False) | ||
(layers): ModuleList( | ||
(0-31): 32 x Phi3DecoderLayer( | ||
(self_attn): Phi3Attention( | ||
(o_proj): Linear(in_features=3072, out_features=3072, bias=False) | ||
(qkv_proj): Linear(in_features=3072, out_features=9216, bias=False) | ||
(rotary_emb): Phi3RotaryEmbedding() | ||
) | ||
(mlp): PhiMLP( | ||
(gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False) | ||
(down_proj): Linear(in_features=16384, out_features=3072, bias=False) | ||
(activation_fn): SiLU() | ||
) | ||
(input_layernorm): Phi3RMSNorm((3072,), eps=1e-05) | ||
(resid_attn_dropout): Dropout(p=0.0) | ||
(resid_mlp_dropout): Dropout(p=0.0) | ||
(post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05) | ||
) | ||
) | ||
(final_layernorm): Phi3RMSNorm((3072,), eps=1e-05) | ||
) | ||
(lm_head): Linear(in_features=3072, out_features=32064, bias=False) | ||
) | ||
''' | ||
|
||
|
||
class Phi3SmallTransformerContainer(LayerContainer): | ||
""" | ||
Transformer layer container for the Phi model. | ||
""" | ||
qkv_w: FusedQKVParameter | ||
qkv_b: FusedQKVParameter | ||
attn_out_w: AttentionOutputParameter | ||
attn_out_b: AttentionOutputParameter | ||
mlp_1_w: MLP1Parameter | ||
mlp_1_b: MLP1Parameter | ||
mlp_2_w: MLP2Parameter | ||
mlp_2_b: MLP2Parameter | ||
attn_norm_gamma: NormParameter | ||
attn_norm_beta: NormParameter | ||
mlp_norm_gamma: NormParameter | ||
mlp_norm_beta: NormParameter | ||
|
||
PARAM_MAPPING = { | ||
"self_attn.query_key_value.weight": "qkv_w.params", | ||
"self_attn.query_key_value.bias": "qkv_b.params", | ||
"self_attn.dense.weight": "attn_out_w.params", | ||
"self_attn.dense.bias": "attn_out_b.params", | ||
"mlp.up_proj.weight": "mlp_1_w.params", | ||
"mlp.up_proj.bias": "mlp_1_b.params", | ||
"mlp.down_proj.weight": "mlp_2_w.params", | ||
"mlp.down_proj.bias": "mlp_2_b.params", | ||
"input_layernorm.weight": "attn_norm_gamma.params", | ||
"input_layernorm.bias": "attn_norm_beta.params", | ||
"post_attention_layernorm.weight": "mlp_norm_gamma.params", | ||
"post_attention_layernorm.bias": "mlp_norm_beta.params", | ||
} | ||
|
||
|
||
class Phi3SmallNonTransformerContainer(LayerContainer): | ||
""" | ||
Non-Transformer layer container for the Phi model. | ||
""" | ||
word_emb: EmbeddingParameter | ||
final_norm_gamma: NormParameter | ||
final_norm_beta: NormParameter | ||
|
||
PARAM_MAPPING = { | ||
"model.embed_tokens.weight": "word_emb.params", | ||
"model.final_layernorm.weight": "final_norm_gamma.params", | ||
"model.final_layernorm.bias": "final_norm_beta.params", | ||
} |
207 changes: 207 additions & 0 deletions
207
deepspeed/inference/v2/model_implementations/phi3small/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from typing import Iterable, Optional, Tuple | ||
|
||
import torch | ||
|
||
import deepspeed.comm as dist | ||
|
||
from ...allocator import empty_from | ||
from ...inference_utils import ActivationType, DtypeEnum | ||
from .. import * | ||
from ...modules.configs import * | ||
from ...modules.interfaces import * | ||
from ...ragged import RaggedBatchWrapper | ||
|
||
from .containers import Phi3SmallNonTransformerContainer, Phi3SmallTransformerContainer | ||
|
||
|
||
class Phi3SmallInferenceModel(DSTransformerModelBase): | ||
""" | ||
Inference model implementation for ragged batching for Llama-2 models. | ||
""" | ||
|
||
_non_transformer: Optional[Phi3SmallNonTransformerContainer] | ||
""" | ||
Embed + unembed container. Specializing the type annotation. | ||
""" | ||
|
||
_transformer: Optional[Iterable[Phi3SmallTransformerContainer]] | ||
""" | ||
Per-layer transformer container. Specializing the type annotation. | ||
""" | ||
""" | ||
Properties inherited from `DSInferenceModelBase` | ||
""" | ||
|
||
@property | ||
def max_sequence_length(self) -> int: | ||
return self._config.max_seq_length | ||
|
||
""" | ||
Properties inherited from `DSTransformerModelBase` | ||
""" | ||
|
||
@property | ||
def num_layers(self) -> int: | ||
return self._config.num_hidden_layers | ||
|
||
@property | ||
def model_dim(self) -> int: | ||
return self._config.hidden_size | ||
|
||
@property | ||
def vocab_size(self) -> int: | ||
return self._config.vocab_size | ||
|
||
@property | ||
def head_size(self) -> int: | ||
return self.model_dim // self.n_heads | ||
|
||
@property | ||
def n_heads(self) -> int: | ||
return self._config.num_attention_heads | ||
|
||
@property | ||
def intermediate_dim(self) -> int: | ||
return self._config.intermediate_size | ||
|
||
@property | ||
def n_heads_kv(self) -> int: | ||
return self._config.num_key_value_heads | ||
|
||
@property | ||
def activation_dtype(self) -> DtypeEnum: | ||
if self._config.torch_dtype == torch.float16: | ||
return DtypeEnum.fp16 | ||
elif self._config.torch_dtype == torch.bfloat16: | ||
return DtypeEnum.bf16 | ||
else: | ||
raise NotImplementedError("Only fp16 and bf16 are supported") | ||
|
||
@property | ||
def mlp_activation_fn(self) -> ActivationType: | ||
activation = self._config.hidden_act.lower() | ||
if activation == "gelu": | ||
return ActivationType.GEGLU | ||
elif activation == "relu": | ||
return ActivationType.ReGLU | ||
elif activation == "gegelu": | ||
return ActivationType.GEGLU | ||
elif activation == "silu": | ||
return ActivationType.SiGLU | ||
else: | ||
raise NotImplementedError(f"Activation {activation} not supported") | ||
|
||
@property | ||
def norm_type(self) -> NormTypeEnum: | ||
return NormTypeEnum.LayerNorm | ||
|
||
@property | ||
def positional_embedding_type(self) -> PositionalEmbeddingType: | ||
return PositionalEmbeddingType.rotate_half | ||
|
||
@property | ||
def positional_embedding_config(self) -> Optional[RotateHalfConfig]: | ||
return RotateHalfConfig(theta_base=self._config.rope_embedding_base) | ||
|
||
""" | ||
Forward implementations | ||
""" | ||
|
||
def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: | ||
""" | ||
Performs the embedding lookup prior to running the transformer of the model. | ||
Arguments: | ||
ragged_batch (RaggedBatchWrapper): The batch to embed. | ||
Returns: | ||
torch.Tensor: The embedded batch. | ||
""" | ||
embed = self.embed(ragged_batch, self._non_transformer.word_emb) | ||
|
||
if embed.shape[-1] != self.model_dim: | ||
raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") | ||
|
||
return embed | ||
|
||
def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, | ||
ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead | ||
optimization to fuse the layer norm of the next layer into the current layer. | ||
Arguments: | ||
layer_idx (int): The index of the layer to execute. | ||
residual (torch.Tensor): The residual tensor from the previous layer. | ||
hidden_states (torch.Tensor): The hidden states from the previous layer. This is the | ||
hidden states after pre normalization. | ||
ragged_batch_info (RaggedBatchWrapper): The batch metadata. | ||
""" | ||
cur_params = self._transformer[layer_idx] | ||
kv_cache = self.state_manager.get_cache(layer_idx) | ||
|
||
hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) | ||
hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) | ||
hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=cur_params.attn_out_b) | ||
|
||
if self.tp_size > 1: | ||
dist.all_reduce(hidden_states, group=self._base_mp_group) | ||
|
||
residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=cur_params.mlp_norm_beta) | ||
|
||
hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) | ||
hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) | ||
|
||
if self.tp_size > 1: | ||
dist.all_reduce(hidden_states, group=self._base_mp_group) | ||
|
||
if layer_idx != self.num_layers - 1: | ||
next_params = self._transformer[layer_idx + 1] | ||
residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=next_params.attn_norm_beta) | ||
else: | ||
# On last layer, we just need to perform the residual add. Adding into the residual | ||
# here is safe. | ||
residual.add_(hidden_states) | ||
|
||
return residual, hidden_states | ||
|
||
def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: | ||
""" | ||
Performs unembedding of the hidden states to logits. This will only sample the final | ||
token of each sequence. | ||
""" | ||
word_unembed = torch.empty(self.vocab_size, self.model_dim, dtype=hidden_states.dtype, device=hidden_states.device) | ||
torch.nn.init.xavier_uniform_(word_unembed) | ||
logits = self.unembed(hidden_states, | ||
word_unembed, | ||
ragged_batch_info, | ||
gamma=self._non_transformer.final_norm_gamma, | ||
beta=self._non_transformer.final_norm_beta) | ||
|
||
if self.tp_size > 1: | ||
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) | ||
full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) | ||
|
||
dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) | ||
|
||
full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) | ||
|
||
return full_logits | ||
else: | ||
return logits | ||
|
||
def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: | ||
residual = self._forward_embed(wrapped_batch) | ||
|
||
residual, hidden_states = self.norm(residual, None, gamma=self._transformer[0].attn_norm_gamma, beta=self._transformer[0].attn_norm_beta) | ||
|
||
for layer_idx in range(self.num_layers): | ||
residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, | ||
wrapped_batch) | ||
|
||
return self._forward_unembed(residual, wrapped_batch) |
30 changes: 30 additions & 0 deletions
30
deepspeed/inference/v2/model_implementations/phi3small/policy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from typing import Any | ||
|
||
from ...config_v2 import RaggedInferenceEngineConfig | ||
from ..inference_policy_base import ContainerMap, InferenceV2Policy | ||
from .containers import Phi3SmallNonTransformerContainer, Phi3SmallTransformerContainer | ||
from .model import Phi3SmallInferenceModel | ||
|
||
|
||
class Phi3SmallPolicy(InferenceV2Policy): | ||
|
||
def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Phi3SmallInferenceModel: | ||
return Phi3SmallInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) | ||
|
||
def build_container_map(self) -> ContainerMap: | ||
map = ContainerMap() | ||
|
||
transformer_containers = [Phi3SmallTransformerContainer(self.model) for _ in range(self.model.num_layers)] | ||
|
||
map.set_transformer_params(['model.layers'], transformer_containers) | ||
|
||
map.set_non_transformer_params(Phi3SmallNonTransformerContainer(self.model)) | ||
|
||
map.set_unmapped_params([]) | ||
|
||
return map |