Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
blankde committed Feb 5, 2024
1 parent 6d3855a commit 074567c
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 28 deletions.
14 changes: 10 additions & 4 deletions configs/7B_MoE4_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,23 @@
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
weight parallel (dict):
1. size: int, the size of weight parallel.
1. size: int, the size of weight parallel for non-moe module.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
expert parallel (dict):
1. size: int, the size of expert parallel, each device would save {num_expert/ep_size} local experts.
expert parallel (dict):
1. size: int, the size of weight parallel for each expert module. the overlap and memory_pool would
inherit from weight parallel setting.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=1, mode="mtp"),
tensor=dict(size=1, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=8, overlap=False, memory_pool=True),
expert=dict(size=4),
expert_weight=dict(size=1, overlap=False, memory_pool=True),
expert_weight=dict(size=1),
)

cudnn_deterministic = False
Expand Down
16 changes: 9 additions & 7 deletions internlm/core/communication/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,11 +451,11 @@ def switch_current_model_chunk(self, chunk_id: int) -> None:
# communication operation interfaces

def all_gather(self, tensor: torch.Tensor, module: nn.Module, is_bias: bool = False):
if dist.get_world_size(self.process_group) <= 1:
if dist.get_world_size(module.process_group) <= 1:
return tensor

if not self.overlap:
result, _ = all_gather_raw(tensor, self.process_group, async_op=False)
result, _ = all_gather_raw(tensor, module.process_group, async_op=False)
elif is_bias:
result = self._bias_global_output[module]
else:
Expand All @@ -470,11 +470,11 @@ def reduce_scatter(
op: dist.ReduceOp,
is_bias: bool = False,
):
if dist.get_world_size(self.process_group) <= 1:
if dist.get_world_size(model.process_group) <= 1:
return tensor, None

