Skip to content

Commit

Permalink
fixes #2389
Browse files Browse the repository at this point in the history
truncating expert param storage for checkpointing
  • Loading branch information
azzhipa committed Oct 10, 2022
1 parent f4a92a1 commit b5c4f58
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 15 deletions.
5 changes: 3 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2989,8 +2989,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}):
num_local_experts + int(local_expert_id)
expert_key = key.replace(f'{moe_str_prefix}{local_expert_id}',
f'{moe_str_prefix}{global_expert_id}')
experts_state_dict[str(
global_expert_id)][expert_key] = moe_state_dict.pop(key)
# truncating extra tensor (shared) storage
truncated = moe_state_dict.pop(key).clone().detach()
experts_state_dict[str(global_expert_id)][expert_key] = truncated

# let save the moe parameters
for global_expert_id, expert_state_dict in experts_state_dict.items():
Expand Down
34 changes: 25 additions & 9 deletions tests/unit/checkpoint/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,19 @@ def compare_lr_scheduler_states(saved_model, loaded_model):
assert state0 == state1


def create_deepspeed_model(config_dict, model, base_optimizer):
if base_optimizer is None:
ds_model, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=model.parameters())
else:
ds_model, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
optimizer=base_optimizer)
# following mixture-of-experts.md
def create_moe_param_groups(model):
from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer

parameters = {'params': [p for p in model.parameters()], 'name': 'parameters'}
return split_params_into_different_moe_groups_for_optimizer(parameters)


def create_deepspeed_model(config_dict, model, base_optimizer):
ds_model, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=create_moe_param_groups(model),
optimizer=base_optimizer)
return ds_model


Expand Down Expand Up @@ -178,6 +181,19 @@ def checkpoint_correctness_verification(config_dict,

dist.barrier()

for root, _, files in os.walk(save_folder):
for f in files:
if "_expert_" in f and "_model_states" in f:
expert = torch.load(os.path.join(root, f))
needed, storages = 0, {}
for name, tensor in expert.items():
needed += tensor.size().numel()
storage = tensor.storage()
# some storage can be shared within an expert's checkpoint
storages[storage.data_ptr()] = storage.size()
stored = sum(v for _, v in storages.items())
assert needed == stored, f"MoE expert checkpoint uses more storage than required: {f}"

loaded_model = create_deepspeed_model(config_dict=config_dict,
model=models[1],
base_optimizer=base_optimizers[1])
Expand Down
15 changes: 11 additions & 4 deletions tests/unit/simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,26 @@ class SimpleMoEModel(torch.nn.Module):
def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False):
super(SimpleMoEModel, self).__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
expert = torch.nn.Linear(hidden_dim, hidden_dim)
# using two MoE layers to check implications of sharing a single storage
self.linear2 = MoE(hidden_size=hidden_dim,
expert=linear2,
expert=expert,
ep_size=ep_size,
use_residual=use_residual,
num_experts=num_experts,
k=1)
self.linear3 = MoE(hidden_size=hidden_dim,
expert=expert,
ep_size=ep_size,
use_residual=use_residual,
num_experts=num_experts,
k=1)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

def forward(self, x, y):
hidden_dim = x
hidden_dim = self.linear(hidden_dim)
hidden_dim = self.linear(x)
output, _, _ = self.linear2(hidden_dim)
output, _, _ = self.linear3(output)
hidden_dim = hidden_dim + output
sentence_embed = hidden_dim.mean(1)
return self.cross_entropy_loss(sentence_embed, y)
Expand Down

0 comments on commit b5c4f58

Please sign in to comment.