Skip to content

Commit

Permalink
- Updated sinkhorn initialization and add max_iter argument.
Browse files Browse the repository at this point in the history
- Removed mp assertion for moe
- Removed mlp_type checks in moe code
- Added Bf16 conversion to dmoe_gather
  • Loading branch information
aurelion-source committed Dec 10, 2024
1 parent fb68c07 commit 542103f
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 36 deletions.
19 changes: 5 additions & 14 deletions megatron/model/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,11 @@ def __init__(
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)

# decide which parallel grouped MLP implementation to use
if neox_args.mlp_type == "regular":
self.mlp = ParallelGroupedMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
)
elif neox_args.mlp_type == "llama":
self.mlp = ParallelGroupedLLaMAMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
)
else:
raise KeyError(neox_args.mlp_type)
self.mlp = ParallelGroupedMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
)

def indices_and_bins(self, top_expert: torch.Tensor):
# Sort the expert ids to produce the scatter/gather
Expand Down
4 changes: 2 additions & 2 deletions megatron/model/moe_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
"""
super(ParallelGroupedMLP, self).__init__()

self.activation_func = get_activation(neox_args)
self.activation_func, self.activation_fn_is_gated = get_activation(neox_args)
self.activation_type = neox_args.activation

self.multiple_of = multiple_of
Expand Down Expand Up @@ -334,7 +334,7 @@ def __init__(
"""
super(ParallelGroupedLLaMAMLP, self).__init__()

self.activation_func = get_activation(neox_args)
self.activation_func, self.activation_fn_is_gated = get_activation(neox_args)
self.activation_type = neox_args.activation

self.multiple_of = multiple_of
Expand Down
8 changes: 5 additions & 3 deletions megatron/model/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,22 @@ def __init__(
)
init_method(self.layer.weight)

def sinkhorn(self, cost: torch.Tensor, tol: float = 0.0001):
def sinkhorn(self, cost: torch.Tensor, tol: float = 0.0001, max_iter=3):
"""Sinkhorn based MoE routing function"""
cost = torch.exp(cost)
d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
d1 = 1 / (cost.size(1) * torch.sum(cost, 0))

eps = 0.00000001
error = 1e9
d1_old = d1
while error > tol:
for iteration in range(max_iter):
d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
error = torch.mean(torch.abs(d1_old - d1))
d1_old = d1
if error > tol:
break
return d1 * cost * d0.unsqueeze(1)

def sinkhorn_load_balancing(self, logits: torch.Tensor):
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def _get_bias_dropout(self):
def forward(self, x, attention_mask, layer_past=None):
layer_past = layer_past if layer_past is not None else self.layer_past
bias_dropout_fn = self._get_bias_dropout()

# x: [b, s, h]
if self.gpt_j_residual:
# pseudocode:
Expand Down
10 changes: 5 additions & 5 deletions megatron/mpu/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,7 @@ def _gather(input_):
torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group())

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim)

# Bf16 convert
if dt == torch.bfloat16 and get_fp32_allreduce():
output = output.bfloat16()
output = torch.cat(tensor_list, dim=last_dim).contiguous()

return output

Expand Down Expand Up @@ -180,6 +176,10 @@ def _dmoe_gather(input_: torch.Tensor, tokens_per_expert: torch.Tensor):
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=gather_dim)

# Bf16 convert
if dt == torch.bfloat16 and get_fp32_allreduce():
output = output.bfloat16()

return output


Expand Down
11 changes: 0 additions & 11 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,17 +1078,6 @@ def calculate_derived(self):
# the sequential model without the PipelineModule wrapper to avoid the overhead it incurs
self.update_value("is_pipe_parallel", self.pipe_parallel_size >= 1)

# Do MoE checks
if self.moe_num_experts > 1:
assert not (
self.is_pipe_parallel or self.pipe_parallel_size > 1
), "MoE not supported with pipeline parallelism"
assert self.zero_optimization["stage"] != 3, "MoE not compatible with zero3"

assert (
self.sequence_parallel is False
), "MoE not compatible with Sequence Parallel"

# Attention config
if self.attention_config is None:
self.update_value("attention_config", [[["global"], self.num_layers]])
Expand Down

0 comments on commit 542103f

Please sign in to comment.