if not self.overlap:
result, handle = reduce_scatter_raw(tensor, self.process_group, op=op, async_op=True)
result, handle = reduce_scatter_raw(tensor, model.process_group, op=op, async_op=True)
else:
if is_bias:
assert hasattr(model.bias, "isp_reduce_scatter_name")
Expand All @@ -485,16 +485,18 @@ def reduce_scatter(

self.reduce_scatter_handlers[key] = reduce_scatter_raw(
tensor,
self.process_group,
model.process_group,
op=op,
async_op=True,
memory_pool_allocator=self.memory_pool.allocate_reduce_scatter_memory,
memory_pool_allocator=self.memory_pool.allocate_reduce_scatter_memory
if self.enable_memory_pool
else None,
)

result, handle = (
self._get_constant_zero(
(
tensor.shape[0] // dist.get_world_size(self.process_group),
tensor.shape[0] // dist.get_world_size(model.process_group),
*tensor.shape[1:],
)
),
Expand Down
2 changes: 2 additions & 0 deletions internlm/core/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
IS_TENSOR_DATA_PARALLEL,
IS_TENSOR_EXPERT_DATA_PARALLEL,
IS_TENSOR_ZERO_PARALLEL,
IS_WEIGHT_EXPERT_DATA_PARALLEL,
IS_WEIGHT_ZERO_PARALLEL,
Config,
ParallelContext,
Expand Down Expand Up @@ -36,6 +37,7 @@
"IS_REPLICA_ZERO_PARALLEL",
"IS_WEIGHT_ZERO_PARALLEL",
"IS_TENSOR_EXPERT_DATA_PARALLEL",
"IS_WEIGHT_EXPERT_DATA_PARALLEL",
"global_context",
"ParallelContext",
"ParallelMode",
Expand Down
5 changes: 3 additions & 2 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
IS_TENSOR_ZERO_PARALLEL = "is_tensor_zero_parallel"
IS_WEIGHT_ZERO_PARALLEL = "is_weight_zero_parallel"
IS_TENSOR_EXPERT_DATA_PARALLEL = "is_tensor_expert_data_parallel"
IS_WEIGHT_EXPERT_DATA_PARALLEL = "is_weight_expert_data_parallel"

logger = get_logger(__file__)

Expand Down Expand Up @@ -448,7 +449,7 @@ def check_sanity(self):
eps = self.expert_parallel_size
ewps = self.expert_weight_parallel_size
edps = self.expert_data_parallel_size
if self.config.parallel["tensor"]["mode"] == "isp":
if isinstance(self.config.parallel["tensor"], dict) and self.config.parallel["tensor"]["mode"] == "isp":
assert ws == eps * edps * ewps * pps, (
f"Expected the world size {ws} to be equal to expert parallel "
f"size ({eps}) * expert data parallel size ({edps}) * expert "
Expand Down Expand Up @@ -508,7 +509,7 @@ def init_parallel_groups(self):
if "expert" not in parallel_config:
parallel_config._add_item("expert", dict(size=1))
if "expert_weight" not in parallel_config:
parallel_config._add_item("expert_weight", dict(size=1, overlap=False, memory_pool=False))
parallel_config._add_item("expert_weight", dict(size=1))

# get value from config
self._set_parallel_size_from_config(parallel_config, "weight", "weight_parallel_size")
Expand Down
18 changes: 7 additions & 11 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def args_sanity_check():
gpc.config.parallel._add_item("expert", dict(size=1))

if "expert_weight" not in gpc.config.parallel:
gpc.config.parallel._add_item("expert_weight", dict(size=1, overlap=False, memory_pool=False))
gpc.config.parallel._add_item("expert_weight", dict(size=1))

if isinstance(gpc.config.parallel.pipeline, int):
pp = gpc.config.parallel.pipeline
Expand Down Expand Up @@ -353,14 +353,6 @@ def args_sanity_check():
if gpc.config.parallel["tensor"]["mode"] != "isp":
assert gpc.config.parallel["weight"]["size"] <= 1, "weight parallel is only supported with isp"

# set default value for expert weight parallel
if gpc.config.parallel["expert_weight"].get("overlap", None) is None:
gpc.config.parallel["expert_weight"]["overlap"] = False
if gpc.config.parallel["expert_weight"].get("memory_pool", None) is None:
gpc.config.parallel["expert_weight"]["memory_pool"] = False
if gpc.config.parallel["tensor"]["mode"] != "isp":
assert gpc.config.parallel["expert_weight"]["size"] <= 1, "expert weight parallel is only supported with isp"

# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
assert (
Expand Down Expand Up @@ -422,10 +414,14 @@ def args_sanity_check():
gpc.get_world_size(ParallelMode.DATA),
), "moe only support zero1, set zero1=dict(size=-1,...) can fix this"
if gpc.config.parallel["tensor"]["mode"] == "isp":
assert gpc.config.parallel["expert_weight"]["overlap"] is False
assert gpc.config.parallel["weight"]["overlap"] is False
if gpc.config.parallel["tensor"]["mode"] != "isp":
assert (
gpc.config.parallel["expert_weight"]["size"] <= 1
), "expert weight parallel is only supported with isp"
else:
assert (
gpc.config.parallel["expert"]["size"] == 1 and gpc.config.parallel["expert_weight"]["size"] == 1
gpc.config.parallel["expert"]["size"] <= 1 and gpc.config.parallel["expert_weight"]["size"] <= 1
), "expert parallel is only supported in MoE setting"


Expand Down
2 changes: 1 addition & 1 deletion internlm/moe/gshard_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def __init__(
if isinstance(gpc.config.parallel["tensor"], dict)
else "mtp"
)
parallel_mode = ParallelMode.WEIGHT if tp_mode == "isp" else ParallelMode.TENSOR
parallel_mode = ParallelMode.EXPERT_WEIGHT if tp_mode == "isp" else ParallelMode.TENSOR
mlp_cls = get_mlp_cls(tp_mode)
super().__init__(
TopKGate(
Expand Down
4 changes: 4 additions & 0 deletions internlm/solver/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
is_tensor_expert_data_parallel_parameter,
is_tensor_zero_parallel_parameter,
is_using_isp,
is_weight_expert_data_parallel_parameter,
is_weight_zero_parallel_parameter,
)

Expand Down Expand Up @@ -287,6 +288,9 @@ def append_grad(g, p):
elif is_tensor_expert_data_parallel_parameter(p):
# process all ranks for IS_TENSOR_EXPERT_DATA_PARALLEL parameter group
append_grad(g, p)
elif is_weight_expert_data_parallel_parameter(p):
# process all ranks for IS_TENSOR_EXPERT_DATA_PARALLEL parameter group
append_grad(g, p)
elif gpc.get_local_rank(weight_parallel_mode) != 0:
continue
else:
Expand Down
9 changes: 6 additions & 3 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
IS_TENSOR_DATA_PARALLEL,
IS_TENSOR_EXPERT_DATA_PARALLEL,
IS_TENSOR_ZERO_PARALLEL,
IS_WEIGHT_EXPERT_DATA_PARALLEL,
IS_WEIGHT_ZERO_PARALLEL,
ParallelMode,
)
Expand Down Expand Up @@ -83,6 +84,7 @@
is_tensor_expert_data_parallel_parameter,
is_tensor_zero_parallel_parameter,
is_using_isp,
is_weight_expert_data_parallel_parameter,
is_weight_zero_parallel_parameter,
set_model_params_layer_name,
sync_model_param,
Expand Down Expand Up @@ -127,9 +129,10 @@ def _check_module(module):
# for linear module
if isinstance(module, (ColumnParallelLinear, RowParallelLinear)):
for param in module.parameters():
if gpc.is_initialized(ParallelMode.EXPERT_DATA) and is_moe_param(param):
# module should be MoE experts's linear
if is_moe_param(param) and gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp():
setattr(param, IS_TENSOR_EXPERT_DATA_PARALLEL, True)
elif is_moe_param(param) and gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp():
setattr(param, IS_WEIGHT_EXPERT_DATA_PARALLEL, True)
elif not is_moe_param(param) and gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp():
setattr(param, IS_TENSOR_ZERO_PARALLEL, True)
elif not is_moe_param(param) and gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp():
Expand All @@ -153,6 +156,7 @@ def _check_module(module):
or is_tensor_zero_parallel_parameter(param)
or is_weight_zero_parallel_parameter(param)
or is_tensor_expert_data_parallel_parameter(param)
or is_weight_expert_data_parallel_parameter(param)
), f"parameter with name:{name} has no parallel attribution."


Expand Down Expand Up @@ -270,7 +274,6 @@ def initialize_isp_communicator(model: Union[nn.Module, nn.ModuleList]):
),
gpc.config.parallel.weight.overlap,
gpc.config.parallel.weight.memory_pool,
gpc.get_group(ParallelMode.WEIGHT),
)
# register communicator for isp linear.
ISPLinear.register_communicator(isp_communicator)
Expand Down
9 changes: 9 additions & 0 deletions internlm/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
IS_TENSOR_DATA_PARALLEL,
IS_TENSOR_EXPERT_DATA_PARALLEL,
IS_TENSOR_ZERO_PARALLEL,
IS_WEIGHT_EXPERT_DATA_PARALLEL,
IS_WEIGHT_ZERO_PARALLEL,
ParallelMode,
)
Expand Down Expand Up @@ -63,6 +64,14 @@ def is_tensor_expert_data_parallel_parameter(p):
)


def is_weight_expert_data_parallel_parameter(p):
return (
gpc.is_initialized(ParallelMode.TENSOR)
and hasattr(p, IS_WEIGHT_EXPERT_DATA_PARALLEL)
and getattr(p, IS_WEIGHT_EXPERT_DATA_PARALLEL)
)


def sync_model_param(model):
r"""Make sure data parameters are consistent during Data Parallel Mode.
Expand Down

0 comments on commit 074567c

Please sign in to comment.