Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pipeline): Zero Bubble V Shape Memory Efficient Editon #357

Merged
merged 11 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/7B_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,15 @@
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
3. mode: str, the pipeline parallel mode, should be in ['1f1b', 'zbh1', 'zbv']. The defalut is 1f1b.
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=2, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True, mode="1f1b"),
weight=dict(size=2, overlap=True),
)

Expand Down
3 changes: 1 addition & 2 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,14 @@
1. size: int, the size of pipeline parallel (Default is 1F1B).
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
3. zero_bubble: bool, enable/disable zero bubble pipeline parallelism (ZB-H1), defaults to False.
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True, zero_bubble=False),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True),
)

Expand Down
8 changes: 4 additions & 4 deletions doc/code-docs/locales/en/LC_MESSAGES/parallel.po
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ msgstr ""
msgid "返回类型"
msgstr "Return type"

#: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler.forward_backward_step:19
#: internlm.core.scheduler.pipeline_scheduler.PipelineScheduler.forward_backward_step:19
#: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler.forward_backward_step:19
#: internlm.core.scheduler.pipeline_scheduler_1f1b.PipelineScheduler.forward_backward_step:19
#: of
msgid "Tuple[:class:`torch.Tensor`]"
msgstr ""
Expand All @@ -579,11 +579,11 @@ msgstr ""
"To use interleaved pipeline scheduler, users need to set "
"``model.num_chunks > 1`` in the config file."

#: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler:1 of
#: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler:1 of
msgid "Interleaved Pipeline Scheduler."
msgstr ""

#: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler.forward_backward_step:1
#: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler.forward_backward_step:1
#: of
msgid ""
"Run interleaved 1F1B schedule (model split into model chunks), with "
Expand Down
4 changes: 2 additions & 2 deletions doc/code-docs/source/parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ InternEvo 在流水线并行中使用 `1F1B <https://arxiv.org/pdf/2104.04473.pd
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
如果要使用非交错式调度, 需要设置 ``model.num_chunks = 1`` 。

.. autoclass:: internlm.core.scheduler.pipeline_scheduler.PipelineScheduler
.. autoclass:: internlm.core.scheduler.pipeline_scheduler_1f1b.PipelineScheduler
:members:

交错式流水线调度
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
如果要使用交错式调度, 需要设置 ``model.num_chunks > 1`` 。

.. autoclass:: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler
.. autoclass:: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler
:members:

值得注意的是,在使用交错式流水线调度器时可启用通信优化功能,即在 1F1B 阶段启用异步通信,以充分利用上行/下行带宽并实现通信与计算重叠。
Expand Down
3 changes: 2 additions & 1 deletion doc/en/structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ The system code file structure is shown below:
│ │ │ └── process_group_initializer.py
│ │ ├── scheduler # Scheduling module, which manages schedulers for parallel training, including non-pipeline and pipeline parallel schedulers
│ │ │ ├── no_pipeline_scheduler.py
│ │ │ └── pipeline_scheduler.py
│ │ │ ├── pipeline_scheduler_1f1b.py
│ │ │ └── pipeline_scheduler_zb.py
│ │ ├── engine.py # Responsible for managing the training and evaluation process of the model
│ │ └── trainer.py # Responsible for managing the training engine and scheduler
│ ├── data # Data module, responsible for managing dataset generation and processing
Expand Down
3 changes: 2 additions & 1 deletion doc/structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
│ │ │ └── process_group_initializer.py
│ │ ├── scheduler # 调度模块,管理并行训练的调度器,包括非流水线并行调度器和流水线并行调度器
│ │ │ ├── no_pipeline_scheduler.py
│ │ │ └── pipeline_scheduler.py
│ │ │ ├── pipeline_scheduler_1f1b.py
│ │ │ └── pipeline_scheduler_zb.py
│ │ ├── engine.py # 负责管理模型的训练和评估过程
│ │ └── trainer.py # 负责管理训练引擎和调度器
│ ├── data # 数据模块,负责管理数据集生成和处理
Expand Down
18 changes: 15 additions & 3 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(self):
self.virtual_pipeline_parallel_rank = None
self._expert_parallel_group_names = []
self.is_evaluating = False
self.v_shape = False

@property
def config(self):
Expand Down Expand Up @@ -292,8 +293,13 @@ def is_rank_for_log(self):
and self.is_first_rank(ParallelMode.WEIGHT)
and self.is_first_rank(ParallelMode.DATA)
and self.is_first_rank(ParallelMode.WEIGHT_DATA)
and self.is_last_rank(ParallelMode.PIPELINE)
)

if not self.v_shape:
is_log_rank = is_log_rank and self.is_last_rank(ParallelMode.PIPELINE)
else:
is_log_rank = is_log_rank and self.is_first_rank(ParallelMode.PIPELINE)

return is_log_rank

def is_last_rank(self, parallel_mode: ParallelMode):
Expand Down Expand Up @@ -327,11 +333,17 @@ def is_pipeline_last_stage(self, ignore_virtual=False):
and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1
):
return False
return self.is_last_rank(ParallelMode.PIPELINE)
if not self.v_shape:
return self.is_last_rank(ParallelMode.PIPELINE)
else:
return self.is_first_rank(ParallelMode.PIPELINE)

def is_no_pp_or_last_stage(self):
# NOTICE!!!, this will ignore virutal stage
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_last_rank(ParallelMode.PIPELINE)
if not self.v_shape:
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_last_rank(ParallelMode.PIPELINE)
else:
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_first_rank(ParallelMode.PIPELINE)

def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`.
Expand Down
6 changes: 6 additions & 0 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,12 @@ def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W06
if self._isp_communicator and self._isp_communicator.overlap:
self._zero_optim.accumulate_left_grads_after_backward()

