Skip to content

Commit

Permalink
add missing file
Browse files Browse the repository at this point in the history
  • Loading branch information
blankde committed Oct 25, 2024
1 parent 556ab85 commit f57f970
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions internlm/model/moe/gshard_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def __init__(
use_fused_gating: bool = True,
enable_token_rearrange_opt: bool = True,
use_tutel: bool = True,
moe_grouped_mlp: bool = True,
use_grouped_mlp: bool = True,
) -> None:
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
"Unsupported noisy_gate_policy: " + noisy_gate_policy
Expand All @@ -540,10 +540,22 @@ def __init__(
use_fused_gating or top_k > 2
), "enable_token_rearrange_opt only can be used when use_fused_gating or top_k>2"

self.moe_grouped_mlp = moe_grouped_mlp

if moe_grouped_mlp:
assert False, "not support yet"
if use_grouped_mlp:
experts = new_feed_forward(
in_features,
hidden_features,
out_features,
bias=False,
device=device,
dtype=dtype,
mlp_layer_fusion=mlp_layer_fusion,
multiple_of=multiple_of,
activation_type=activation_type,
is_expert=True,
use_grouped_mlp=True,
num_groups=num_experts // ep_size,
backend="bmm",
)
else:
experts = torch.nn.ModuleList(
[
Expand Down Expand Up @@ -583,6 +595,8 @@ def __init__(
num_experts // ep_size,
)

self.use_grouped_mlp = use_grouped_mlp

self.time_falltoall = 0.0
self.time_salltoall = 0.0
self.time_moe = 0.0
Expand Down Expand Up @@ -634,15 +648,15 @@ def forward(self, *inputs: Tensor) -> Tensor:
# Re-shape after all-to-all: ecm -> gecm
dispatched_inputs = dispatched_inputs.reshape(self.ep_size, self.num_local_experts, -1, d_model)

if self.moe_grouped_mlp:
if self.use_grouped_mlp:
# (g,e,c,m) -> (e, g*c, m)
dispatched_inputs = (
dispatched_inputs.transpose(0, 1).reshape(self.num_local_experts, -1, d_model).contiguous()
)

expert_output = self.experts(dispatched_inputs, split_dim=1)

if self.moe_grouped_mlp:
if self.use_grouped_mlp:
# (e, g*c, m) -> (e, g, c, m) -> (g, e, c, m)
expert_output = (
expert_output.reshape(self.num_local_experts, self.ep_size, -1, d_model).transpose(0, 1).contiguous()
Expand Down

0 comments on commit f57f970

Please sign in to comment.