From 011edcf27e3a5a020d9295e914e0542a8047debf Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 29 Jan 2024 16:50:33 +0800 Subject: [PATCH] feat(utils/parallel.py): add func is_using_isp --- .../solver/optimizer/hybrid_zero_optim.py | 6 ++---- internlm/solver/optimizer/utils.py | 5 ++--- internlm/train/training_internlm.py | 20 ++++++++----------- internlm/train/utils.py | 4 ++-- internlm/utils/model_checkpoint.py | 17 ++++++++-------- internlm/utils/parallel.py | 18 ++++++++--------- 6 files changed, 32 insertions(+), 38 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 44111cb9..d603539b 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -38,6 +38,7 @@ from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.parallel import is_using_isp from internlm.utils.timeout import llm_timeout from .base_optimizer import BaseOptimizer @@ -85,10 +86,7 @@ def __init__( clip_grad_norm = zero_cfg.clip_grad_norm self._overlap_sync_grad = zero_cfg.overlap_sync_grad self._overlap_sync_param = zero_cfg.overlap_sync_param - self.use_isp = ( - isinstance(gpc.config.parallel["tensor"], dict) - and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" - ) + self.use_isp = is_using_isp() super().__init__(optim=optimizer) diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index ff707a42..ffa06477 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -22,6 +22,7 @@ is_tensor_data_parallel_parameter, is_tensor_expert_data_parallel_parameter, is_tensor_zero_parallel_parameter, + is_using_isp, is_weight_zero_parallel_parameter, ) @@ -312,9 +313,7 @@ def compute_norm( Total norm of the parameters, need total_norm**(1/norm) before using. """ - weight_parallel_mode = ( - ParallelMode.WEIGHT if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" else ParallelMode.TENSOR - ) + weight_parallel_mode = ParallelMode.WEIGHT if is_using_isp() else ParallelMode.TENSOR enable_cuda_kernels = gradients[0].device.type == "cuda" # Norm parameters. norm_type = float(norm_type) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 62a9d060..2fe61b7d 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -78,6 +78,7 @@ from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import ( + is_using_isp, is_replica_zero_parallel_parameter, is_tensor_data_parallel_parameter, is_tensor_expert_data_parallel_parameter, @@ -105,8 +106,6 @@ def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]): def set_parallel_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]): - tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") - def _check_module(module): # layer_norm if isinstance(module, (RMSNorm, nn.LayerNorm)): @@ -120,9 +119,9 @@ def _check_module(module): # embedding and head if isinstance(module, (Embedding1D, ParallelGPT2Embeddings, BaseScaleColumnParallelLinear)): for param in module.parameters(): - if gpc.is_initialized(ParallelMode.TENSOR) and tp_mode == "isp": + if gpc.is_initialized(ParallelMode.TENSOR) and is_using_isp(): setattr(param, IS_TENSOR_DATA_PARALLEL, True) - elif gpc.is_initialized(ParallelMode.TENSOR) and tp_mode != "isp": + elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): setattr(param, IS_TENSOR_ZERO_PARALLEL, True) # for linear module @@ -131,9 +130,9 @@ def _check_module(module): if gpc.is_initialized(ParallelMode.EXPERT_DATA) and is_moe_param(param): # module should be MoE experts's linear setattr(param, IS_TENSOR_EXPERT_DATA_PARALLEL, True) - elif not is_moe_param(param) and gpc.is_initialized(ParallelMode.TENSOR) and tp_mode != "isp": + elif not is_moe_param(param) and gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): setattr(param, IS_TENSOR_ZERO_PARALLEL, True) - elif not is_moe_param(param) and gpc.is_initialized(ParallelMode.WEIGHT) and tp_mode == "isp": + elif not is_moe_param(param) and gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp(): setattr(param, IS_WEIGHT_ZERO_PARALLEL, True) if not isinstance(model, nn.ModuleList): @@ -208,9 +207,7 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f # Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random # state in the same dp group are all the same. - random_mode = ( - ParallelMode.WEIGHT_DATA if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" else ParallelMode.DATA - ) + random_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA set_mode(random_mode) # if fsdp enabled, wrap the model @@ -262,9 +259,8 @@ def initialize_isp_communicator(model: Union[nn.Module, nn.ModuleList]): Returns: An isp communicator for managing comp/comm overlap and memory pool. """ - if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp": - isp_communicator = None - else: + isp_communicator = None + if is_using_isp(): isp_communicator = ISPCommunicator( model, ISPCommModelConfig( diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 2f57f11a..4980255a 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -5,7 +5,7 @@ from internlm.core.context.parallel_context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.model.utils import is_moe_param -from internlm.utils.parallel import is_tensor_data_parallel_parameter +from internlm.utils.parallel import is_tensor_data_parallel_parameter, is_using_isp def split_params_into_different_groups_for_optimizer( @@ -39,7 +39,7 @@ def split_params_into_different_groups_for_optimizer( # create new groups for IS_TENSOR_DATA_PARALLEL parameter group new_groups = {} - if isinstance(gpc.config.parallel["tensor"], dict) and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": + if is_using_isp(): new_groups["embed_head"] = {"name": "embed_head", "params": [], "optimizer_mode": ParallelMode.DATA} # create new groups for fp32 parameter group new_groups["fp32"] = {"name": "fp32", "params": [], "optimizer_mode": ParallelMode.ZERO1} diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 962f6415..9ace19ab 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -30,6 +30,7 @@ from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.parallel import is_using_isp from internlm.utils.storage_manager import ( get_fns, get_storage_manager, @@ -325,7 +326,7 @@ def save_model_checkpoint(folder, model): # even if pp is not considered, it will definitely not be written on the same machine. # for tensor parallel mode with isp - if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": + if is_using_isp(): if wdp_rank == 0 or dp_rank == 0: fn = f"model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt" fp = os.path.join(folder, fn) @@ -564,7 +565,7 @@ def load_model_checkpoint(folder, model): for fn in fns: if fn.startswith("model_t") and not fn.endswith(".md5"): segements = os.path.splitext(fn)[0].split("_") - if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": + if is_using_isp(): max_pp = max(max_pp, int(segements[-1][2:])) max_wp = max(max_wp, int(segements[-2][2:])) max_tp = max(max_tp, int(segements[-3][2:])) @@ -590,7 +591,7 @@ def load_model_checkpoint(folder, model): dp_size == max_zo + 1 ), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards" - if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": + if is_using_isp(): should_load_name = f"model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt" elif gpc.config.parallel.zero1.fsdp: should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt" @@ -702,7 +703,7 @@ def save_optimizer_checkpoint(optim, state_path): states = optim.state_dict() if isinstance(optim, HybridZeroOptimizer): - if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": + if is_using_isp(): fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt" llm_save(os.path.join(state_path, fp), states) else: @@ -752,7 +753,7 @@ def load_optimizer_checkpoint(folder, optim): max_tp, max_wp, max_pp, max_zero, max_dp = 0, 0, 0, 0, 0 for fn in fns: if fn.startswith("optimizer_") and not fn.endswith(".md5"): - if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": + if is_using_isp(): _, tp, wp, pp, dp = os.path.splitext(fn)[0].split("_") max_dp = max(max_dp, int(dp[2:])) max_tp = max(max_tp, int(tp[2:])) @@ -770,12 +771,12 @@ def load_optimizer_checkpoint(folder, optim): pp_size = gpc.get_world_size(ParallelMode.PIPELINE) dp_size = gpc.get_world_size(ParallelMode.DATA) - if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": + if is_using_isp(): assert dp_size == max_dp + 1, ( f"The optimizer states are save for {max_dp+1} data parallelism, " f"while current has {dp_size} data parallelism" ) - if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp": + if not is_using_isp(): assert zero_size == max_zero + 1, ( f"The optimizer states are save for {max_zero+1} zero parallel, " f"while current has {zero_size} zero broadcast range." @@ -795,7 +796,7 @@ def load_optimizer_checkpoint(folder, optim): wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) dp_rank = gpc.get_local_rank(ParallelMode.DATA) - if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": + if is_using_isp(): fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt" else: fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt" diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 5a491d33..76cd8d95 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -20,6 +20,10 @@ RMSNorm = try_import_RMSNorm() +def is_using_isp(): + return isinstance(gpc.config.parallel["tensor"], dict) and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" + + def is_replica_zero_parallel_parameter(p): return hasattr(p, IS_REPLICA_ZERO_PARALLEL) and getattr(p, IS_REPLICA_ZERO_PARALLEL) @@ -27,7 +31,7 @@ def is_replica_zero_parallel_parameter(p): def is_tensor_data_parallel_parameter(p): return ( gpc.is_initialized(ParallelMode.TENSOR) - and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" + and is_using_isp() and hasattr(p, IS_TENSOR_DATA_PARALLEL) and getattr(p, IS_TENSOR_DATA_PARALLEL) ) @@ -36,7 +40,7 @@ def is_tensor_data_parallel_parameter(p): def is_tensor_zero_parallel_parameter(p): return ( gpc.is_initialized(ParallelMode.TENSOR) - and gpc.config.parallel["tensor"].get("mode", "mtp") != "isp" + and not is_using_isp() and hasattr(p, IS_TENSOR_ZERO_PARALLEL) and getattr(p, IS_TENSOR_ZERO_PARALLEL) ) @@ -45,7 +49,7 @@ def is_tensor_zero_parallel_parameter(p): def is_weight_zero_parallel_parameter(p): return ( gpc.is_initialized(ParallelMode.WEIGHT) - and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" + and is_using_isp() and hasattr(p, IS_WEIGHT_ZERO_PARALLEL) and getattr(p, IS_WEIGHT_ZERO_PARALLEL) ) @@ -67,9 +71,7 @@ def sync_model_param(model): """ sync_moe_param = gpc.is_using_parallel_mode(ParallelMode.EXPERT_DATA) - sync_parallel_mode = ( - ParallelMode.WEIGHT_DATA if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" else ParallelMode.DATA - ) + sync_parallel_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA for param in model.parameters(): if sync_moe_param and getattr(param, "is_expert", False): ranks = gpc.get_ranks_in_group(ParallelMode.EXPERT_DATA) @@ -90,9 +92,7 @@ def sync_model_replica_param_group(model): model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. """ - parallel_mode = ( - ParallelMode.WEIGHT if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" else ParallelMode.TENSOR - ) + parallel_mode = ParallelMode.WEIGHT if is_using_isp() else ParallelMode.TENSOR if gpc.is_using_parallel_mode(parallel_mode): for param in model.parameters(): if is_replica_zero_parallel_parameter(param):