Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes #2389 #2411

Merged
merged 3 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2989,8 +2989,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}):
num_local_experts + int(local_expert_id)
expert_key = key.replace(f'{moe_str_prefix}{local_expert_id}',
f'{moe_str_prefix}{global_expert_id}')
experts_state_dict[str(
global_expert_id)][expert_key] = moe_state_dict.pop(key)
# truncating extra tensor (shared) storage
truncated = moe_state_dict.pop(key).clone().detach()
experts_state_dict[str(global_expert_id)][expert_key] = truncated

# let save the moe parameters
for global_expert_id, expert_state_dict in experts_state_dict.items():
Expand Down
34 changes: 25 additions & 9 deletions tests/unit/checkpoint/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,19 @@ def compare_lr_scheduler_states(saved_model, loaded_model):
assert state0 == state1


def create_deepspeed_model(config_dict, model, base_optimizer):
if base_optimizer is None:
ds_model, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=model.parameters())
else:
ds_model, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
optimizer=base_optimizer)
# following mixture-of-experts.md
def create_moe_param_groups(model):
from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer

parameters = {'params': [p for p in model.parameters()], 'name': 'parameters'}
return split_params_into_different_moe_groups_for_optimizer(parameters)


def create_deepspeed_model(config_dict, model, base_optimizer):
ds_model, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=create_moe_param_groups(model),
optimizer=base_optimizer)
return ds_model


Expand Down Expand Up @@ -178,6 +181,19 @@ def checkpoint_correctness_verification(config_dict,

dist.barrier()

for root, _, files in os.walk(save_folder):
for f in files:
if "_expert_" in f and "_model_states" in f:
expert = torch.load(os.path.join(root, f))
needed, storages = 0, {}
for name, tensor in expert.items():
needed += tensor.size().numel()
storage = tensor.storage()
# some storage can be shared within an expert's checkpoint
storages[storage.data_ptr()] = storage.size()
stored = sum(v for _, v in storages.items())
assert needed == stored, f"MoE expert checkpoint uses more storage than required: {f}"

loaded_model = create_deepspeed_model(config_dict=config_dict,
model=models[1],
base_optimizer=base_optimizers[1])
Expand Down
15 changes: 11 additions & 4 deletions tests/unit/simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,26 @@ class SimpleMoEModel(torch.nn.Module):
def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False):
super(SimpleMoEModel, self).__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
expert = torch.nn.Linear(hidden_dim, hidden_dim)
# using two MoE layers to check implications of sharing a single storage
self.linear2 = MoE(hidden_size=hidden_dim,
expert=linear2,
expert=expert,
ep_size=ep_size,
use_residual=use_residual,
num_experts=num_experts,
k=1)
self.linear3 = MoE(hidden_size=hidden_dim,
expert=expert,
ep_size=ep_size,
use_residual=use_residual,
num_experts=num_experts,
k=1)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

def forward(self, x, y):
hidden_dim = x
hidden_dim = self.linear(hidden_dim)
hidden_dim = self.linear(x)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
output, _, _ = self.linear2(hidden_dim)
output, _, _ = self.linear3(output)
hidden_dim = hidden_dim + output
sentence_embed = hidden_dim.mean(1)
return self.cross_entropy_loss(sentence_embed, y)
Expand Down