diff --git a/deepspeed/inference/v2/model_implementations/phi3/model.py b/deepspeed/inference/v2/model_implementations/phi3/model.py index 507bb4fc9af1a..5371a0021e5c1 100644 --- a/deepspeed/inference/v2/model_implementations/phi3/model.py +++ b/deepspeed/inference/v2/model_implementations/phi3/model.py @@ -145,6 +145,8 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid cur_params = self._transformer[layer_idx] kv_cache = self.state_manager.get_cache(layer_idx) + residual, hidden_states = self.norm(residual, hidden_states, gamma=cur_params.attn_norm_gamma, beta=None) + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) @@ -195,10 +197,8 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: residual = self._forward_embed(wrapped_batch) - residual, hidden_states = self.norm(residual, None, gamma=self._transformer[0].attn_norm_gamma, beta=None) - for layer_idx in range(self.num_layers): - residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + residual, _ = self._forward_transformer_layer(layer_idx, residual, None, wrapped_batch) return self._forward_unembed(residual, wrapped_batch)