From 542d2a2ea78d296dd13024a2907b09e6c136f471 Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Tue, 11 Jun 2024 23:00:02 +0000 Subject: [PATCH] Move input-layernorm inside decoder layer --- .../model_implementations/phi3small/model.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/deepspeed/inference/v2/model_implementations/phi3small/model.py b/deepspeed/inference/v2/model_implementations/phi3small/model.py index 1f1e853fc167c..55d7dbabde591 100644 --- a/deepspeed/inference/v2/model_implementations/phi3small/model.py +++ b/deepspeed/inference/v2/model_implementations/phi3small/model.py @@ -112,7 +112,6 @@ def positional_embedding_config(self) -> Optional[RotateHalfConfig]: def mup_embedding_multiplier(self) -> float: return 10.0 - """ Forward implementations """ @@ -153,6 +152,11 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid cur_params = self._transformer[layer_idx] kv_cache = self.state_manager.get_cache(layer_idx) + residual, hidden_states = self.norm(residual, + hidden_states, + gamma=cur_params.attn_norm_gamma, + beta=cur_params.attn_norm_beta) + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=cur_params.attn_out_b) @@ -160,7 +164,10 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid if self.tp_size > 1: dist.all_reduce(hidden_states, group=self._base_mp_group) - residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=cur_params.mlp_norm_beta) + residual, hidden_states = self.norm(residual, + hidden_states, + cur_params.mlp_norm_gamma, + beta=cur_params.mlp_norm_beta) hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) @@ -170,7 +177,10 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid if layer_idx != self.num_layers - 1: next_params = self._transformer[layer_idx + 1] - residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=next_params.attn_norm_beta) + residual, hidden_states = self.norm(residual, + hidden_states, + next_params.attn_norm_gamma, + beta=next_params.attn_norm_beta) else: # On last layer, we just need to perform the residual add. Adding into the residual # here is safe. @@ -205,10 +215,7 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: residual = self._forward_embed(wrapped_batch) - residual, hidden_states = self.norm(residual, None, gamma=self._transformer[0].attn_norm_gamma, beta=self._transformer[0].attn_norm_beta) - for layer_idx in range(self.num_layers): - residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, - wrapped_batch) + residual, _ = self._forward_transformer_layer(layer_idx, residual, None, wrapped_batch) return self._forward_unembed(residual, wrapped_batch)