diff --git a/applications/ColossalMoE/colossal_moe/__init__.py b/applications/ColossalMoE/colossal_moe/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/applications/ColossalMoE/colossal_moe/models/__init__.py b/applications/ColossalMoE/colossal_moe/models/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py index 543c434d2a99..1b07496e53f5 100644 --- a/applications/ColossalMoE/infer.py +++ b/applications/ColossalMoE/infer.py @@ -2,8 +2,7 @@ import torch import torch.distributed as dist -from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO -from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO from transformers import AutoTokenizer from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM @@ -11,6 +10,7 @@ from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator +from colossalai.shardformer.policies.mixtral import MixtralForCausalLMPolicy def parse_args(): diff --git a/applications/ColossalMoE/infer.sh b/applications/ColossalMoE/infer.sh index 0487fe9c1562..ba4362d7444d 100644 --- a/applications/ColossalMoE/infer.sh +++ b/applications/ColossalMoE/infer.sh @@ -1,5 +1,6 @@ NUM_GPU=2 -MODEL="mistralai/Mixtral-8x7B-v0.1" +# MODEL="mistralai/Mixtral-8x7B-v0.1" +MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1" # ep torchrun --standalone --nproc_per_node $NUM_GPU infer.py \ diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/applications/ColossalMoE/mixtral_checkpoint.py similarity index 100% rename from applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py rename to applications/ColossalMoE/mixtral_checkpoint.py diff --git a/applications/ColossalMoE/tests/test_mixtral_layer.py b/applications/ColossalMoE/tests/test_mixtral_layer.py index cbb70f195258..c21f608feae7 100644 --- a/applications/ColossalMoE/tests/test_mixtral_layer.py +++ b/applications/ColossalMoE/tests/test_mixtral_layer.py @@ -3,13 +3,13 @@ import pytest import torch import torch.distributed as dist -from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock from torch.testing import assert_close from transformers.models.mixtral.configuration_mixtral import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import colossalai from colossalai.moe import MOE_MANAGER +from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock from colossalai.testing.utils import spawn tokens, n_experts = 7, 4 diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py index 074dbf835fa6..c1b6be317a05 100644 --- a/applications/ColossalMoE/tests/test_moe_checkpoint.py +++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py @@ -3,8 +3,7 @@ import pytest import torch import torch.distributed as dist -from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO -from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO from torch.optim import Adam from transformers.models.mixtral.configuration_mixtral import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM @@ -81,7 +80,6 @@ def check_mixtral_moe_layer(): tp_size=1, pp_size=2, ep_size=2, - custom_policy=MixtralForCausalLMPolicy(), checkpoint_io=MixtralMoEHybridParallelCheckpointIO, microbatch_size=1, zero_stage=1, diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py index d2789d644ca5..76374db798e5 100644 --- a/applications/ColossalMoE/train.py +++ b/applications/ColossalMoE/train.py @@ -2,13 +2,12 @@ import torch import torch.distributed as dist -from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO -from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy -from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint +from mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO from torch.utils.data import Dataset from tqdm import tqdm from transformers import AutoTokenizer from transformers.models.mixtral import MixtralForCausalLM +from utils import load_checkpoint, move_to_cuda, save_checkpoint import colossalai from colossalai.booster import Booster @@ -155,7 +154,6 @@ def main(): pp_size=args.pp_size, ep_size=args.ep_size, microbatch_size=args.microbatch_size, - custom_policy=MixtralForCausalLMPolicy(), enable_fused_normalization=args.use_layernorm_kernel, enable_jit_fused=args.use_kernel, precision=args.precision, diff --git a/applications/ColossalMoE/colossal_moe/utils.py b/applications/ColossalMoE/utils.py similarity index 100% rename from applications/ColossalMoE/colossal_moe/utils.py rename to applications/ColossalMoE/utils.py diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index cc33c77f3eed..2708764d89bd 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,20 +1,7 @@ from .checkpoint import MoECheckpointIO -from .experts import MLPExperts -from .layers import SparseMLP, apply_load_balance from .manager import MOE_MANAGER -from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter -from .utils import NormalNoiseGenerator, UniformNoiseGenerator __all__ = [ - "MLPExperts", - "MoeRouter", - "Top1Router", - "Top2Router", - "TopKRouter", - "NormalNoiseGenerator", - "UniformNoiseGenerator", - "SparseMLP", "MoECheckpointIO", "MOE_MANAGER", - "apply_load_balance", ] diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py deleted file mode 100644 index 8e6ea3884df4..000000000000 --- a/colossalai/moe/experts.py +++ /dev/null @@ -1,161 +0,0 @@ -import math -from typing import Callable, Optional, Tuple - -import torch -import torch.nn as nn - -from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON -from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import get_activation -from colossalai.shardformer.layer.utils import Randomizer -from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info - -if HAS_TRITON: - from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine - - -class MLPExperts(nn.Module): - """ - SparseMLP is a multi-layer perceptron with sparse expert parallel layers. - - Args: - num_experts (int): The number of experts - hidden_size (int): The hidden size of MLP - intermediate_size (int): The intermediate size of MLP - expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP. - activation (optional): The activation function of MLP - drop_rate (float, optional): The drop rate of MLP - gated (bool, optional): Whether to use gated MLP - use_kernel (bool, optional): Whether to use kernel optimization - """ - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - expert_parallel: Optional[str] = None, - activation: Optional[Callable] = None, - drop_rate: Optional[float] = 0, - gated: Optional[bool] = False, - use_kernel: Optional[bool] = False, - ): - super().__init__() - assert expert_parallel in ["EP", "TP", None] - self.expert_parallel = expert_parallel - self.num_total_experts = num_experts - self.gated = gated - self.use_kernel = use_kernel - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - - # get expert parallel info - if expert_parallel is not None: - self.num_local_experts, self.moe_info = MOE_MANAGER.get_info( - num_experts, use_tp=True if expert_parallel == "TP" else False - ) - # get settings for different parallel - self.ep_size = get_ep_size(self) - if expert_parallel == "TP": - intermediate_size = intermediate_size // self.ep_size - num_experts = self.num_total_experts - else: - num_experts = self.num_local_experts - else: - self.num_local_experts = self.num_total_experts - self.ep_size = 1 - - if gated: - self.wi_gate = nn.Parameter( - torch.empty( - num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size - ) - ) - self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) - else: - self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) - self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) - - self.act_name = activation - self.act = get_activation(activation) - self.drop = nn.Dropout(p=drop_rate) - - if expert_parallel is not None: - for param in self.parameters(): - set_moe_tensor_info(param, self.moe_info) - - # init param - self.reset_parameters() - - @torch.no_grad() - def reset_parameters(self): - # expert param should be different - if self.expert_parallel is not None: - seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True) - else: - seed_ctx = Randomizer(42).fork_rng(enable_cpu=True) - with seed_ctx: - if self.gated: - torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size)) - torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size)) - else: - torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size)) - torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size)) - - def forward( - self, - x: torch.Tensor, - param_slice: Tuple[slice] = (slice(None),), - use_sparse: bool = True, - ) -> torch.Tensor: - """ - forward: hidden_size --> intermediate_size --> hidden_size - - Args: - x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size) - - Returns: - torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size) - """ - x = MoeInGradScaler.apply(x, self.ep_size) - - e = x.size(1) - h = x.size(-1) - - x = x.transpose(0, 1) - inshape = x.shape - x = x.reshape(e, -1, h) - - if self.use_kernel and use_sparse: - seq_len = x.shape[1] - with torch.no_grad(): - mask = x[:, :, 0] != 0.0 - mask = torch.sum(mask, dim=-1) - x_list = [] - for i in range(e): - x_list.append(x[i, : mask[i]]) - x = x_list - - if self.gated: - x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)] - x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)] - if self.use_kernel and HAS_TRITON and self.act_name == "swiglu": - x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)] - else: - x = [self.act(x_gate[i]) * x_up[i] for i in range(e)] - else: - x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)] - x = [self.act(x[i]) for i in range(e)] - x = [self.drop(x[i]) for i in range(e)] - x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)] - - if self.use_kernel and use_sparse: - for i in range(e): - x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0) - - x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) - x = x.reshape(inshape) - x = x.transpose(0, 1).contiguous() - x = MoeOutGradScaler.apply(x, self.ep_size) - return x diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py deleted file mode 100644 index 2ac5b186d116..000000000000 --- a/colossalai/moe/layers.py +++ /dev/null @@ -1,400 +0,0 @@ -import dataclasses -import math -from typing import Any, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F - -from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter -from colossalai.moe.experts import MLPExperts -from colossalai.moe.load_balance import LoadBalancer -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.routers import MoeRouter, get_router_cls -from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator -from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size - - -class SparseMLP(nn.Module): - """A class for users to create MoE modules in their models. - - Args: - dim_model (int): Hidden dimension of training model - num_experts (int): The number experts - top_k (int, optional): The number of experts for dispatchment of each token - capacity_factor_train (float, optional): Capacity factor in routing during training - capacity_factor_eval (float, optional): Capacity factor in routing during evaluation - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. - 'Jitter' can be found in `Switch Transformer paper`_. - 'Gaussian' can be found in `ViT-MoE paper`_. - drop_tks (bool, optional): Whether drops tokens in evaluation - use_residual (bool, optional): Makes this MoE layer a Residual MoE. - More information can be found in `Microsoft paper`_. - residual_instance (nn.Module, optional): The instance of residual module in Residual MoE - expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer - expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given - expert_args (optional): The args of expert when no instance is given - - .. _Switch Transformer paper: - https://arxiv.org/abs/2101.03961 - .. _ViT-MoE paper: - https://arxiv.org/abs/2106.05974 - .. _Microsoft paper: - https://arxiv.org/abs/2201.05596 - """ - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - router_top_k: int = 1, - router_loss: bool = True, - router_norm: bool = False, - router_capacity_factor_train: float = 1.25, - router_capacity_factor_eval: float = 2.0, - router_min_capacity: int = 4, - router_noisy_policy: Optional[str] = None, - router_drop_tks: bool = True, - mlp_activation: Optional[str] = None, - mlp_gated: bool = False, - enable_load_balance: bool = False, - load_balance_tolerance: float = 0.1, - load_balance_beam_width: int = 8, - load_balance_group_swap_factor: float = 0.4, - enable_kernel: bool = False, - enable_comm_overlap: bool = False, - enable_hierarchical_comm: bool = False, - return_gate_logits: bool = False, - ): - super().__init__() - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_experts = num_experts - self.gated = mlp_gated - self.return_gate_logits = return_gate_logits - self.enable_kernel = enable_kernel - self.enable_comm_overlap = enable_comm_overlap - self.expert_parallel = MOE_MANAGER.get_parallel() - self.router_loss = router_loss - self.router_norm = router_norm - - # moe router - noisy_func = get_noise_generator(router_noisy_policy, num_experts) - router_cls = get_router_cls(router_top_k) - self.topk = router_top_k - self.router: MoeRouter = router_cls( - capacity_factor_train=router_capacity_factor_train, - capacity_factor_eval=router_capacity_factor_eval, - min_capacity=router_min_capacity, - noisy_func=noisy_func, - drop_tks=router_drop_tks, - ) - - # gate - self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) - - # moe experts - self.experts = MLPExperts( - num_experts=self.num_experts, - expert_parallel=self.expert_parallel, - hidden_size=self.hidden_size, - intermediate_size=self.intermediate_size, - activation=mlp_activation, - gated=mlp_gated, - use_kernel=self.enable_kernel, - ) - - # get parallel settings - if self.expert_parallel is not None: - self.ep_group = get_ep_group(self.experts) - self.ep_size = get_ep_size(self.experts) - self.ep_hierarchical_group = None - if enable_hierarchical_comm: - self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group( - get_ep_group_ranks(self.experts) - ) - self.dp_group = get_dp_group(self.experts) - else: - self.ep_group = None - self.dp_group = None - self.num_local_experts = self.experts.num_local_experts - - # load balance - self.enable_load_balance = enable_load_balance - if self.enable_load_balance == True: - self.load_balancer = LoadBalancer( - experts=self.experts, - gate=self.gate_weight, - local_expert_num=self.num_local_experts, - expert_num=self.num_experts, - ep_group=self.ep_group, - dp_group=self.dp_group, - tolerance=load_balance_tolerance, - beam_width=load_balance_beam_width, - group_swap_factor=load_balance_group_swap_factor, - ) - - # init param - self.reset_parameters() - - @torch.no_grad() - def reset_parameters(self): - torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - """ - Args: - inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size) - - Returns: - torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size) - """ - # reshape the input tokens - tokens = inputs.reshape(-1, self.hidden_size) - - # the data type of the inputs in the gating should be fp32 - gate_logits = F.linear(tokens, self.gate_weight) - gate_output = gate_logits.to(torch.float) - - # update expert load - if self.enable_load_balance == True: - with torch.no_grad(): - # TODO: optimize computation - expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1] - # TODO: bincount introduces synchronize, fix it - expert_load = torch.bincount(expert_load.view(-1)) - self.load_balancer.update_load(expert_load) - - # the result from the router - used_capacity, *route_result_list = self.router( - inputs=gate_output, - use_kernel=self.enable_kernel, - ep_group=self.ep_group, - use_loss=self.router_loss, - use_norm=self.router_norm, - ) - - # dispatch_data: (num_experts, capacity, hidden_size) - if self.enable_kernel: - dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) - dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size) - else: - sec_mask_f = route_result_list[1].type_as(inputs) - dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) - - # expert_output: (num_groups, num_experts, capacity, hidden_size) - if self.expert_parallel == "EP": - expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) - elif self.expert_parallel == "TP": - expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) - elif self.expert_parallel is None: - expert_output = self._local_process(dispatch_data) - else: - raise NotImplementedError( - "This kind of communication has not been implemented yet.\n" "Please use Experts build function." - ) - - if self.enable_kernel: - expert_output = expert_output.reshape(-1, self.hidden_size) - ans = MoeCombine.apply(expert_output, *route_result_list) - else: - combine_weights = route_result_list[0].type_as(inputs) - combine_weights = combine_weights.view(combine_weights.shape[0], -1) - expert_output = expert_output.view(-1, expert_output.shape[-1]) - ans = torch.matmul(combine_weights, expert_output) - - ans = ans.reshape(inputs.shape) - - if self.return_gate_logits: - return ans, gate_logits - else: - return ans - - def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: - expert_in = expert_in.unsqueeze(0) - expert_out = self.experts(expert_in) - return expert_out - - def _ep_process( - self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False - ) -> torch.Tensor: - """ - Expert Parallel - - Args: - dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) - - Returns: - torch.Tensor: (num_experts, capacity, hidden_size) - """ - if not overlap or dist.get_world_size(self.ep_group) == 1: - if self.ep_hierarchical_group is not None: - expert_input = HierarchicalAllToAll.apply( - dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank - ) - expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) - expert_output = self.experts(expert_input) - expert_output = HierarchicalAllToAll.apply( - expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank - ) - return expert_output - else: - expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] - expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) - expert_output = self.experts(expert_input) - expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0] - return expert_output - else: - - @dataclasses.dataclass - class Capsule: - data: torch.Tensor - handle: Any = None - - NUM_CHUNK = 4 - NUM_STAGES = 4 - - assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet" - chunk_size = dispatch_data.shape[1] // NUM_CHUNK - input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) - dispatch_data = dispatch_data.reshape(*input_shape) - chunk_data = torch.split(dispatch_data, chunk_size, dim=2) - output = torch.empty_like(dispatch_data) - - offset = 0 - _expert_in, expert_in, _expert_out, expert_out = None, None, None, None - - for i in range(NUM_CHUNK + NUM_STAGES - 1): - if expert_out is not None: - expert_out.handle.wait() - output[:, :, offset : offset + chunk_size, :] = expert_out.data - offset += chunk_size - expert_out = None - - # all2all last output - if _expert_out is not None: - expert_out = Capsule( - *AllToAll.apply(_expert_out.data, self.ep_group, True), - ) - _expert_out = None - - # all2all next input - if 0 <= i < NUM_CHUNK: - _expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True)) - - # compute - if expert_in is not None: - expert_in.handle.wait() - _expert_out = Capsule(data=self.experts(expert_in.data), handle=None) - expert_in = None - - if _expert_in is not None: - expert_in = _expert_in - _expert_in = None - - return output - - def _tp_process( - self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False - ) -> torch.Tensor: - """ - without overlap: - | C | - | A | | R | - - with overlap: - | C1 || C2 || C3 || C4 | - | A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 | - - where C is computation, A is all gather, R is reduce scatter. - - Args: - dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) - - Returns: - torch.Tensor: (num_experts, capacity, hidden_size) - """ - if not overlap or dist.get_world_size(self.ep_group) == 1: - expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0] - expert_out = self.experts(expert_in) - expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0] - return expert_out - else: - - @dataclasses.dataclass - class Capsule: - data: torch.Tensor - handle: Any - indices: Tuple - - NUM_CHUNK = 4 - NUM_STAGES = 4 - - assert ( - dispatch_data.shape[0] % NUM_CHUNK == 0 - ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" - chunk_size = dispatch_data.shape[0] // NUM_CHUNK - chunk_data = torch.split(dispatch_data, chunk_size, dim=0) - output = torch.empty_like(dispatch_data) - - def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: - return (slice(idx * chunk_size, (idx + 1) * chunk_size),) - - _expert_in, expert_in, _expert_out, expert_out = None, None, None, None - - for i in range(NUM_CHUNK + NUM_STAGES - 1): - if expert_out is not None: - expert_out.handle.wait() - output[expert_out.indices] = expert_out.data - expert_out = None - - # reduce scatter last output - if _expert_out is not None: - expert_out = Capsule( - *ReduceScatter.apply(_expert_out.data, self.ep_group, True), - indices=_expert_out.indices, - ) - _expert_out = None - - # all gather next input - if 0 <= i < NUM_CHUNK: - _expert_in = Capsule( - *AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), - indices=get_chunk_slice(i, chunk_size), - ) - - # compute - if expert_in is not None: - expert_in.handle.wait() - _expert_out = Capsule( - self.experts(expert_in.data, expert_in.indices), - handle=None, - indices=expert_in.indices, - ) - expert_in = None - - if _expert_in is not None: - expert_in = _expert_in - _expert_in = None - - return output - - -def apply_load_balance(model: nn.Module, optim: Any) -> None: - """ - apply load balance to every experts in the model - """ - - def _apply_recursive(module: nn.Module): - for _, sub_module in module.named_children(): - if isinstance(sub_module, SparseMLP): - if sub_module.enable_load_balance == True: - sub_module.load_balancer.balance_load(optim) - _apply_recursive(sub_module) - - torch.cuda.empty_cache() - _apply_recursive(model) - torch.cuda.empty_cache() diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py deleted file mode 100644 index 85c12d73fa52..000000000000 --- a/colossalai/moe/load_balance.py +++ /dev/null @@ -1,442 +0,0 @@ -from copy import deepcopy -from typing import List, Optional, Tuple - -import torch -import torch.distributed as dist -from torch import Tensor, nn -from torch.distributed import ProcessGroup - -from colossalai.cluster import ProcessGroupMesh -from colossalai.moe.experts import MLPExperts -from colossalai.moe.manager import MOE_MANAGER -from colossalai.zero.low_level import LowLevelZeroOptimizer - - -class LoadBalancer: - def __init__( - self, - experts: MLPExperts, - gate: nn.Parameter, - local_expert_num: int, - expert_num: int, - ep_group: ProcessGroup, - dp_group: ProcessGroup, - tolerance: Optional[float] = 0.1, - beam_width: Optional[int] = 8, - group_swap_factor: Optional[float] = 0.4, - ) -> None: - self.experts: MLPExperts = experts - self.gate: nn.Parameter = gate - self.moe_ep_group: ProcessGroup = ep_group - self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks - self.moe_dp_group: ProcessGroup = dp_group - self.tolerance = tolerance - self.beam_width = beam_width - self.group_swap_factor = group_swap_factor - self.local_expert_num = local_expert_num - self.expert_num = expert_num - self.local_load = None - # TODO: use a global process group mesh - pp_size = 1 if MOE_MANAGER.pp_size is None else MOE_MANAGER.pp_size - global_dp_group = ProcessGroupMesh(pp_size, dist.get_world_size() // pp_size) - self.global_dp_group = global_dp_group.get_group_along_axis(1) - self.global_dp_rank = dist.get_rank(self.global_dp_group) - self.global_dp_size = dist.get_world_size(self.global_dp_group) - - def _clear_load(self) -> None: - self.local_load = None - - def _sync_load(self) -> Tensor: - new_load = self.local_load.clone().detach() - # all reduce load between ep group - dist.all_reduce(new_load, group=self.moe_ep_group) - # all reduce load between dp group - dist.all_reduce(new_load, group=self.moe_dp_group) - return new_load - - @staticmethod - def _get_diff_from_avg(data: List, group: int, avg: float) -> float: - return abs(sum(data[group]) / len(data[group]) - avg) - - @staticmethod - def _swap_data(data: List, group_i: int, index_i: int, group_j: int, index_j: int) -> None: - data[group_i][index_i], data[group_j][index_j] = ( - data[group_j][index_j], - data[group_i][index_i], - ) - - @staticmethod - def _normalize_data(data: List) -> List: - max_value = max(max(sublist) for sublist in data) - data = [[i / max_value for i in sublist] for sublist in data] - return data - - @staticmethod - def _get_swap_loss( - group_swap_factor: float, - swap_list: List, - group_i: int, - index_i: int, - group_j: int, - index_j: int, - ) -> float: - """ - Get swap loss. The swap loss is used to avoid the situation that - the same index is swapped twice and the same group is swapped for multiple times. - """ - swap_loss = 0 - for swap in swap_list: - for group_id, index_id in zip([group_i, group_j], [index_i, index_j]): - # the group has been swapped - if group_id in [swap[0], swap[2]]: - # the index has been swapped - # we want to avoid the situation that the same index is swapped twice - if index_id in [swap[1], swap[3]]: - swap_loss += 1e5 - # the index has not been swapped - # this is acceptable but as less as possible - else: - swap_loss += group_swap_factor - return swap_loss - - @staticmethod - def _check_convergence(data: List, avg: float, tolerance: float): - """ - Check whether the data is converged after swap. - """ - for sublist in data: - if abs(sum(sublist) / len(sublist) - avg) > tolerance * avg: - return False - return True - - def _beam_search( - self, - inputs: Tuple[List, float, List], - beam_width: int, - avg: float, - group_swap_factor: float, - ) -> List: - """ - Beam search for the best swap combination. - Specifically, we swap two elements from two groups and calculate the score. - The score is the difference between the origin group sum and the new group sum. - The larger the score, the better the swap combination. - - Args: - inputs (Tuple): (data, origin_score, swap_list) - beam_width (int): beam width for beam search - avg (float): average value of the data - group_swap_factor (float): group loss for group swap loss - - Returns: - List: results list - """ - data, origin_score, swap_list = inputs - results = [] - group_num = len(data) - group_size = len(data[0]) - origin_diff_list = [self._get_diff_from_avg(data, i, avg) for i in range(group_num)] - - for group_num_i in range(group_num): - for group_size_i in range(group_size): - for group_num_j in range(group_num_i + 1, group_num): - for group_size_j in range(group_size): - new_data = deepcopy(data) - # calculate origin group sum - origin_diff = origin_diff_list[group_num_i] + origin_diff_list[group_num_j] - # swap data - self._swap_data( - new_data, - group_num_i, - group_size_i, - group_num_j, - group_size_j, - ) - # calculate new group sum - new_diff = self._get_diff_from_avg(new_data, group_num_i, avg) + self._get_diff_from_avg( - new_data, group_num_j, avg - ) - # caculate score - new_score = origin_diff - new_diff - if new_score > 0: - new_score = origin_score + new_score - # get swap loss - swap_loss = self._get_swap_loss( - group_swap_factor, - swap_list, - group_num_i, - group_size_i, - group_num_j, - group_size_j, - ) - new_score = new_score - swap_loss - # update swap list - new_swap_list = swap_list + [(group_num_i, group_size_i, group_num_j, group_size_j)] - results.append((new_data, new_score, new_swap_list)) - # sort results - results.sort(key=lambda x: x[1], reverse=True) - # select top k results - results = results[:beam_width] - return results - - def _load_to_list(self, load: Tensor) -> List: - load_len = len(load) - assert load_len % self.local_expert_num == 0 - load_list = [] - tmp_list = [] - for i in range(len(load)): - tmp_list.append(float(load[i])) - if (i + 1) % self.local_expert_num == 0: - load_list.append(tmp_list) - tmp_list = [] - return load_list - - def _search_balance( - self, - data: List, - tolerance: Optional[float] = 0.1, - beam_width: Optional[int] = 8, - group_swap_factor: Optional[float] = 0.4, - return_swapped_data: Optional[bool] = False, - ) -> Tuple[List, List]: - """ - Search for the best swap combination to balance the data within the specified tolerance. - And return the balanced data and the swap list. The swap list is used to record the swap. - The swap list is a list of tuples. Each tuple is a swap operation. - - Args: - data (List): expert load list. - E.g. [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]] - This means there are 4 devices and each devices has 2 experts. - The value is the load of the expert. - tolerance (float): tolerance for balance. - beam_width (int): beam width for beam search. - group_swap_factor (float): group swap factor for group swap loss. - The bigger it is, the less times a group will be swapped. - return_swapped_data (bool): whether to return the swapped data. - - Returns: - Tuple: (balanced data, swap list). - The swap list is a list of tuples. Each tuple is a swap operation. - E.g. [(0, 0, 1, 0), (...), (...)]. The first tuple means - the first expert of the first device is swapped with the first expert - of the second device. - """ - norm_data = self._normalize_data(data) - avg = sum(sum(sublist) / len(sublist) for sublist in norm_data) / len(norm_data) - results = [(norm_data, 0, [])] - stop_flag = False - - while stop_flag == False: - new_results = [] - best_score = results[0][1] - for i in range(len(results)): - new_results.extend(self._beam_search(results[i], beam_width, avg, group_swap_factor)) - if len(new_results) == 0: - stop_flag = True - break - new_results.sort(key=lambda x: x[1], reverse=True) - new_best_score = new_results[0][1] - if new_best_score == best_score: - stop_flag = True - break - new_results = new_results[:beam_width] - results = new_results - for i in results: - if self._check_convergence(results[0][0], avg, tolerance): - stop_flag = True - break - - swap_list = results[0][2] - if return_swapped_data: - out = deepcopy(data) - for swap in swap_list: - self._swap_data(out, *swap) - return out, swap_list - else: - return swap_list - - @staticmethod - def _swap_expert_single_tensor( - weight: nn.Parameter, - expert_idx: int, - comm_group: ProcessGroup, - send_first: bool, - comm_rank: int, - ): - # exchange weight - local_weight = weight.data[expert_idx] - new_weight = torch.empty_like(local_weight) - if send_first: - dist.send(local_weight, dst=comm_rank, group=comm_group) - dist.recv(new_weight, src=comm_rank, group=comm_group) - else: - dist.recv(new_weight, src=comm_rank, group=comm_group) - dist.send(local_weight, dst=comm_rank, group=comm_group) - weight.data[expert_idx] = new_weight - - def _swap_expert_param_and_optim( - self, - weight: nn.Parameter, - expert_idx: int, - comm_group: ProcessGroup, - send_first: bool, - comm_rank: int, - optim: LowLevelZeroOptimizer, - ): - # need to update master and working param if master param exists - # else just update working param - if weight in optim.optim.state: - master_weight_ptr = None - working_weight_ptr = weight - exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"] - exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"] - else: - master_weight_ptr = optim._param_store.working_to_master_param[id(weight)] - working_weight_ptr = weight - exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"] - exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"] - - # exchange weight - self._swap_expert_single_tensor( - working_weight_ptr, - expert_idx, - comm_group, - send_first, - comm_rank, - ) - if master_weight_ptr is not None: - # TODO: exchange master weight, skip for now - # master weight is shared by dp group - tmp = working_weight_ptr.view(-1).split( - working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group) - )[dist.get_rank(self.moe_dp_group)] - master_weight_ptr.data.copy_(tmp.clone().detach().to(master_weight_ptr.device).to(master_weight_ptr.dtype)) - # exchange optim - self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank) - self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank) - - def _gather_global_dp_group(self, data: Tensor) -> Tensor: - data_list = [torch.zeros_like(data) for _ in range(self.global_dp_size)] - dist.all_gather(data_list, data, group=self.global_dp_group) - data_list = torch.cat(data_list, dim=0) - return data_list - - def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None: - """ - Swap moe param and optim. - We use different strategies to swap expert and gate. - For expert, we exchange the param and optim of the expert by p2p. - For gate, we all gather the gate choose the part we want. - - Args: - swap_list (List) - optim (LowLevelZeroOptimizer) - """ - # get all experts weights - local_rank = dist.get_rank(self.moe_ep_group) - if self.experts.gated: - weight_list = [self.experts.wi_up, self.experts.wi_gate] - else: - weight_list = [self.experts.wi] - weight_list.append(self.experts.wo) - - # gate optim should be obtained first - gate_shape = self.gate.shape - # get master weight and optim - master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)] - gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"] - gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"] - # gather - global_master_gate_weight = self._gather_global_dp_group(master_gate_weight).view(gate_shape) - global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg).view(gate_shape) - global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq).view(gate_shape) - assert ( - self.gate.shape - == global_master_gate_weight.shape - == global_gate_exp_avg.shape - == global_gate_exp_avg_sq.shape - ) - - for swap in swap_list: - source_group, source_idx, target_group, target_idx = swap - source_rank = self.moe_ep_ranks[source_group] - target_rank = self.moe_ep_ranks[target_group] - # exchange expert - if local_rank in [source_group, target_group]: - for weight in weight_list: - if local_rank == source_group: - self._swap_expert_param_and_optim( - weight, - source_idx, - self.moe_ep_group, - True, - target_rank, - optim, - ) - elif local_rank == target_group: - self._swap_expert_param_and_optim( - weight, - target_idx, - self.moe_ep_group, - False, - source_rank, - optim, - ) - # exchange gate - source_expert_pos = source_group * self.local_expert_num + source_idx - target_expert_pos = target_group * self.local_expert_num + target_idx - for gate in [ - self.gate, - global_master_gate_weight, - global_gate_exp_avg, - global_gate_exp_avg_sq, - ]: - origin_source = gate.data[source_expert_pos].clone().detach() - origin_target = gate.data[target_expert_pos].clone().detach() - gate.data[source_expert_pos], gate.data[target_expert_pos] = ( - origin_target, - origin_source, - ) - - # update gate - global_master_gate_weight = global_master_gate_weight.view(-1).split( - global_master_gate_weight.numel() // self.global_dp_size - )[self.global_dp_rank] - master_gate_weight.data.copy_(global_master_gate_weight) - global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // self.global_dp_size)[ - self.global_dp_rank - ] - gate_exp_avg.data.copy_(global_gate_exp_avg) - global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split( - global_gate_exp_avg_sq.numel() // self.global_dp_size - )[self.global_dp_rank] - gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq) - - @torch.no_grad() - def update_load(self, load: Tensor) -> None: - if len(load) != self.expert_num: - padding_size = self.expert_num - len(load) - padding = torch.zeros(padding_size, dtype=load.dtype, device=load.device) - load = torch.cat((load, padding), dim=0) - if self.local_load is None: - self.local_load = load - else: - self.local_load += load - - @torch.no_grad() - def balance_load(self, optim: LowLevelZeroOptimizer) -> None: - # prepare load - load = self._sync_load() - load = self._load_to_list(load) - # search balance - swap_list = self._search_balance(load) - if dist.get_rank() == 0: - if len(swap_list) > 0: - print(f"[Load Balance] Applying expert swap...") - else: - print(f"[Load Balance] Invalid swap, skip...") - # swap expert and gate - self._swap_moe_param(swap_list, optim) - # clear load - self._clear_load() diff --git a/colossalai/moe/loss.py b/colossalai/moe/loss.py deleted file mode 100644 index 75624510b452..000000000000 --- a/colossalai/moe/loss.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch.nn as nn -from torch.nn.modules.loss import _Loss - -from colossalai.moe.manager import MOE_MANAGER - - -class MoeCrossEntropyLoss(_Loss): - r"""torch.nn.CrossEntropyLoss added with auxiliary loss. - - Args: - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. - - The ``args`` and ``kwargs`` should include parameters below: - :: - - weight (Tensor, optional) - size_average (bool, optional) - ignore_index (int, optional) - reduce (bool, optional) - reduction (str, optional) - label_smoothing (float, optional) - - More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in - `Cross_entropy `_. - """ - - def __init__(self, aux_weight: float = 0.01, *args, **kwargs): - super().__init__() - self.loss = nn.CrossEntropyLoss(*args, **kwargs) - self.aux_weight = aux_weight - - def forward(self, *args): - """ - The ``args`` should at least include parameters below: - :: - - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - - More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in - `Cross_entropy `_. - """ - main_loss = self.loss(*args) - aux_loss = MOE_MANAGER.get_loss() - return main_loss + self.aux_weight * aux_loss - - -class MoeLoss(_Loss): - """A wrapper class for any loss module to add with auxiliary loss. - - Args: - aux_weight (float): Weight of auxiliary loss in total loss. - loss_fn (``Callable``): Loss function. - args (list): Args in loss function. - kwargs (dict): Kwargs in loss function - """ - - def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): - super().__init__() - self.loss_fn = loss_fn(*args, **kwargs) - self.aux_weight = aux_weight - - def forward(self, *args, **kwargs): - """ - The ``args`` and ``kwargs`` should at least include parameters below: - :: - - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - - Note: - The ``args`` and ``kwargs`` may include different parameters varying with different loss function. - """ - main_loss = self.loss_fn(*args, **kwargs) - aux_loss = MOE_MANAGER.get_loss() - return main_loss + self.aux_weight * aux_loss diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py deleted file mode 100644 index e40674c9bb44..000000000000 --- a/colossalai/moe/routers.py +++ /dev/null @@ -1,466 +0,0 @@ -import math -from abc import ABC -from typing import Callable, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from torch.distributed import ProcessGroup - -from colossalai.accelerator import get_accelerator -from colossalai.moe._operation import moe_cumsum -from colossalai.moe.manager import MOE_MANAGER - - -class MoeRouter(nn.Module, ABC): - """Base class for all MoE routers. - Args: - k_value (int): The value of top_k. - capacity_factor_train (float): Capacity factor in routing of training. - capacity_factor_eval (float): Capacity factor in routing of evaluation. - min_capacity (int): The minimum number of the capacity of each expert. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__( - self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - use_kernel: bool = False, - ): - super().__init__() - self.k_value = k_value - self.capacity_factor_train = capacity_factor_train - self.capacity_factor_eval = capacity_factor_eval - self.min_capacity = min_capacity - self.noisy_func = noisy_func - self.drop_tks = drop_tks - self._aux_loss = None - self._z_loss = None - self.use_kernel = use_kernel - - def get_capacity(self, num_tokens, num_experts, ep_group=None): - if ep_group is not None: - num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device()) - dist.all_reduce(num_tokens_tensor, group=ep_group) - num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group) - capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts) - capacity += capacity % 2 - capacity = max(capacity, self.min_capacity) - assert capacity > 0 - return int(capacity) - - def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None: - """Computes auxiliary load balancing loss as in Switch Transformer. - - See Switch Transformer (https://arxiv.org/abs/2101.03961). This function - implements the loss function presented in equations (4) - (6). It aims to - penalize those cases where the routing between experts is unbalanced. - - Args: - router_probs: Probability assigned to each expert per token. Shape: - [num_groups, tokens_per_group, num_experts]. - expert_indices: [num_groups, tokens_per_group, num_selected_experts] - indices identifying the top num_selected_experts for a given token. - """ - assert self._aux_loss is None - if router_probs.dim() == expert_indices.dim() == 2: - router_probs = router_probs.unsqueeze(0) - expert_indices = expert_indices.unsqueeze(0) - assert ( - router_probs.dim() == expert_indices.dim() == 3 - ), "router_probs must be 3D tensor and expert_indices must be 4D tensor" - - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - expert_mask = F.one_hot(expert_indices, num_experts) - # For a given token, determine if it was routed to a given expert. - # Shape: [num_groups, tokens_per_group, num_experts] - expert_mask = expert_mask.max(dim=-2)[0] - - tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2) - router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2) - aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) - self._aux_loss = aux_loss - - def set_z_loss(self, router_logits: torch.Tensor): - """Compute router z-loss. - - The router z-loss was introduced in Designing Effective Sparse Expert Models - (https://arxiv.org/abs/2202.08906). It encourages router logits to remain - small in an effort to improve stability. - - Args: - router_logits: [num_groups, tokens_per_group, num_experts] router logits. - """ - assert self._z_loss is None - if router_logits.dim() == 2: - router_logits = router_logits.unsqueeze(0) - assert router_logits.dim() == 3, "router_logits must be 3D tensor" - num_groups, tokens_per_group, _ = router_logits.shape - log_z = torch.logsumexp(router_logits, dim=-1) - z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group) - self._z_loss = z_loss - - def pop_router_loss(self) -> torch.Tensor: - assert self._aux_loss is not None - MOE_MANAGER.add_loss(self._aux_loss, self._z_loss) - self._aux_loss = None - self._z_loss = None - - -class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) - and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed - function can be found in the paper about Switch Transformer of Google. - - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert. - select_policy (str, optional): The policy about tokens selection. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__( - self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - ): - super().__init__( - k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks, - ) - self.select_policy = select_policy - assert select_policy in {"first", "random"} - if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_accelerator().get_current_device()), - high=torch.tensor(1.0, device=get_accelerator().get_current_device()), - ).rsample - - def forward( - self, - inputs: torch.Tensor, - use_kernel: bool = False, - ep_group: Optional[ProcessGroup] = None, - use_loss: bool = False, - use_norm: bool = False, - ) -> Tuple: - """ - Args: - inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). - - Returns: - 1. use_kernel is False: - The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). - The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). - 2. use_kernel is True: - ... - """ - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - probs = F.softmax(inputs, dim=-1) - num_experts = probs.size(-1) - num_tokens = inputs.size(0) - capacity = self.get_capacity(num_tokens, num_experts, ep_group) - - top1_idx = torch.argmax(inputs, dim=-1) - mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - - # calculate router loss - self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts) - self.set_z_loss(inputs) - self.pop_router_loss() - - if not self.training and not self.drop_tks and ep_group is not None: - max_num = torch.max(torch.sum(mask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - if self.select_policy == "random": - rand_mask = mask * self.uniform(mask.shape) - _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) - mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) - ranks = moe_cumsum(mask, use_kernel=self.use_kernel) - elif self.select_policy == "first": - ranks = moe_cumsum(mask, use_kernel=self.use_kernel) - mask = mask * torch.lt(ranks, capacity) - else: - raise NotImplementedError("Not support such select policy yet.") - - ranks = torch.sum(mask * ranks, dim=-1) - used_capacity = mask.sum(dim=0) - - if use_kernel: - mask = torch.sum(mask, dim=-1) - mask = torch.stack([mask], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) - return used_capacity, probs, mask, dest_idx, num_experts * capacity - else: - ranks = F.one_hot(ranks, num_classes=capacity) - weight = mask * probs.type_as(inputs) - combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) - sec_mask = combine_weights.bool() - return used_capacity, combine_weights, sec_mask, probs - - -class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) - and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed - function can be found in the paper about ViT-MoE. - - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation. - """ - - def __init__( - self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - ): - super().__init__( - k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks, - ) - - def forward( - self, - inputs: torch.Tensor, - use_kernel: bool = False, - ep_group: Optional[ProcessGroup] = None, - use_norm: bool = False, - use_loss: bool = True, - ) -> Tuple: - """ - Args: - inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). - - Returns: - 1. use_kernel is False: - The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). - The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). - 2. use_kernel is True: - ... - """ - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - probs = F.softmax(inputs, dim=-1) - if use_norm: - routing_weights, _ = torch.topk(probs, 2, dim=-1) - probs = probs / routing_weights.sum(dim=-1, keepdim=True) - - num_experts = probs.size(-1) - num_tokens = inputs.size(0) - capacity = self.get_capacity(num_tokens, num_experts, ep_group) - - top1_idx = torch.argmax(probs, dim=-1) - mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - logits_except1 = probs.masked_fill(mask1.bool(), float("-inf")) - top2_idx = torch.argmax(logits_except1, dim=-1) - mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - - cmask = mask1 + mask2 # loss: [s, e] - cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 - - # calculate loss - if use_loss: - expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) - self.set_aux_loss(probs, expert_indices, num_experts) - self.set_z_loss(inputs) - self.pop_router_loss() - - if not self.training and not self.drop_tks and ep_group is not None: - max_num = torch.max(torch.sum(cmask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] - rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) - rank2 += torch.sum(mask1, dim=-2, keepdim=True) - - mask1 *= torch.lt(rank1, capacity) - mask2 *= torch.lt(rank2, capacity) - used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0) - - rank1 = torch.sum(mask1 * rank1, dim=-1) - rank2 = torch.sum(mask2 * rank2, dim=-1) - - if use_kernel: - mask1 = torch.sum(mask1, dim=-1) - mask2 = torch.sum(mask2, dim=-1) - - mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - - return used_capacity, probs, mask, dest_idx, num_experts * capacity - else: - """ - The following code is equivalent to: - - ``` - weight1 = mask1 * probs.type_as(inputs) - weight2 = mask2 * probs.type_as(inputs) - rank1_sc = F.one_hot(rank1, num_classes=capacity) - rank2_sc = F.one_hot(rank2, num_classes=capacity) - - cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) - cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) - cb_weight = cb_weight1 + cb_weight2 - sec_mask = cb_weight.bool() - ``` - """ - - weight1 = mask1 * probs.type_as(inputs) - weight2 = mask2 * probs.type_as(inputs) - - cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device) - sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool) - indices = torch.arange(0, inputs.shape[0], device=inputs.device) - cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]] - cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]] - sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]] - sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]] - - return used_capacity, cb_weight, sec_mask - - -class TopKRouter(MoeRouter): - """Masked matmul router using tokens choose top-k experts assignment. - - NOTE: this is modified from flaxformer. - This router uses the same mechanism as in Switch Transformer - (https://arxiv.org/abs/2101.03961) and V-MoE - (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are - sorted by router_probs and then routed to their choice of expert until the - expert's expert_capacity is reached. There is no guarantee that each token is - processed by an expert, or that each expert receives at least one token. - - Attributes: - num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are - oversubscribed / reach capacity. - """ - - def __init__( - self, - num_selected_experts: int, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - ): - super().__init__( - num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks - ) - - def forward( - self, - router_probs: torch.Tensor, - expert_capacity: int, - ) -> Tuple: - """Computes masks for the top-k experts per token. - - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - - Returns: - Dispatch and combine arrays for routing with masked matmuls. - """ - # TODO: FIXME: add parallel group - num_groups, _, num_experts = router_probs.shape - - # Top-k router probability and corresponding expert indices for each token. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - expert_gate, expert_index = torch.topk(router_probs, self.k_value) - - self.set_aux_loss(router_probs, expert_index, num_experts) - self.pop_router_loss() - - # Make num_selected_experts the leading axis to ensure that top-1 choices - # have priority over top-2 choices, which have priority over top-3 choices, - # etc. - expert_index = torch.transpose(expert_index, 1, 2) - # Shape: [num_groups, num_selected_experts * tokens_per_group] - expert_index = expert_index.reshape(num_groups, -1) - - # Create mask out of indices. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) - - # Experts have a fixed capacity that we cannot exceed. A token's priority - # within the expert's buffer is given by the masked, cumulative capacity of - # its target expert. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1 - # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. - token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts)) - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - token_priority = torch.transpose(token_priority, 1, 2) - # For each token, across all selected experts, select the only non-negative - # (unmasked) priority. Now, for group G routing to expert E, token T has - # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E - # is its targeted expert. - # Shape: [num_groups, tokens_per_group, num_experts]. - token_priority = torch.max(token_priority, dim=2)[0] - - # Token T can only be routed to expert E if its priority is positive and - # less than the expert capacity. One-hot matrix will ignore indices outside - # the range [0, expert_capacity). - # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. - valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) - token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) - dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) - valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity) - dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) - - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, - # expert_capacity]. - combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) - - return combine_array, dispatch_mask - - -def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter: - if not grouped: - if top_k == 1: - return Top1Router - elif top_k == 2: - return Top2Router - else: - raise NotImplementedError("top_k > 2 is not supported yet") - else: - return TopKRouter diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/colossalai/shardformer/modeling/mixtral.py similarity index 100% rename from applications/ColossalMoE/colossal_moe/models/mixtral_layer.py rename to colossalai/shardformer/modeling/mixtral.py diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 69df021b0828..e33bd808981a 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -192,6 +192,12 @@ class PolicyLocation: "transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation( file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy" ), + "transformers.models.mixtral.modeling_mixtral.MixtralModel": PolicyLocation( + file_name="mixtral", class_name="MixtralModelPolicy" + ), + "transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation( + file_name="mixtral", class_name="MixtralForCausalLMPolicy" + ), } diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/colossalai/shardformer/policies/mixtral.py similarity index 99% rename from applications/ColossalMoE/colossal_moe/models/mixtral_policy.py rename to colossalai/shardformer/policies/mixtral.py index c01e02c49a60..87e3476c9e14 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -17,11 +17,10 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from colossalai.shardformer.shard import ShardConfig -from .mixtral_layer import EPMixtralSparseMoeBlock - __all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]