diff --git a/internlm/model/moe/dropless_layer.py b/internlm/model/moe/dropless_layer.py index 6f92bda8..d0342430 100644 --- a/internlm/model/moe/dropless_layer.py +++ b/internlm/model/moe/dropless_layer.py @@ -152,6 +152,7 @@ def __init__( moe_grouped_mlp: bool = True, enable_fused_permute: bool = True, token_dispatch_policy: str = "alltoall", + deterministic_mode: bool = False, ) -> None: assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], ( "Unsupported noisy_gate_policy: " + noisy_gate_policy @@ -175,6 +176,7 @@ def __init__( mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, activation_type=activation_type, + is_expert=True, ) for _ in range(num_experts // ep_size) ] @@ -199,6 +201,7 @@ def __init__( assert len(self.local_expert_indices) > 0, "Expected at least one local expert index" self.topk = top_k self.moe_grouped_mlp = moe_grouped_mlp + self.deterministic_mode = deterministic_mode self.drop_and_pad = drop_and_pad self.capacity_factor = capacity_factor @@ -255,10 +258,13 @@ def forward(self, *inputs: Tensor) -> Tensor: # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 reshaped_inputs = inputs[0].reshape(-1, d_model) - self.gates = self.gate(reshaped_inputs) - expert_weights, indices = self.topk_softmax_with_capacity(self.gates) + gates = self.gate(reshaped_inputs) + expert_weights, indices, tokens_per_expert_before_capacity = self.topk_softmax_with_capacity(gates) + self.l_aux = self.load_balancing_loss(tokens_per_expert_before_capacity, gates) - (dispatched_input, tokens_per_expert) = self.token_permutation_func(reshaped_inputs, expert_weights, indices) + (dispatched_input, tokens_per_expert) = self.token_permutation_func( + reshaped_inputs, expert_weights, indices, tokens_per_expert_before_capacity + ) if self.moe_grouped_mlp: expert_output = self.experts(dispatched_input, batch_sizes=tokens_per_expert) else: @@ -272,11 +278,18 @@ def forward(self, *inputs: Tensor) -> Tensor: def topk_softmax_with_capacity(self, gates): expert_weights, indices = torch.topk(gates, self.topk, dim=1) expert_weights /= expert_weights.sum(dim=-1, keepdim=True) + # we compute num_local_tokens_per_expert here. If no drop and padding, num_local_tokens_per_expert should be + # the final value, otherwise we recompute it in self.process(.) + # histc(.) can be faster the bincount(.), but will cause non-deterministic behavior + if self.deterministic_mode: + num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts) + else: + num_local_tokens_per_expert = torch.histc(indices, bins=self.num_experts, min=0, max=self.num_experts) # without capacity if self.capacity_factor is None: # shape: [num_token, topk] - return expert_weights, indices + return expert_weights, indices, num_local_tokens_per_expert # with capacity expert_capacity = get_capacity( @@ -311,7 +324,9 @@ def topk_softmax_with_capacity(self, gates): final_expert_weights = expert_weights * torch.logical_not(exceed_mask) final_indices = indices.clone().masked_fill_(exceed_mask, torch.iinfo(torch.long).max) - return final_expert_weights, final_indices + tokens_per_expert_before_capacity = topk_mask.sum(dim=0) + + return final_expert_weights, final_indices, tokens_per_expert_before_capacity def _gather_along_first_dim_expert_parallel(self, input_): """Gather tensors and concatenate along the first dimension.""" @@ -329,7 +344,7 @@ def _gather_along_first_dim_expert_parallel(self, input_): return output - def preprocess(self, indices, expert_weight) -> torch.Tensor: + def preprocess(self, indices, expert_weight, tokens_per_expert_before_capacity) -> torch.Tensor: """ Preprocess token indices for AlltoAll communication and token permutation. This method computes the number of tokens assigned to each expert based on the input indices. @@ -342,7 +357,13 @@ def preprocess(self, indices, expert_weight) -> torch.Tensor: """ # NOTE: bincount seem slower than histc # num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts) - num_local_tokens_per_expert = torch.histc(indices, bins=self.num_experts, min=0, max=self.num_experts) + if self.capacity_factor is not None: + if self.deterministic_mode: + num_local_tokens_per_expert = torch.bincount(indices.view(-1), minlength=self.num_experts) + else: + num_local_tokens_per_expert = torch.histc(indices, bins=self.num_experts, min=0, max=self.num_experts) + else: + num_local_tokens_per_expert = tokens_per_expert_before_capacity # num_local_tokens_per_expert: [num_experts] if self.drop_and_pad: @@ -415,8 +436,6 @@ def preprocess(self, indices, expert_weight) -> torch.Tensor: -1, self.num_local_experts ).to(torch.device("cpu"), non_blocking=True) - self.l_aux = self.load_balancing_loss(num_local_tokens_per_expert, self.gates) - return num_tokens_per_local_expert def permute_with_padded_tokens(self, tokens, indices): @@ -570,6 +589,7 @@ def token_permutation_by_alltoall( reshaped_inputs: torch.Tensor, expert_weights: torch.Tensor, indices: torch.Tensor, + tokens_per_expert_before_capacity: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Dispatch tokens to local experts using AlltoAll communication. @@ -585,7 +605,7 @@ def token_permutation_by_alltoall( # Preprocess: Get the metadata for communication, permutation and computation operations. assert expert_weights.dim() == 2, "Expected 2D tensor for expert weights" assert indices.dim() == 2, "Expected 2D tensor for indices" - tokens_per_expert = self.preprocess(indices, expert_weights) + tokens_per_expert = self.preprocess(indices, expert_weights, tokens_per_expert_before_capacity) # Permutation 1: input to AlltoAll input self.hiddden_shape_before_permute = reshaped_inputs.shape @@ -674,6 +694,7 @@ def token_permutation_by_all_gather( reshaped_inputs: torch.Tensor, expert_weights: torch.Tensor, indices: torch.Tensor, + tokens_per_expert_before_capacity: torch.Tensor, # pylint: disable=W0613 ): """Dispatch tokens to local experts. It's composed of two stages: (1) Permute the tokens across the expert parallel devices. After this stage,