From 699a4f6d1f603a908870446aab655e1217d37353 Mon Sep 17 00:00:00 2001 From: shidongxing Date: Wed, 4 Dec 2024 17:45:42 +0800 Subject: [PATCH 1/3] fix moe act release --- internlm/model/moe/base_layer.py | 1 - internlm/model/moe/dropless_layer.py | 7 ++++++- internlm/model/moe/gshard_layer.py | 4 ++-- internlm/model/moe/megablocks/megablock_moe.py | 4 ++-- internlm/model/moe/moe.py | 8 ++++---- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/internlm/model/moe/base_layer.py b/internlm/model/moe/base_layer.py index fa02f145..2649cbd0 100644 --- a/internlm/model/moe/base_layer.py +++ b/internlm/model/moe/base_layer.py @@ -32,7 +32,6 @@ def __init__( self.ep_group = ep_group self.ep_size = ep_size self.num_local_experts = num_local_experts - self.l_aux = torch.tensor(0.0, device=get_current_device(), dtype=gpc.config.model.get("dtype")) self.exp_counts = None for _, param in self.gate.named_parameters(): diff --git a/internlm/model/moe/dropless_layer.py b/internlm/model/moe/dropless_layer.py index f5881dfb..031c2306 100644 --- a/internlm/model/moe/dropless_layer.py +++ b/internlm/model/moe/dropless_layer.py @@ -288,7 +288,12 @@ def forward(self, *inputs: Tensor) -> Tensor: # Reshape the output tensor output = output.view(self.hidden_shape) - return output + + # Note: 1. we need to relase self.l_aux and its compute graph; 2. we need self.l_aux to simplify code + # so we first use self.l_aux and then reset it. + l_aux = self.l_aux + self.l_aux = None + return output, l_aux def topk_softmax_with_capacity(self, gates): expert_weights, indices = torch.topk(gates, self.topk, dim=1) diff --git a/internlm/model/moe/gshard_layer.py b/internlm/model/moe/gshard_layer.py index 3aba8d1a..a102b8c9 100644 --- a/internlm/model/moe/gshard_layer.py +++ b/internlm/model/moe/gshard_layer.py @@ -555,7 +555,7 @@ def forward(self, *inputs: Tensor) -> Tensor: # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 reshaped_inputs = inputs[0].reshape(-1, d_model) - self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1]) + l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1]) dispatched_inputs = einsum( "sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs ) # TODO: heavy memory usage due to long sequence length @@ -608,4 +608,4 @@ def forward(self, *inputs: Tensor) -> Tensor: timer("moe").stop() self.time_moe = timer("moe").elapsed(reset=False) - return out + return out, l_aux diff --git a/internlm/model/moe/megablocks/megablock_moe.py b/internlm/model/moe/megablocks/megablock_moe.py index 82fa3062..257585da 100644 --- a/internlm/model/moe/megablocks/megablock_moe.py +++ b/internlm/model/moe/megablocks/megablock_moe.py @@ -303,6 +303,6 @@ def forward(self, *inputs) -> torch.Tensor: x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) - self.l_aux = self.load_balancing_loss(tokens_per_expert, all_probs) + l_aux = self.load_balancing_loss(tokens_per_expert, all_probs) - return x.view(*input_shape) + return x.view(*input_shape), l_aux diff --git a/internlm/model/moe/moe.py b/internlm/model/moe/moe.py index 0bd35e5b..67fc40b5 100644 --- a/internlm/model/moe/moe.py +++ b/internlm/model/moe/moe.py @@ -181,7 +181,7 @@ def forward(self, hidden_states, used_token=None): * exp_counts (int): expert count """ - output = self.moe_layer(hidden_states, used_token) + output, l_aux = self.moe_layer(hidden_states, used_token) if self.num_shared_experts > 0: # Residual MoE output_mlp = self.residual_mlp(hidden_states) @@ -190,7 +190,7 @@ def forward(self, hidden_states, used_token=None): coef = self.coefficient(hidden_states) coef = torch.nn.functional.softmax(coef, dim=-1) output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] - return output, self.moe_layer.l_aux, self.moe_layer.exp_counts + return output, l_aux, self.moe_layer.exp_counts class Qwen2MoE(MoEBase): @@ -264,7 +264,7 @@ def forward(self, hidden_states, used_token=None): * exp_counts (int): expert count """ - output = self.moe_layer(hidden_states, used_token) + output, l_aux = self.moe_layer(hidden_states, used_token) if self.num_shared_experts > 0: # Residual MoE output_mlp = self.residual_mlp(hidden_states) @@ -273,4 +273,4 @@ def forward(self, hidden_states, used_token=None): coef = self.coefficient(hidden_states) output_mlp = F.sigmoid(coef) * output_mlp output = output + output_mlp - return output, self.moe_layer.l_aux, self.moe_layer.exp_counts + return output, l_aux, self.moe_layer.exp_counts From 3094978af5310cedc17c971cbbaca8b5c60912d5 Mon Sep 17 00:00:00 2001 From: shidongxing Date: Wed, 4 Dec 2024 18:57:28 +0800 Subject: [PATCH 2/3] fix sp gmm bug --- internlm/model/modules/linear.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 856e6ba0..28489ce5 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -338,6 +338,7 @@ def forward( raise NotImplementedError(f"Invalid backend: {backend}") input_numel = x.numel() + ctx.input_numel = input_numel if input_numel == 0: backend = "bmm" @@ -354,13 +355,15 @@ def forward( if input_numel == 0: # if inp is empty, reshape to make grad flow. # inp shape: (0, hdim) - weight = weight.view(x.shape[-1], -1) - - output = torch.matmul(x, weight) + output = torch.matmul(x, weight.view(x.shape[-1], -1)) + else: + output = torch.matmul(x, weight) saved_x = None if ctx.compute_weight_gradient is False else x ctx.save_for_backward(saved_x, weight, batch_sizes) + assert len(output.shape) == len(x.shape) + return output @staticmethod @@ -372,23 +375,36 @@ def backward(ctx, grad_output): x, weight, batch_sizes = ctx.saved_tensors grad_input, grad_weight = None, None + if grad_output.numel() == 0: + if ctx.needs_input_grad[1]: + grad_weight = torch.zeros_like(weight) + if ctx.needs_input_grad[0]: + grad_input = torch.zeros_like(x) + + return grad_input, grad_weight, None, None, None, None, None + if ctx.needs_input_grad[1]: assert ctx.compute_weight_gradient if backend == "gmm": grad_input, grad_weight = gmm_backward_op(x, grad_output, batch_sizes, input_weight=weight) else: - grad_weight = torch.matmul(x.transpose(-1, -2), grad_output) + if ctx.input_numel == 0: + grad_weight = torch.zeros_like(weight) + else: + grad_weight = torch.matmul(x.transpose(-1, -2), grad_output) if ctx.needs_input_grad[0]: if backend == "gmm": if grad_input is None: grad_input, _ = gmm_backward_op(grad_output, weight, batch_sizes, is_grad_input=True) else: - grad_input = torch.matmul(grad_output, weight.transpose(-1, -2)) + if ctx.input_numel == 0: + grad_input = torch.zeros_like(x) + else: + grad_input = torch.matmul(grad_output, weight.transpose(-1, -2)) return grad_input, grad_weight, None, None, None, None, None - class GroupedGemmWPFusedDenseFunc(torch.autograd.Function): "Grouped Gemm FusedDenseFunc for weigth parallel." From 4a7dc466a36c74ee3eea984b92c2325e6919eda4 Mon Sep 17 00:00:00 2001 From: shidongxing Date: Wed, 4 Dec 2024 19:05:38 +0800 Subject: [PATCH 3/3] lint --- internlm/model/modules/linear.py | 1 + internlm/model/moe/base_layer.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 28489ce5..7659df59 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -405,6 +405,7 @@ def backward(ctx, grad_output): return grad_input, grad_weight, None, None, None, None, None + class GroupedGemmWPFusedDenseFunc(torch.autograd.Function): "Grouped Gemm FusedDenseFunc for weigth parallel." diff --git a/internlm/model/moe/base_layer.py b/internlm/model/moe/base_layer.py index 2649cbd0..7811e056 100644 --- a/internlm/model/moe/base_layer.py +++ b/internlm/model/moe/base_layer.py @@ -1,12 +1,10 @@ from typing import TYPE_CHECKING, Union -import torch from torch import Tensor from torch.nn import Module, ModuleList from internlm.core.context import global_context as gpc from internlm.model.moe.experts import Experts -from internlm.utils.common import get_current_device if TYPE_CHECKING: Base = Module[Tensor]