Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add moe_router_device_choice_method argument to choose method … #1381

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def device_limited_topk(
num_tokens: int,
num_experts: int,
moe_router_topk_limited_devices: int,
moe_router_device_choice_method: str,
):
"""Perform top-k routing on a subset of expert parallel ranks.

Expand All @@ -346,6 +347,7 @@ def device_limited_topk(
num_experts (int): The number of experts.
moe_router_topk_limited_devices (int): Number of expert parallel ranks to consider for
each token during routing. None means no device limitation.
moe_router_device_choice_method (str): The method to select the top-k devices.

Returns:
Tuple[torch.Tensor, torch.Tensor]: Probs and indices tensor.
Expand All @@ -355,7 +357,15 @@ def device_limited_topk(
num_group = (
parallel_state.get_expert_model_parallel_world_size()
) # num_group equals to expert parallel size
group_scores = scores.view(num_tokens, num_group, -1).max(dim=-1).values

group_scores = scores.view(num_tokens, num_group, -1)
if moe_router_device_choice_method == "max":
group_scores = group_scores.max(dim=-1).values
elif moe_router_device_choice_method == "top2-sum":
group_scores = group_scores.topk(2, dim=-1).sum(dim=-1).values
else:
raise ValueError(f"Invalid moe_router_device_choice_method: {moe_router_device_choice_method}")

group_idx = torch.topk(group_scores, k=moe_router_topk_limited_devices, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
Expand All @@ -381,6 +391,7 @@ def topk_softmax_with_capacity(
drop_policy: str = "probs",
use_pre_softmax: bool = False,
moe_router_topk_limited_devices: int = None,
moe_router_device_choice_method: str = "max",
moe_router_topk_scaling_factor: Optional[float] = None,
deterministic_mode: bool = False,
score_function: str = "softmax",
Expand All @@ -399,6 +410,8 @@ def topk_softmax_with_capacity(
use_pre_softmax (bool): Whether to apply softmax before top-k selection.
moe_router_topk_limited_devices (int): Number of expert parallel ranks to consider for
each token during routing. None means no device limitation.
moe_router_device_choice_method (str): The method to select the top-k devices.
only works when --moe-router-topk-limited-devices is not None.
moe_router_topk_scaling_factor (float): Scaling factor for routing score in top-k
selection, only works when use_pre_softmax enabled.
deterministic_mode (bool): Deprecated.
Expand All @@ -415,27 +428,54 @@ def topk_softmax_with_capacity(
assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
num_tokens, num_experts = logits.shape

def compute_topk(scores, topk, limited_devices=None):
def compute_topk(scores, topk, limited_devices=None, device_choice_method='max'):
if limited_devices:
return device_limited_topk(scores, topk, num_tokens, num_experts, limited_devices)
return device_limited_topk(
scores,
topk,
num_tokens,
num_experts,
limited_devices,
device_choice_method,
)
else:
return torch.topk(scores, k=topk, dim=1)

if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = compute_topk(scores, topk, moe_router_topk_limited_devices)
probs, top_indices = compute_topk(
scores,
topk,
moe_router_topk_limited_devices,
moe_router_device_choice_method,
)
else:
scores, top_indices = compute_topk(logits, topk, moe_router_topk_limited_devices)
scores, top_indices = compute_topk(
scores,
topk,
moe_router_topk_limited_devices,
moe_router_device_choice_method,
)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
if expert_bias is not None:
scores_for_routing = scores + expert_bias
_, top_indices = compute_topk(scores_for_routing, topk, moe_router_topk_limited_devices)
_, top_indices = compute_topk(
scores_for_routing,
topk,
moe_router_topk_limited_devices,
moe_router_device_choice_method,
)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
else:
scores, top_indices = compute_topk(scores, topk, moe_router_topk_limited_devices)
scores, top_indices = compute_topk(
scores_for_routing,
topk,
moe_router_topk_limited_devices,
moe_router_device_choice_method,
)
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")
Expand Down
2 changes: 2 additions & 0 deletions megatron/core/transformer/moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def aux_loss_load_balancing(self, logits: torch.Tensor):
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
moe_router_topk_limited_devices=self.config.moe_router_topk_limited_devices,
moe_router_device_choice_method=self.config.moe_router_device_choice_method,
moe_router_topk_scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
Expand Down Expand Up @@ -201,6 +202,7 @@ def seq_aux_loss_load_balancing(self, logits: torch.Tensor, bsz: int, seq_length
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
moe_router_topk_limited_devices=self.config.moe_router_topk_limited_devices,
moe_router_device_choice_method=self.config.moe_router_device_choice_method,
moe_router_topk_scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
Expand Down
5 changes: 5 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,11 @@ class TransformerConfig(ModelParallelConfig):
routing on a subset of expert parallel ranks by first selecting N ranks for each token, then
conducting top-k selection among experts on these devices. None means no device limitation."""

moe_router_device_choice_method: str = 'max'
"""The method to select the top-k devices, only works when --moe-router-topk-limited-devices
is not None. "max" is used in DeepSeekV2 and "top2-sum" is used in DeepSeekV3. The default is
"max"."""

moe_router_pre_softmax: bool = False
"""Enable pre-softmax routing for MoE, which means softmax is before the top-k selection.
By default, softmax is done after top-k."""
Expand Down
4 changes: 4 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2199,6 +2199,10 @@ def _add_moe_args(parser):
help='Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.')
group.add_argument('--moe-router-topk-limited-devices', type=int, default=None,
help='Number of expert parallel ranks to consider for each token during routing. Perform top-k routing on a subset of expert parallel ranks by first selecting N ranks for each token, then conducting top-k selection among experts on these devices. Default is None, which means no limited devices.')
group.add_argument('--moe-router-device-choice-method', type=str, default='max',
choices=['max', 'top2-sum'],
help='The method to select the top-k devices, only works when --moe-router-topk-limited-devices is not None. "max" is used in DeepSeekV2 and "top2-sum" is used in DeepSeekV3. The default is "max".'
)
group.add_argument('--moe-router-topk-scaling-factor', type=float, default=None,
help='Scaling factor for routing score in top-k selection, only works when --moe-router-pre-softmax enabled. Defaults to None, which means no scaling.')
group.add_argument('--moe-router-enable-expert-bias', action='store_true',
Expand Down