if (
getattr(gpc.config.parallel["pipeline"], "mode", "1F1B").upper() in ["ZBV", "ZBH1"]
and not self._zero_optim.skip_grad_reduce
):
self._zero_optim.reduce_left_grads_after_backward()

def post_helper_func(self, scheduler, outputs, label) -> None: # pylint: disable=W0613
pass

Expand Down
14 changes: 10 additions & 4 deletions internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,16 @@ def partition_uniform(num_items: int, pipeline_parallel_size: int, num_chunks: i
if chunk_size == 0:
raise ValueError("Some nodes in Pipeline have no requests")

for p in range(pipeline_parallel_size):
st = base_idx
base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx))
if getattr(gpc.config.parallel["pipeline"], "mode", "1F1B").upper() == "ZBV" and idx == 1:
for p in range(pipeline_parallel_size - 1, -1, -1):
st = base_idx
base_idx += chunk_size + ((pipeline_parallel_size - p - 1) >= left)
parts[p].append((st, base_idx))
else:
for p in range(pipeline_parallel_size):
st = base_idx
base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx))

indexes = []
for _parts in parts:
Expand Down
7 changes: 4 additions & 3 deletions internlm/core/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .base_scheduler import BaseScheduler
from .no_pipeline_scheduler import NonPipelineScheduler
from .pipeline_scheduler import (
InterleavedPipelineScheduler,
PipelineScheduler,
from .pipeline_scheduler_1f1b import InterleavedPipelineScheduler, PipelineScheduler
from .pipeline_scheduler_zb import (
ZeroBubblePipelineScheduler,
ZeroBubblePipelineVShapeScheduler,
)

__all__ = [
Expand All @@ -12,4 +12,5 @@
"InterleavedPipelineScheduler",
"PipelineScheduler",
"ZeroBubblePipelineScheduler",
"ZeroBubblePipelineVShapeScheduler",
]
6 changes: 2 additions & 4 deletions internlm/core/scheduler/comm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from .p2p import (
AsynCommunicator,
fused_send_recv_tensor,
recv_backward,
recv_forward,
send_backward,
send_backward_and_recv_next_backward_async,
send_backward_recv_backward,
send_backward_recv_forward,
send_forward,
send_forward_and_recv_next_forward_async,
send_forward_backward_recv_forward_backward,
send_forward_recv_backward,
send_forward_recv_forward,
Expand All @@ -26,7 +25,6 @@
"recv_forward",
"send_obj_meta",
"recv_obj_meta",
"send_backward_and_recv_next_backward_async",
"send_forward_and_recv_next_forward_async",
"AsynCommunicator",
"fused_send_recv_tensor",
]
Loading
Loading