Skip to content

Commit

Permalink
[Feauture] MoE refractor; Intergration with Mixtral (hpcaitech#5682)
Browse files Browse the repository at this point in the history
* cherry pick from refractor-moe branch

* tests passed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support ep + zero

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored and ver217 committed May 29, 2024
1 parent f1d4167 commit df6826d
Show file tree
Hide file tree
Showing 26 changed files with 1,979 additions and 1,296 deletions.
4 changes: 2 additions & 2 deletions applications/ColossalMoE/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import torch
import torch.distributed as dist
from mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.checkpoint import MoECheckpointIO
from colossalai.shardformer.policies.mixtral import MixtralForCausalLMPolicy


Expand Down Expand Up @@ -71,7 +71,7 @@ def main():
zero_stage=1,
precision=args.precision,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
checkpoint_io=MoECheckpointIO,
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
)
Expand Down
629 changes: 0 additions & 629 deletions applications/ColossalMoE/mixtral_checkpoint.py

This file was deleted.

11 changes: 7 additions & 4 deletions applications/ColossalMoE/tests/test_mixtral_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

import colossalai
from colossalai.moe import MOE_MANAGER
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock
from colossalai.testing.utils import spawn

Expand All @@ -19,8 +19,11 @@

def check_mixtral_moe_layer():
torch.cuda.set_device(dist.get_rank())
MOE_MANAGER.setup(
parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=dist.get_world_size(),
)
config = MixtralConfig(
hidden_size=hidden_size,
Expand All @@ -33,7 +36,7 @@ def check_mixtral_moe_layer():
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output, orig_logits = orig_model(x)
model = deepcopy(orig_model)
model = EPMixtralSparseMoeBlock.from_native_module(model)
model = EPMixtralSparseMoeBlock.from_native_module(model, plugin.ep_group)
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)
assert_close(orig_output, ep_output)
Expand Down
55 changes: 44 additions & 11 deletions applications/ColossalMoE/tests/test_moe_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import shutil
from copy import deepcopy

import pytest
import torch
import torch.distributed as dist
from mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from torch.optim import Adam
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe import MoECheckpointIO
from colossalai.shardformer.policies.mixtral import MixtralForCausalLMPolicy
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing.utils import spawn

tokens, n_experts = 7, 4
Expand All @@ -20,8 +23,14 @@

def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert torch.equal(p1.half(), p2.half())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
if not torch.equal(p1.half(), p2.half()):
# exit distributed
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
raise AssertionError(f"Model parameter {name} is not equal")
# dist.destroy_process_group()
# exit(1)
# print(f"Passed: {name}")


def get_optimizer_snapshot(optim):
Expand All @@ -40,7 +49,7 @@ def get_optimizer_snapshot(optim):
}


def check_optimizer_snapshot_equal(snapshot1, snapshot2):
def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_group=None):
# check param_groups
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
Expand All @@ -51,14 +60,26 @@ def check_optimizer_snapshot_equal(snapshot1, snapshot2):
assert set(snapshot1["state"].keys()) == set(
snapshot2["state"].keys()
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"

passed = True
count = 0
for pid in snapshot1["state"].keys():
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
assert set(state1.keys()) == set(state2.keys())
bug = False
for k in state1.keys():
if isinstance(state1[k], torch.Tensor):
assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
if not torch.equal(state1[k], state2[k]):
bug = True
count += 1
else:
assert state1[k] == state2[k]
if bug:
passed = False
print(f"rank {dist.get_rank()} optim mismatch: {param2name[pid]}")

if not passed:
raise AssertionError(f"A total of {count} optim states are not equal")


def check_mixtral_moe_layer():
Expand All @@ -77,10 +98,11 @@ def check_mixtral_moe_layer():
model = deepcopy(orig_model)
optimizer = Adam(model.parameters(), lr=1e-3)
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=2,
ep_size=2,
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
tp_size=1,
checkpoint_io=MoECheckpointIO,
custom_policy=MixtralForCausalLMPolicy(),
microbatch_size=1,
zero_stage=1,
)
Expand All @@ -103,9 +125,9 @@ def check_mixtral_moe_layer():
if dist.get_rank() == 0:
saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
check_model_equal(orig_model, saved_model)
# check_model_equal(model, saved_model)
saved_model.save_pretrained("mixtral_hf_model")
dist.barrier()

# check load model
new_model = MixtralForCausalLM(config).cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
Expand All @@ -120,25 +142,36 @@ def check_mixtral_moe_layer():
snapshot = get_optimizer_snapshot(optimizer.unwrap())
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
dist.barrier()

working2master = optimizer.get_working_to_master_map()
param2name = {id(working2master[id(p)]): n for n, p in model.named_parameters()}
# reset optimizer state
for state in optimizer.unwrap().state.values():
for v in state.values():
if isinstance(v, torch.Tensor):
v.zero_()
booster.load_optimizer(optimizer, "mixtral_optim")
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
check_optimizer_snapshot_equal(snapshot, loaded_snapshot, param2name, model)

# Clean up
dist.barrier()
if dist.get_rank() == 0:
shutil.rmtree("mixtral_model")
shutil.rmtree("mixtral_hf_model")
shutil.rmtree("mixtral_optim")


def run_dist(rank: int, world_size: int, port: int):
colossalai.launch(rank, world_size, "localhost", port)
check_mixtral_moe_layer()


@pytest.mark.parametrize("world_size", [4])
# Test EP + ZeRO + PP
@pytest.mark.parametrize("world_size", [8])
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)


if __name__ == "__main__":
test_mixtral_moe_layer(4)
test_mixtral_moe_layer(8)
6 changes: 4 additions & 2 deletions applications/ColossalMoE/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.distributed as dist
from mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
Expand All @@ -13,8 +12,10 @@
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.checkpoint import MoECheckpointIO
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.policies.mixtral import MixtralForCausalLMPolicy
from colossalai.utils import get_current_device


Expand Down Expand Up @@ -154,11 +155,12 @@ def main():
pp_size=args.pp_size,
ep_size=args.ep_size,
microbatch_size=args.microbatch_size,
custom_policy=MixtralForCausalLMPolicy(),
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
precision=args.precision,
zero_stage=args.zero_stage,
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
checkpoint_io=MoECheckpointIO,
)

else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
print(resp) # super-heavyweight awesome-natured yawning Australian creature!
"""

import json
from typing import Any, Mapping

Expand Down
Loading

0 comments on commit df6826d

Please sign in to comment.