Skip to content

Commit

Permalink
Move input-layernorm inside decoder layer
Browse files Browse the repository at this point in the history
  • Loading branch information
adk9 committed Jun 11, 2024
1 parent 3d24681 commit 542d2a2
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions deepspeed/inference/v2/model_implementations/phi3small/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
def mup_embedding_multiplier(self) -> float:
return 10.0


"""
Forward implementations
"""
Expand Down Expand Up @@ -153,14 +152,22 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid
cur_params = self._transformer[layer_idx]
kv_cache = self.state_manager.get_cache(layer_idx)

residual, hidden_states = self.norm(residual,
hidden_states,
gamma=cur_params.attn_norm_gamma,
beta=cur_params.attn_norm_beta)

hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b)
hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info)
hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=cur_params.attn_out_b)

if self.tp_size > 1:
dist.all_reduce(hidden_states, group=self._base_mp_group)

residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=cur_params.mlp_norm_beta)
residual, hidden_states = self.norm(residual,
hidden_states,
cur_params.mlp_norm_gamma,
beta=cur_params.mlp_norm_beta)

hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None)
hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None)
Expand All @@ -170,7 +177,10 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid

if layer_idx != self.num_layers - 1:
next_params = self._transformer[layer_idx + 1]
residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=next_params.attn_norm_beta)
residual, hidden_states = self.norm(residual,
hidden_states,
next_params.attn_norm_gamma,
beta=next_params.attn_norm_beta)
else:
# On last layer, we just need to perform the residual add. Adding into the residual
# here is safe.
Expand Down Expand Up @@ -205,10 +215,7 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge
def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor:
residual = self._forward_embed(wrapped_batch)

residual, hidden_states = self.norm(residual, None, gamma=self._transformer[0].attn_norm_gamma, beta=self._transformer[0].attn_norm_beta)

for layer_idx in range(self.num_layers):
residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states,
wrapped_batch)
residual, _ = self._forward_transformer_layer(layer_idx, residual, None, wrapped_batch)

return self._forward_unembed(residual, wrapped_batch)

0 comments on commit 542d2a2

Please sign in to comment.