From b29c1dfbcf6cf402fa03469189ce41e358571dbf Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Tue, 4 Jun 2024 18:19:12 +0000 Subject: [PATCH] Add support for the Phi-3 small model --- deepspeed/inference/v2/engine_factory.py | 3 + .../v2/model_implementations/__init__.py | 1 + .../phi3small/__init__.py | 6 + .../phi3small/containers.py | 87 ++++++++ .../model_implementations/phi3small/model.py | 207 ++++++++++++++++++ .../model_implementations/phi3small/policy.py | 30 +++ 6 files changed, 334 insertions(+) create mode 100644 deepspeed/inference/v2/model_implementations/phi3small/__init__.py create mode 100644 deepspeed/inference/v2/model_implementations/phi3small/containers.py create mode 100644 deepspeed/inference/v2/model_implementations/phi3small/model.py create mode 100644 deepspeed/inference/v2/model_implementations/phi3small/policy.py diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py index c21affb9a0de..895d8bbfae15 100644 --- a/deepspeed/inference/v2/engine_factory.py +++ b/deepspeed/inference/v2/engine_factory.py @@ -21,6 +21,7 @@ FalconPolicy, PhiPolicy, Phi3Policy, + Phi3SmallPolicy, QwenPolicy, Qwen2Policy, ) @@ -122,6 +123,8 @@ def build_hf_engine(path: str, policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "phi3": policy = Phi3Policy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "phi3small": + policy = Phi3SmallPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "qwen": policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "qwen2": diff --git a/deepspeed/inference/v2/model_implementations/__init__.py b/deepspeed/inference/v2/model_implementations/__init__.py index e4160ab94949..cd27bf495d94 100644 --- a/deepspeed/inference/v2/model_implementations/__init__.py +++ b/deepspeed/inference/v2/model_implementations/__init__.py @@ -16,5 +16,6 @@ from .falcon import * from .phi import * from .phi3 import * +from .phi3small import * from .qwen import * from .qwen_v2 import * diff --git a/deepspeed/inference/v2/model_implementations/phi3small/__init__.py b/deepspeed/inference/v2/model_implementations/phi3small/__init__.py new file mode 100644 index 000000000000..71df721cf135 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi3small/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import Phi3SmallPolicy diff --git a/deepspeed/inference/v2/model_implementations/phi3small/containers.py b/deepspeed/inference/v2/model_implementations/phi3small/containers.py new file mode 100644 index 000000000000..deb31d311627 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi3small/containers.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Phi-3 model looks like this: + +Phi3SmallForCausalLM( + (model): Phi3Model( + (embed_tokens): Embedding(32064, 3072) + (embed_dropout): Dropout(p=0.0, inplace=False) + (layers): ModuleList( + (0-31): 32 x Phi3DecoderLayer( + (self_attn): Phi3Attention( + (o_proj): Linear(in_features=3072, out_features=3072, bias=False) + (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False) + (rotary_emb): Phi3RotaryEmbedding() + ) + (mlp): PhiMLP( + (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False) + (down_proj): Linear(in_features=16384, out_features=3072, bias=False) + (activation_fn): SiLU() + ) + (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05) + (resid_attn_dropout): Dropout(p=0.0) + (resid_mlp_dropout): Dropout(p=0.0) + (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05) + ) + ) + (final_layernorm): Phi3RMSNorm((3072,), eps=1e-05) + ) + (lm_head): Linear(in_features=3072, out_features=32064, bias=False) +) +''' + + +class Phi3SmallTransformerContainer(LayerContainer): + """ + Transformer layer container for the Phi model. + """ + qkv_w: FusedQKVParameter + qkv_b: FusedQKVParameter + attn_out_w: AttentionOutputParameter + attn_out_b: AttentionOutputParameter + mlp_1_w: MLP1Parameter + mlp_1_b: MLP1Parameter + mlp_2_w: MLP2Parameter + mlp_2_b: MLP2Parameter + attn_norm_gamma: NormParameter + attn_norm_beta: NormParameter + mlp_norm_gamma: NormParameter + mlp_norm_beta: NormParameter + + PARAM_MAPPING = { + "self_attn.query_key_value.weight": "qkv_w.params", + "self_attn.query_key_value.bias": "qkv_b.params", + "self_attn.dense.weight": "attn_out_w.params", + "self_attn.dense.bias": "attn_out_b.params", + "mlp.up_proj.weight": "mlp_1_w.params", + "mlp.up_proj.bias": "mlp_1_b.params", + "mlp.down_proj.weight": "mlp_2_w.params", + "mlp.down_proj.bias": "mlp_2_b.params", + "input_layernorm.weight": "attn_norm_gamma.params", + "input_layernorm.bias": "attn_norm_beta.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + "post_attention_layernorm.bias": "mlp_norm_beta.params", + } + + +class Phi3SmallNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Phi model. + """ + word_emb: EmbeddingParameter + final_norm_gamma: NormParameter + final_norm_beta: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.final_layernorm.weight": "final_norm_gamma.params", + "model.final_layernorm.bias": "final_norm_beta.params", + } diff --git a/deepspeed/inference/v2/model_implementations/phi3small/model.py b/deepspeed/inference/v2/model_implementations/phi3small/model.py new file mode 100644 index 000000000000..e8c22a108611 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi3small/model.py @@ -0,0 +1,207 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from .. import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...ragged import RaggedBatchWrapper + +from .containers import Phi3SmallNonTransformerContainer, Phi3SmallTransformerContainer + + +class Phi3SmallInferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[Phi3SmallNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[Phi3SmallTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties inherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties inherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + activation = self._config.hidden_act.lower() + if activation == "gelu": + return ActivationType.GEGLU + elif activation == "relu": + return ActivationType.ReGLU + elif activation == "gegelu": + return ActivationType.GEGLU + elif activation == "silu": + return ActivationType.SiGLU + else: + raise NotImplementedError(f"Activation {activation} not supported") + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.LayerNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(theta_base=self._config.rope_embedding_base) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=cur_params.attn_out_b) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=cur_params.mlp_norm_beta) + + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=next_params.attn_norm_beta) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + word_unembed = torch.empty(self.vocab_size, self.model_dim, dtype=hidden_states.dtype, device=hidden_states.device) + torch.nn.init.xavier_uniform_(word_unembed) + logits = self.unembed(hidden_states, + word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm_gamma, + beta=self._non_transformer.final_norm_beta) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, gamma=self._transformer[0].attn_norm_gamma, beta=self._transformer[0].attn_norm_beta) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/phi3small/policy.py b/deepspeed/inference/v2/model_implementations/phi3small/policy.py new file mode 100644 index 000000000000..235fa41ac608 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi3small/policy.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .containers import Phi3SmallNonTransformerContainer, Phi3SmallTransformerContainer +from .model import Phi3SmallInferenceModel + + +class Phi3SmallPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Phi3SmallInferenceModel: + return Phi3SmallInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [Phi3SmallTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(Phi3SmallNonTransformerContainer(self.model)) + + map.set_unmapped_params([]) + + return map