Skip to content

Commit

Permalink
Move input-layernorm inside the decoder layer
Browse files Browse the repository at this point in the history
  • Loading branch information
adk9 committed Jun 11, 2024
1 parent 702bad7 commit cb5714e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions deepspeed/inference/v2/model_implementations/phi3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid
cur_params = self._transformer[layer_idx]
kv_cache = self.state_manager.get_cache(layer_idx)

residual, hidden_states = self.norm(residual, hidden_states, gamma=cur_params.attn_norm_gamma, beta=None)

hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None)
hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info)
hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None)
Expand Down Expand Up @@ -195,10 +197,8 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge
def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor:
residual = self._forward_embed(wrapped_batch)

residual, hidden_states = self.norm(residual, None, gamma=self._transformer[0].attn_norm_gamma, beta=None)

for layer_idx in range(self.num_layers):
residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states,
residual, _ = self._forward_transformer_layer(layer_idx, residual, None,
wrapped_batch)

return self._forward_unembed(residual, wrapped_batch)

0 comments on commit cb5714e

Please sign in to comment.