Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(moe): fix moe act late release #387

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def forward(
raise NotImplementedError(f"Invalid backend: {backend}")

input_numel = x.numel()
ctx.input_numel = input_numel
if input_numel == 0:
backend = "bmm"

Expand All @@ -354,13 +355,15 @@ def forward(
if input_numel == 0:
# if inp is empty, reshape to make grad flow.
# inp shape: (0, hdim)
weight = weight.view(x.shape[-1], -1)

output = torch.matmul(x, weight)
output = torch.matmul(x, weight.view(x.shape[-1], -1))
else:
output = torch.matmul(x, weight)

saved_x = None if ctx.compute_weight_gradient is False else x
ctx.save_for_backward(saved_x, weight, batch_sizes)

assert len(output.shape) == len(x.shape)

return output

@staticmethod
Expand All @@ -372,19 +375,33 @@ def backward(ctx, grad_output):
x, weight, batch_sizes = ctx.saved_tensors
grad_input, grad_weight = None, None

if grad_output.numel() == 0:
if ctx.needs_input_grad[1]:
grad_weight = torch.zeros_like(weight)
if ctx.needs_input_grad[0]:
grad_input = torch.zeros_like(x)

return grad_input, grad_weight, None, None, None, None, None

if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
if backend == "gmm":
grad_input, grad_weight = gmm_backward_op(x, grad_output, batch_sizes, input_weight=weight)
else:
grad_weight = torch.matmul(x.transpose(-1, -2), grad_output)
if ctx.input_numel == 0:
grad_weight = torch.zeros_like(weight)
else:
grad_weight = torch.matmul(x.transpose(-1, -2), grad_output)

if ctx.needs_input_grad[0]:
if backend == "gmm":
if grad_input is None:
grad_input, _ = gmm_backward_op(grad_output, weight, batch_sizes, is_grad_input=True)
else:
grad_input = torch.matmul(grad_output, weight.transpose(-1, -2))
if ctx.input_numel == 0:
grad_input = torch.zeros_like(x)
else:
grad_input = torch.matmul(grad_output, weight.transpose(-1, -2))

return grad_input, grad_weight, None, None, None, None, None

Expand Down
3 changes: 0 additions & 3 deletions internlm/model/moe/base_layer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import TYPE_CHECKING, Union

import torch
from torch import Tensor
from torch.nn import Module, ModuleList

from internlm.core.context import global_context as gpc
from internlm.model.moe.experts import Experts
from internlm.utils.common import get_current_device

if TYPE_CHECKING:
Base = Module[Tensor]
Expand All @@ -32,7 +30,6 @@ def __init__(
self.ep_group = ep_group
self.ep_size = ep_size
self.num_local_experts = num_local_experts
self.l_aux = torch.tensor(0.0, device=get_current_device(), dtype=gpc.config.model.get("dtype"))
self.exp_counts = None

for _, param in self.gate.named_parameters():
Expand Down
7 changes: 6 additions & 1 deletion internlm/model/moe/dropless_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,12 @@ def forward(self, *inputs: Tensor) -> Tensor:

# Reshape the output tensor
output = output.view(self.hidden_shape)
return output

# Note: 1. we need to relase self.l_aux and its compute graph; 2. we need self.l_aux to simplify code
# so we first use self.l_aux and then reset it.
l_aux = self.l_aux
self.l_aux = None
return output, l_aux

def topk_softmax_with_capacity(self, gates):
expert_weights, indices = torch.topk(gates, self.topk, dim=1)
Expand Down
4 changes: 2 additions & 2 deletions internlm/model/moe/gshard_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def forward(self, *inputs: Tensor) -> Tensor:
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
reshaped_inputs = inputs[0].reshape(-1, d_model)

self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
dispatched_inputs = einsum(
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
) # TODO: heavy memory usage due to long sequence length
Expand Down Expand Up @@ -608,4 +608,4 @@ def forward(self, *inputs: Tensor) -> Tensor:
timer("moe").stop()
self.time_moe = timer("moe").elapsed(reset=False)

return out
return out, l_aux
4 changes: 2 additions & 2 deletions internlm/model/moe/megablocks/megablock_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,6 @@ def forward(self, *inputs) -> torch.Tensor:

x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)

self.l_aux = self.load_balancing_loss(tokens_per_expert, all_probs)
l_aux = self.load_balancing_loss(tokens_per_expert, all_probs)

return x.view(*input_shape)
return x.view(*input_shape), l_aux
8 changes: 4 additions & 4 deletions internlm/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def forward(self, hidden_states, used_token=None):

* exp_counts (int): expert count
"""
output = self.moe_layer(hidden_states, used_token)
output, l_aux = self.moe_layer(hidden_states, used_token)
if self.num_shared_experts > 0:
# Residual MoE
output_mlp = self.residual_mlp(hidden_states)
Expand All @@ -190,7 +190,7 @@ def forward(self, hidden_states, used_token=None):
coef = self.coefficient(hidden_states)
coef = torch.nn.functional.softmax(coef, dim=-1)
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
return output, l_aux, self.moe_layer.exp_counts


class Qwen2MoE(MoEBase):
Expand Down Expand Up @@ -264,7 +264,7 @@ def forward(self, hidden_states, used_token=None):

* exp_counts (int): expert count
"""
output = self.moe_layer(hidden_states, used_token)
output, l_aux = self.moe_layer(hidden_states, used_token)
if self.num_shared_experts > 0:
# Residual MoE
output_mlp = self.residual_mlp(hidden_states)
Expand All @@ -273,4 +273,4 @@ def forward(self, hidden_states, used_token=None):
coef = self.coefficient(hidden_states)
output_mlp = F.sigmoid(coef) * output_mlp
output = output + output_mlp
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
return output, l_aux, self.moe_layer.exp_counts
Loading