Skip to content

Commit

Permalink
Merge branch 'xiny/fix_dist_ckpt_for_fp8_padding' into 'main'
Browse files Browse the repository at this point in the history
Fix distributed checkpointing for fp8 padding/unpadding

See merge request ADLR/megatron-lm!2529
  • Loading branch information
ko3n1g committed Feb 3, 2025
2 parents eedb2fe + 68589ec commit 3366815
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
11 changes: 6 additions & 5 deletions megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
from megatron.core.transformer.moe import grouped_gemm_util as gg
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_object_for_checkpoint
from megatron.core.transformer.utils import (
make_sharded_object_for_checkpoint,
sharded_state_dict_default,
)

try:

Expand Down Expand Up @@ -369,6 +372,7 @@ def sh_ten_build_fn(
v_tensors = []
w_lens = []
v_lens = []
expert_global_idx = local_expert_indices_offset + local_expert_idx
for input_dim_idx in range(self.config.hidden_size):
for glu_idx in range(2):
local_idx = (
Expand Down Expand Up @@ -399,9 +403,6 @@ def sh_ten_build_fn(
== local_flattened_range.stop - local_flattened_range.start
)
start_pos += len(local_tensor)
expert_global_idx = (
local_expert_indices_offset + local_expert_idx
)
if glu_idx == 0:
w_tensors.append(local_tensor)
w_lens.append(len(local_tensor))
Expand Down Expand Up @@ -719,7 +720,7 @@ def sharded_state_dict(
"""
sharded_state_dict = {}
for name, module in self._modules.items():
sub_sd = module.sharded_state_dict(f'{name}.', sharded_offsets, metadata)
sub_sd = sharded_state_dict_default(module, f'{name}.', sharded_offsets, metadata)
if name == 'linear_fc1' and self.config.gated_linear_unit:
num_global_experts = (
parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
Expand Down
31 changes: 24 additions & 7 deletions tests/unit_tests/dist_checkpointing/models/test_moe_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def initialize_expert_layer(seed, glu=True, expert_type='sequential', fp8=False,
num_moe_experts=num_moe_experts,
use_cpu_initialization=True,
gated_linear_unit=glu,
fp8="hybrid" if fp8 else None,
)
default_config_kwargs.update(**config_kwargs)
transformer_config = TransformerConfig(**default_config_kwargs)
Expand All @@ -66,8 +67,20 @@ def initialize_expert_layer(seed, glu=True, expert_type='sequential', fp8=False,
transformer_config,
transformer_layer_spec.submodules.mlp.submodules.experts.submodules,
)
elif expert_type == 'te_sequential':
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
num_experts=num_moe_experts, moe_grouped_gemm=False
)
model = SequentialMLP(
num_local_experts,
transformer_config,
transformer_layer_spec.submodules.mlp.submodules.experts.submodules,
)
else:
raise ValueError('expert_type can only be one of ["sequential", "grouped", "te_grouped"]')
raise ValueError(
'expert_type can only be one of ["sequential", "te_sequential", "grouped",'
' "te_grouped"]'
)
return model


Expand All @@ -79,10 +92,14 @@ def get_pp_offsets():

expert_type = ['sequential', 'grouped']
src_dest_expert_type = [('sequential', 'grouped'), ('grouped', 'sequential')]
if is_te_min_version("1.7.0.dev0"):
expert_type.append('te_sequential')
src_dest_expert_type.append(('sequential', 'te_sequential'))
src_dest_expert_type.append(('te_sequential', 'sequential'))
if is_te_min_version("1.9.0.dev0"):
expert_type.append('te_grouped')
src_dest_expert_type.append(('sequential', 'te_grouped'))
src_dest_expert_type.append(('te_grouped', 'sequential'))
src_dest_expert_type.append(('te_sequential', 'te_grouped'))
src_dest_expert_type.append(('te_grouped', 'te_sequential'))


class TestExpertLayerReconfiguration:
Expand Down Expand Up @@ -283,10 +300,10 @@ def test_sequential_grouped_mlp_interchangeable(
"src_module,dst_module,src_tp_pp_exp,dest_tp_pp_exp",
[
# Changing tp/pp/dp doesn't affect _extra_state
('sequential', 'te_grouped', (1, 1, 1), (1, 1, 4)),
('sequential', 'te_grouped', (1, 1, 4), (1, 1, 1)),
('te_grouped', 'sequential', (1, 1, 1), (1, 1, 4)),
('te_grouped', 'sequential', (1, 1, 4), (1, 1, 1)),
('te_sequential', 'te_grouped', (1, 1, 1), (1, 1, 4)),
('te_sequential', 'te_grouped', (1, 1, 4), (1, 1, 1)),
('te_grouped', 'te_sequential', (1, 1, 1), (1, 1, 4)),
('te_grouped', 'te_sequential', (1, 1, 4), (1, 1, 1)),
],
)
def test_sequential_grouped_mlp_extra_state(
Expand Down

0 comments on commit 3366815

Please sign in to comment.