Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(moe): dropless moe loss #348

Merged
merged 5 commits into from
Oct 18, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions internlm/model/moe/dropless_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __init__(
moe_grouped_mlp: bool = True,
enable_fused_permute: bool = True,
token_dispatch_policy: str = "alltoall",
deterministic_mode: bool = False,
) -> None:
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
"Unsupported noisy_gate_policy: " + noisy_gate_policy
Expand Down Expand Up @@ -199,6 +200,7 @@ def __init__(
assert len(self.local_expert_indices) > 0, "Expected at least one local expert index"
self.topk = top_k
self.moe_grouped_mlp = moe_grouped_mlp
self.deterministic_mode = deterministic_mode

self.drop_and_pad = drop_and_pad
self.capacity_factor = capacity_factor
Expand Down Expand Up @@ -256,9 +258,11 @@ def forward(self, *inputs: Tensor) -> Tensor:
reshaped_inputs = inputs[0].reshape(-1, d_model)

self.gates = self.gate(reshaped_inputs)
expert_weights, indices = self.topk_softmax_with_capacity(self.gates)
expert_weights, indices, tokens_per_expert_before_capacity = self.topk_softmax_with_capacity(self.gates)

(dispatched_input, tokens_per_expert) = self.token_permutation_func(reshaped_inputs, expert_weights, indices)
(dispatched_input, tokens_per_expert) = self.token_permutation_func(
reshaped_inputs, expert_weights, indices, tokens_per_expert_before_capacity
)
if self.moe_grouped_mlp:
expert_output = self.experts(dispatched_input, batch_sizes=tokens_per_expert)
else:
Expand All @@ -272,11 +276,18 @@ def forward(self, *inputs: Tensor) -> Tensor:
def topk_softmax_with_capacity(self, gates):
expert_weights, indices = torch.topk(gates, self.topk, dim=1)
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
# we compute num_local_tokens_per_expert here. If no drop and padding, num_local_tokens_per_expert should be
# the final value, otherwise we recompute it in self.process(.)
# histc(.) can be faster the bincount(.), but will cause non-deterministic behavior
if self.deterministic_mode:
num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts)
else:
num_local_tokens_per_expert = torch.histc(indices, bins=self.num_experts, min=0, max=self.num_experts)

# without capacity
if self.capacity_factor is None:
# shape: [num_token, topk]
return expert_weights, indices
return expert_weights, indices, num_local_tokens_per_expert

# with capacity
expert_capacity = get_capacity(
Expand Down Expand Up @@ -311,7 +322,9 @@ def topk_softmax_with_capacity(self, gates):
final_expert_weights = expert_weights * torch.logical_not(exceed_mask)
final_indices = indices.clone().masked_fill_(exceed_mask, torch.iinfo(torch.long).max)

return final_expert_weights, final_indices
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)

return final_expert_weights, final_indices, tokens_per_expert_before_capacity

def _gather_along_first_dim_expert_parallel(self, input_):
"""Gather tensors and concatenate along the first dimension."""
Expand All @@ -329,7 +342,7 @@ def _gather_along_first_dim_expert_parallel(self, input_):

return output

def preprocess(self, indices, expert_weight) -> torch.Tensor:
def preprocess(self, indices, expert_weight, tokens_per_expert_before_capacity) -> torch.Tensor:
"""
Preprocess token indices for AlltoAll communication and token permutation. This method computes
the number of tokens assigned to each expert based on the input indices.
Expand All @@ -342,7 +355,13 @@ def preprocess(self, indices, expert_weight) -> torch.Tensor:
"""
# NOTE: bincount seem slower than histc
# num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts)
num_local_tokens_per_expert = torch.histc(indices, bins=self.num_experts, min=0, max=self.num_experts)
if self.capacity_factor is not None:
if self.deterministic_mode:
num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts)
else:
num_local_tokens_per_expert = torch.histc(indices, bins=self.num_experts, min=0, max=self.num_experts)
else:
num_local_tokens_per_expert = tokens_per_expert_before_capacity
# num_local_tokens_per_expert: [num_experts]

if self.drop_and_pad:
Expand Down Expand Up @@ -415,7 +434,7 @@ def preprocess(self, indices, expert_weight) -> torch.Tensor:
-1, self.num_local_experts
).to(torch.device("cpu"), non_blocking=True)

self.l_aux = self.load_balancing_loss(num_local_tokens_per_expert, self.gates)
self.l_aux = self.load_balancing_loss(tokens_per_expert_before_capacity, self.gates)
KimmiShi marked this conversation as resolved.
Show resolved Hide resolved

return num_tokens_per_local_expert

Expand Down Expand Up @@ -570,6 +589,7 @@ def token_permutation_by_alltoall(
reshaped_inputs: torch.Tensor,
expert_weights: torch.Tensor,
indices: torch.Tensor,
tokens_per_expert_before_capacity: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
Expand All @@ -585,7 +605,7 @@ def token_permutation_by_alltoall(
# Preprocess: Get the metadata for communication, permutation and computation operations.
assert expert_weights.dim() == 2, "Expected 2D tensor for expert weights"
assert indices.dim() == 2, "Expected 2D tensor for indices"
tokens_per_expert = self.preprocess(indices, expert_weights)
tokens_per_expert = self.preprocess(indices, expert_weights, tokens_per_expert_before_capacity)

# Permutation 1: input to AlltoAll input
self.hiddden_shape_before_permute = reshaped_inputs.shape
Expand Down Expand Up @@ -674,6 +694,7 @@ def token_permutation_by_all_gather(
reshaped_inputs: torch.Tensor,
expert_weights: torch.Tensor,
indices: torch.Tensor,
tokens_per_expert_before_capacity: torch.Tensor, # pylint: disable=W0613
):
"""Dispatch tokens to local experts. It's composed of two stages:
(1) Permute the tokens across the expert parallel devices. After this stage,
Expand Down
Loading