Skip to content

Commit

Permalink
feat(utils/parallel.py): add func is_using_isp
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Jan 29, 2024
1 parent e74f2dd commit 011edcf
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 38 deletions.
6 changes: 2 additions & 4 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import is_using_isp
from internlm.utils.timeout import llm_timeout

from .base_optimizer import BaseOptimizer
Expand Down Expand Up @@ -85,10 +86,7 @@ def __init__(
clip_grad_norm = zero_cfg.clip_grad_norm
self._overlap_sync_grad = zero_cfg.overlap_sync_grad
self._overlap_sync_param = zero_cfg.overlap_sync_param
self.use_isp = (
isinstance(gpc.config.parallel["tensor"], dict)
and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp"
)
self.use_isp = is_using_isp()

super().__init__(optim=optimizer)

Expand Down
5 changes: 2 additions & 3 deletions internlm/solver/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
is_tensor_data_parallel_parameter,
is_tensor_expert_data_parallel_parameter,
is_tensor_zero_parallel_parameter,
is_using_isp,
is_weight_zero_parallel_parameter,
)

Expand Down Expand Up @@ -312,9 +313,7 @@ def compute_norm(
Total norm of the parameters, need total_norm**(1/norm) before using.
"""

weight_parallel_mode = (
ParallelMode.WEIGHT if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" else ParallelMode.TENSOR
)
weight_parallel_mode = ParallelMode.WEIGHT if is_using_isp() else ParallelMode.TENSOR
enable_cuda_kernels = gradients[0].device.type == "cuda"
# Norm parameters.
norm_type = float(norm_type)
Expand Down
20 changes: 8 additions & 12 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import (
is_using_isp,
is_replica_zero_parallel_parameter,
is_tensor_data_parallel_parameter,
is_tensor_expert_data_parallel_parameter,
Expand Down Expand Up @@ -105,8 +106,6 @@ def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]):


def set_parallel_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]):
tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp")

def _check_module(module):
# layer_norm
if isinstance(module, (RMSNorm, nn.LayerNorm)):
Expand All @@ -120,9 +119,9 @@ def _check_module(module):
# embedding and head
if isinstance(module, (Embedding1D, ParallelGPT2Embeddings, BaseScaleColumnParallelLinear)):
for param in module.parameters():
if gpc.is_initialized(ParallelMode.TENSOR) and tp_mode == "isp":
if gpc.is_initialized(ParallelMode.TENSOR) and is_using_isp():
setattr(param, IS_TENSOR_DATA_PARALLEL, True)
elif gpc.is_initialized(ParallelMode.TENSOR) and tp_mode != "isp":
elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp():
setattr(param, IS_TENSOR_ZERO_PARALLEL, True)

# for linear module
Expand All @@ -131,9 +130,9 @@ def _check_module(module):
if gpc.is_initialized(ParallelMode.EXPERT_DATA) and is_moe_param(param):
# module should be MoE experts's linear
setattr(param, IS_TENSOR_EXPERT_DATA_PARALLEL, True)
elif not is_moe_param(param) and gpc.is_initialized(ParallelMode.TENSOR) and tp_mode != "isp":
elif not is_moe_param(param) and gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp():
setattr(param, IS_TENSOR_ZERO_PARALLEL, True)
elif not is_moe_param(param) and gpc.is_initialized(ParallelMode.WEIGHT) and tp_mode == "isp":
elif not is_moe_param(param) and gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp():
setattr(param, IS_WEIGHT_ZERO_PARALLEL, True)

if not isinstance(model, nn.ModuleList):
Expand Down Expand Up @@ -208,9 +207,7 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f

# Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random
# state in the same dp group are all the same.
random_mode = (
ParallelMode.WEIGHT_DATA if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" else ParallelMode.DATA
)
random_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA
set_mode(random_mode)

# if fsdp enabled, wrap the model
Expand Down Expand Up @@ -262,9 +259,8 @@ def initialize_isp_communicator(model: Union[nn.Module, nn.ModuleList]):
Returns:
An isp communicator for managing comp/comm overlap and memory pool.
"""
if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp":
isp_communicator = None
else:
isp_communicator = None
if is_using_isp():
isp_communicator = ISPCommunicator(
model,
ISPCommModelConfig(
Expand Down
4 changes: 2 additions & 2 deletions internlm/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from internlm.core.context.parallel_context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.model.utils import is_moe_param
from internlm.utils.parallel import is_tensor_data_parallel_parameter
from internlm.utils.parallel import is_tensor_data_parallel_parameter, is_using_isp


def split_params_into_different_groups_for_optimizer(
Expand Down Expand Up @@ -39,7 +39,7 @@ def split_params_into_different_groups_for_optimizer(

# create new groups for IS_TENSOR_DATA_PARALLEL parameter group
new_groups = {}
if isinstance(gpc.config.parallel["tensor"], dict) and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
if is_using_isp():
new_groups["embed_head"] = {"name": "embed_head", "params": [], "optimizer_mode": ParallelMode.DATA}
# create new groups for fp32 parameter group
new_groups["fp32"] = {"name": "fp32", "params": [], "optimizer_mode": ParallelMode.ZERO1}
Expand Down
17 changes: 9 additions & 8 deletions internlm/utils/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import is_using_isp
from internlm.utils.storage_manager import (
get_fns,
get_storage_manager,
Expand Down Expand Up @@ -325,7 +326,7 @@ def save_model_checkpoint(folder, model):
# even if pp is not considered, it will definitely not be written on the same machine.

# for tensor parallel mode with isp
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
if is_using_isp():
if wdp_rank == 0 or dp_rank == 0:
fn = f"model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt"
fp = os.path.join(folder, fn)
Expand Down Expand Up @@ -564,7 +565,7 @@ def load_model_checkpoint(folder, model):
for fn in fns:
if fn.startswith("model_t") and not fn.endswith(".md5"):
segements = os.path.splitext(fn)[0].split("_")
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
if is_using_isp():
max_pp = max(max_pp, int(segements[-1][2:]))
max_wp = max(max_wp, int(segements[-2][2:]))
max_tp = max(max_tp, int(segements[-3][2:]))
Expand All @@ -590,7 +591,7 @@ def load_model_checkpoint(folder, model):
dp_size == max_zo + 1
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"

if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
if is_using_isp():
should_load_name = f"model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt"
elif gpc.config.parallel.zero1.fsdp:
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
Expand Down Expand Up @@ -702,7 +703,7 @@ def save_optimizer_checkpoint(optim, state_path):

states = optim.state_dict()
if isinstance(optim, HybridZeroOptimizer):
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
if is_using_isp():
fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
llm_save(os.path.join(state_path, fp), states)
else:
Expand Down Expand Up @@ -752,7 +753,7 @@ def load_optimizer_checkpoint(folder, optim):
max_tp, max_wp, max_pp, max_zero, max_dp = 0, 0, 0, 0, 0
for fn in fns:
if fn.startswith("optimizer_") and not fn.endswith(".md5"):
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
if is_using_isp():
_, tp, wp, pp, dp = os.path.splitext(fn)[0].split("_")
max_dp = max(max_dp, int(dp[2:]))
max_tp = max(max_tp, int(tp[2:]))
Expand All @@ -770,12 +771,12 @@ def load_optimizer_checkpoint(folder, optim):
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
dp_size = gpc.get_world_size(ParallelMode.DATA)

if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
if is_using_isp():
assert dp_size == max_dp + 1, (
f"The optimizer states are save for {max_dp+1} data parallelism, "
f"while current has {dp_size} data parallelism"
)
if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp":
if not is_using_isp():
assert zero_size == max_zero + 1, (
f"The optimizer states are save for {max_zero+1} zero parallel, "
f"while current has {zero_size} zero broadcast range."
Expand All @@ -795,7 +796,7 @@ def load_optimizer_checkpoint(folder, optim):
wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
if is_using_isp():
fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
else:
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
Expand Down
18 changes: 9 additions & 9 deletions internlm/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@
RMSNorm = try_import_RMSNorm()


def is_using_isp():
return isinstance(gpc.config.parallel["tensor"], dict) and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp"


def is_replica_zero_parallel_parameter(p):
return hasattr(p, IS_REPLICA_ZERO_PARALLEL) and getattr(p, IS_REPLICA_ZERO_PARALLEL)


def is_tensor_data_parallel_parameter(p):
return (
gpc.is_initialized(ParallelMode.TENSOR)
and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp"
and is_using_isp()
and hasattr(p, IS_TENSOR_DATA_PARALLEL)
and getattr(p, IS_TENSOR_DATA_PARALLEL)
)
Expand All @@ -36,7 +40,7 @@ def is_tensor_data_parallel_parameter(p):
def is_tensor_zero_parallel_parameter(p):
return (
gpc.is_initialized(ParallelMode.TENSOR)
and gpc.config.parallel["tensor"].get("mode", "mtp") != "isp"
and not is_using_isp()
and hasattr(p, IS_TENSOR_ZERO_PARALLEL)
and getattr(p, IS_TENSOR_ZERO_PARALLEL)
)
Expand All @@ -45,7 +49,7 @@ def is_tensor_zero_parallel_parameter(p):
def is_weight_zero_parallel_parameter(p):
return (
gpc.is_initialized(ParallelMode.WEIGHT)
and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp"
and is_using_isp()
and hasattr(p, IS_WEIGHT_ZERO_PARALLEL)
and getattr(p, IS_WEIGHT_ZERO_PARALLEL)
)
Expand All @@ -67,9 +71,7 @@ def sync_model_param(model):
"""

sync_moe_param = gpc.is_using_parallel_mode(ParallelMode.EXPERT_DATA)
sync_parallel_mode = (
ParallelMode.WEIGHT_DATA if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" else ParallelMode.DATA
)
sync_parallel_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA
for param in model.parameters():
if sync_moe_param and getattr(param, "is_expert", False):
ranks = gpc.get_ranks_in_group(ParallelMode.EXPERT_DATA)
Expand All @@ -90,9 +92,7 @@ def sync_model_replica_param_group(model):
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
"""

parallel_mode = (
ParallelMode.WEIGHT if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" else ParallelMode.TENSOR
)
parallel_mode = ParallelMode.WEIGHT if is_using_isp() else ParallelMode.TENSOR
if gpc.is_using_parallel_mode(parallel_mode):
for param in model.parameters():
if is_replica_zero_parallel_parameter(param):
Expand Down

0 comments on commit 011edcf

Please sign in to comment.