Skip to content

Commit

Permalink
fix(moe): dropless moe loss (#348)
Browse files Browse the repository at this point in the history
  • Loading branch information
blankde authored Oct 18, 2024
1 parent c3dfe0f commit 0f99777
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions internlm/model/moe/dropless_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __init__(
moe_grouped_mlp: bool = True,
enable_fused_permute: bool = True,
token_dispatch_policy: str = "alltoall",
deterministic_mode: bool = False,
) -> None:
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
"Unsupported noisy_gate_policy: " + noisy_gate_policy
Expand All @@ -175,6 +176,7 @@ def __init__(
mlp_layer_fusion=mlp_layer_fusion,
multiple_of=multiple_of,
activation_type=activation_type,
is_expert=True,
)
for _ in range(num_experts // ep_size)
]
Expand All @@ -199,6 +201,7 @@ def __init__(
assert len(self.local_expert_indices) > 0, "Expected at least one local expert index"
self.topk = top_k
self.moe_grouped_mlp = moe_grouped_mlp
self.deterministic_mode = deterministic_mode

self.drop_and_pad = drop_and_pad
self.capacity_factor = capacity_factor
Expand Down Expand Up @@ -255,10 +258,13 @@ def forward(self, *inputs: Tensor) -> Tensor:
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
reshaped_inputs = inputs[0].reshape(-1, d_model)

self.gates = self.gate(reshaped_inputs)
expert_weights, indices = self.topk_softmax_with_capacity(self.gates)
gates = self.gate(reshaped_inputs)
expert_weights, indices, tokens_per_expert_before_capacity = self.topk_softmax_with_capacity(gates)
self.l_aux = self.load_balancing_loss(tokens_per_expert_before_capacity, gates)

(dispatched_input, tokens_per_expert) = self.token_permutation_func(reshaped_inputs, expert_weights, indices)
(dispatched_input, tokens_per_expert) = self.token_permutation_func(
reshaped_inputs, expert_weights, indices, tokens_per_expert_before_capacity
)
if self.moe_grouped_mlp:
expert_output = self.experts(dispatched_input, batch_sizes=tokens_per_expert)
else:
Expand All @@ -272,11 +278,18 @@ def forward(self, *inputs: Tensor) -> Tensor:
def topk_softmax_with_capacity(self, gates):
expert_weights, indices = torch.topk(gates, self.topk, dim=1)
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
# we compute num_local_tokens_per_expert here. If no drop and padding, num_local_tokens_per_expert should be
# the final value, otherwise we recompute it in self.process(.)
# histc(.) can be faster the bincount(.), but will cause non-deterministic behavior
if self.deterministic_mode:
num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts)
else:
num_local_tokens_per_expert = torch.histc(indices, bins=self.num_experts, min=0, max=self.num_experts)

# without capacity
if self.capacity_factor is None:
# shape: [num_token, topk]
return expert_weights, indices
return expert_weights, indices, num_local_tokens_per_expert

# with capacity
expert_capacity = get_capacity(
Expand Down Expand Up @@ -311,7 +324,9 @@ def topk_softmax_with_capacity(self, gates):
final_expert_weights = expert_weights * torch.logical_not(exceed_mask)
final_indices = indices.clone().masked_fill_(exceed_mask, torch.iinfo(torch.long).max)

return final_expert_weights, final_indices
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)

return final_expert_weights, final_indices, tokens_per_expert_before_capacity

def _gather_along_first_dim_expert_parallel(self, input_):
"""Gather tensors and concatenate along the first dimension."""
Expand All @@ -329,7 +344,7 @@ def _gather_along_first_dim_expert_parallel(self, input_):

return output

def preprocess(self, indices, expert_weight) -> torch.Tensor:
def preprocess(self, indices, expert_weight, tokens_per_expert_before_capacity) -> torch.Tensor:
"""
Preprocess token indices for AlltoAll communication and token permutation. This method computes
the number of tokens assigned to each expert based on the input indices.
Expand All @@ -342,7 +357,13 @@ def preprocess(self, indices, expert_weight) -> torch.Tensor:
"""
# NOTE: bincount seem slower than histc
# num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts)
num_local_tokens_per_expert = torch.histc(indices, bins=self.num_experts, min=0, max=self.num_experts)
if self.capacity_factor is not None:
if self.deterministic_mode:
num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts)
else:
num_local_tokens_per_expert = torch.histc(indices, bins=self.num_experts, min=0, max=self.num_experts)
else:
num_local_tokens_per_expert = tokens_per_expert_before_capacity
# num_local_tokens_per_expert: [num_experts]

if self.drop_and_pad:
Expand Down Expand Up @@ -415,8 +436,6 @@ def preprocess(self, indices, expert_weight) -> torch.Tensor:
-1, self.num_local_experts
).to(torch.device("cpu"), non_blocking=True)

self.l_aux = self.load_balancing_loss(num_local_tokens_per_expert, self.gates)

return num_tokens_per_local_expert

def permute_with_padded_tokens(self, tokens, indices):
Expand Down Expand Up @@ -570,6 +589,7 @@ def token_permutation_by_alltoall(
reshaped_inputs: torch.Tensor,
expert_weights: torch.Tensor,
indices: torch.Tensor,
tokens_per_expert_before_capacity: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
Expand All @@ -585,7 +605,7 @@ def token_permutation_by_alltoall(
# Preprocess: Get the metadata for communication, permutation and computation operations.
assert expert_weights.dim() == 2, "Expected 2D tensor for expert weights"
assert indices.dim() == 2, "Expected 2D tensor for indices"
tokens_per_expert = self.preprocess(indices, expert_weights)
tokens_per_expert = self.preprocess(indices, expert_weights, tokens_per_expert_before_capacity)

# Permutation 1: input to AlltoAll input
self.hiddden_shape_before_permute = reshaped_inputs.shape
Expand Down Expand Up @@ -674,6 +694,7 @@ def token_permutation_by_all_gather(
reshaped_inputs: torch.Tensor,
expert_weights: torch.Tensor,
indices: torch.Tensor,
tokens_per_expert_before_capacity: torch.Tensor, # pylint: disable=W0613
):
"""Dispatch tokens to local experts. It's composed of two stages:
(1) Permute the tokens across the expert parallel devices. After this stage,
Expand Down

0 comments on commit 0f99777

Please sign in to comment.