diff --git a/examples/nlp/language_modeling/megatron_change_num_partitions.py b/examples/nlp/language_modeling/megatron_change_num_partitions.py index 0949d90f4b96..944565d8bd43 100644 --- a/examples/nlp/language_modeling/megatron_change_num_partitions.py +++ b/examples/nlp/language_modeling/megatron_change_num_partitions.py @@ -54,7 +54,7 @@ --target_pipeline_model_parallel_size=1 \ --target_pipeline_model_parallel_split_rank=0 \ --precision=bf16 - + ### Only Tensor Parallelism conversion ### To the above commands, add the following argument: `--tp_conversion_only` @@ -99,13 +99,14 @@ """ + ################# ### Utilities ### ################# def compute_tp_splits( - param_name, param, partitions, global_idx, tp_size, pp_size, pp_rank, pp_split_rank, megatron_legacy + param_name, param, partitions, global_idx, tp_size, pp_size, pp_rank, pp_split_rank, megatron_legacy, model_cfg ): """ Function to compute the splits required for tensor-parallelism. @@ -120,6 +121,7 @@ def compute_tp_splits( pp_rank: Int, pipeline-parallelism rank. pp_split_rank: Int, pipeline-parallelism split rank. This should be > 1 if TP is being used with EncDec models (T5) megatron_legacy: Bool, whether the model is a legacy Megatron model or not. + model_cfg: The model config as a OmegaConf DictConfig. Returns: List of torch tensors, each of which is a split of the current parameter. @@ -127,6 +129,8 @@ def compute_tp_splits( # alias the global index to idx idx = global_idx + swiglu_activation = 'swiglu' in str(model_cfg.get('activation', '')).lower() + if param.shape == partitions[0][idx].shape: split = [partitions[0][idx].data] * tp_size logging.debug(">> Perfect match, no splitting needed") @@ -156,6 +160,15 @@ def compute_tp_splits( for i in range(tp_size): tp_qkv = torch.cat([tp_qkv_splits[item] for item in range(i, tp_size * 2, tp_size)]) split.append(tp_qkv) + elif 'dense_h_to_4h.weight' in param_name and swiglu_activation: + # For Megatron GPT model with Swiglu activation + # Handle gated linear units + # concat all the first halves ('W's) and all the second halves ('V's) + w_split, k_split = torch.chunk(partitions[0][idx].data, 2, dim=0) + w_split = torch.chunk(w_split, tp_size, dim=0) + k_split = torch.chunk(k_split, tp_size, dim=0) + split = [torch.cat(weights, dim=0) for weights in zip(w_split, k_split)] # split per tp rank + # Regular split for Megatron and NeMo-Megatron models. else: split = torch.split(partitions[0][idx].data, param.shape[0], dim=0) @@ -163,7 +176,7 @@ def compute_tp_splits( return split -def compute_tp_merge(idx, name, param, partitions_pp): +def compute_tp_merge(idx, name, param, partitions_pp, model_cfg): """ Function to compute the partition merge required for tensor-parallelism. @@ -173,10 +186,13 @@ def compute_tp_merge(idx, name, param, partitions_pp): param: The parameter to be merged under TP 1 PP 1. partitions_pp: List of all TP partitions of the flattened parameter of the current model for a given PP rank (TP X PP Y). Indexed as partitions_pp[tp_rank][idx]. + model_cfg: The model config as an OmegaConf DictConfig. Returns: The concatenated parameter for TP 1 PP 1. """ + swiglu_activation = 'swiglu' in str(model_cfg.get('activation', '')).lower() + # Logic from original TP rank change if param.shape == partitions_pp[0][idx].shape: concated = partitions_pp[0][idx].data @@ -184,6 +200,19 @@ def compute_tp_merge(idx, name, param, partitions_pp): concated = torch.cat([partitions_pp[i][idx].data for i in range(len(partitions_pp))], dim=-1) else: concated = torch.cat([partitions_pp[i][idx].data for i in range(len(partitions_pp))], dim=0) + + # Logic for Swiglu activation + if 'dense_h_to_4h.weight' in name and swiglu_activation: + # concat all the first halves ('W's) and all the second halves ('V's) + wk_splits = [] + for tpr in range(len(partitions_pp)): + wk_splits.append(torch.chunk(partitions_pp[tpr][idx].data, 2, dim=0)) + + w_split = torch.cat([w[0] for w in wk_splits], dim=0) + k_split = torch.cat([w[1] for w in wk_splits], dim=0) + concated = torch.cat([w_split, k_split], dim=0) + + # Trim padding if concated.shape != param.shape: logging.info( f"Warning: Shape mismatch for parameter {name} required shape: {param.shape}, merged shape: {concated.shape}. Narrowing to match required size." @@ -301,7 +330,16 @@ def compute_splits(self, model, partitions, idx, tp_rank, pp_rank, pp_split_rank # Tensor Parallel Splitting split = compute_tp_splits( - param_name, param, partitions, idx, tp_size, pp_size, pp_rank, pp_split_rank, self.megatron_legacy + param_name, + param, + partitions, + idx, + tp_size, + pp_size, + pp_rank, + pp_split_rank, + self.megatron_legacy, + model.cfg, ) splits.append(split) @@ -419,7 +457,16 @@ def compute_splits(self, model, partitions, idx, tp_rank, pp_rank, pp_split_rank # Tensor Parallel Splitting split = compute_tp_splits( - param_name, param, partitions, idx, tp_size, pp_size, pp_rank, pp_split_rank, self.megatron_legacy + param_name, + param, + partitions, + idx, + tp_size, + pp_size, + pp_rank, + pp_split_rank, + self.megatron_legacy, + model.cfg, ) splits.append(split) @@ -445,12 +492,13 @@ def compute_splits(self, model, partitions, idx, tp_rank, pp_rank, pp_split_rank param_name, param, partitions, - 0, - tp_size, - pp_size, - pp_rank, - pp_split_rank, - self.megatron_legacy, + global_idx=0, + tp_size=tp_size, + pp_size=pp_size, + pp_rank=pp_rank, + pp_split_rank=pp_split_rank, + megatron_legacy=self.megatron_legacy, + model_cfg=model.cfg, ) splits.insert(self.intermediate_shared_embedding_location, split) break @@ -534,7 +582,7 @@ def merge_partition(model, partitions: Dict[int, List[List[torch.Tensor]]], writ ) # Original TP rank change logic - concated = compute_tp_merge(idx, name, param, partitions_pp) + concated = compute_tp_merge(idx, name, param, partitions_pp, model.cfg) # Update the model parameter with the merged tensor param.data = concated @@ -656,6 +704,7 @@ def split_tp_partition_only(model, partitions, tp_size, write_path=None, megatro pp_rank=0, pp_split_rank=0, megatron_legacy=megatron_legacy, + model_cfg=model.cfg, ) splits.append(split) idx += 1