diff --git a/doc/code-docs/source/parallel.rst b/doc/code-docs/source/parallel.rst index 211efe86..8a2a2455 100644 --- a/doc/code-docs/source/parallel.rst +++ b/doc/code-docs/source/parallel.rst @@ -10,41 +10,112 @@ InternEvo 的并行设置由配置文件中的 ``parallel`` 字段指定,用 .. code-block:: python parallel = dict( - zero1=8, - tensor=1, + zero1=dict(size=8), + tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - sequence_parallel=False, + weight=dict(size=1, overlap=True, memory_pool=True), ) - zero1:zero 并行策略,分如下三种情况,默认值为 -1 - - 当 ``zero1 <= 0``,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配 - - 当 ``zero1 == 1``,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数 - - 当 ``zero1 > 1`` 且 ``zero1 <= data_parallel_world_size``,则 zero1 进程组是数据并行进程组的子集 + - 当 ``size <= 0``,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配 + - 当 ``size == 1``,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数 + - 当 ``size > 1`` 且 ``zero1 <= data_parallel_size``,则 zero1 进程组是数据并行进程组的子集 + +- tensor:张量并行策略 + + - size:张量并行大小,通常是每个节点的 GPU 数量,默认值为 1 + - mode:张量并行模式,支持['mtp', 'msp', 'fsp', 'isp'],其中, + + - mtp:表示使用 Megatron-LM 的张量并行实现方案,不包含序列并行,mtp 为默认模式 + - msp:表示使用 Megatron-LM 的序列化并行实现方案,序列并行大小 = 张量并行大小 + - fsp:表示使用 flash-attn 实现模式的张量并行与序列并行,序列并行大小 = 张量并行大小 + - isp:InternEvo 系统自研的序列化并行方案,可以与权重并行结合使用,序列并行大小与权重并行大小互相独立 -- tensor:张量并行大小,通常是每个节点的 GPU 数量,默认值为 1 - pipeline:流水线并行策略 - size:流水线并行大小,默认值为 1 - interleaved_overlap:bool 类型,交错式调度时,开启或关闭通信优化,默认值为 False -- sequence_parallel:是否开启序列化并行,默认值为 False +- weight:权重并行策略,只能与 isp 张量并行模式结合使用 + + - size:权重并行大小,默认值为 1 + - overlap:是否开启计算与通信的 overlap,默认值为 False + - memory_pool:是否开启权重显存池,默认值为 False 注意:数据并行大小 = 总的 GPU 数目 / 流水线并行大小 / 张量并行大小 张量并行 ----------------- -InternEvo 的张量并行实现方案基于 `flash attention `_, 主要对 `attention `_ 和 -`linear `_ 这两个模块进行张量并行操作。 +InternEvo 系统 ``v0.3.0`` 版本在张量并行策略上有较大更新,目前的张量并行支持四种模式配置['mtp', 'msp', 'fsp', 'isp'],前三种模式为基于 Megatron-LM 的张量并行和序列化并行的策略实现,最后一种模式为 InternEvo 系统自研,可与权重并行(Weight Parallel)结合使用的一种新的序列化并行方式。接下来详细介绍这几种张量并行模式的区别。 + +- MTP -用户可通过配置文件中的 ``parallel.tensor`` 字段来设置张量并行大小。 +MTP(Megatron-LM Tensor Parallel), 为默认的张量并行模型,引用自 `Megatron-LM Tensor Parallel `_ 并行方案,如下图所示,带有张量并行的 ``Transformer`` 层: -.. figure:: ../../imgs/tensor_parallel.png +.. figure:: ../../imgs/mtp.png :scale: 50% :class: with-border - 张量并行,采用自 `flash-attention `_ + Transformer layer with tensor parallelism. Adopted from `Megatron-LM Tensor Parallel `_ + +MTP 主要对 `attention `_ 和 `linear `_ 这两个模块进行张量并行操作。假设张量并行大小为 ``tp``,输入数据的序列长度为 ``seqlen``,隐藏层大小为 ``hidden size``,则张量并行期间产生的激活值 ``shape`` 为 ``[seqlen, hidden_size/tp]``。 + +MTP 张量并行引入的通信如上图所示,其中 ``f`` 和 ``f̄`` 是共轭的。在前向传递中 ``f`` 是无操作,而在反向传递中进行 ``all-reduce`` 操作。而 ``f̄`` 在前向传递中进行 ``all-reduce`` 操作,在反向传递中是无操作。 + +- MSP + +MSP(Megatron-LM Sequence Parallel),引用自 `Megatron-LM Sequence Parallel `_ 并行方案,如下图所示,带有张量并行和序列化并行的 ``Transformer`` 层: + +.. figure:: ../../imgs/msp.png + :scale: 50% + :class: with-border + + Transformer layer with tensor and sequence parallelism. Adopted from `Megatron-LM Sequence Parallel `_ + +与 MTP 对比,我们可以发现,MSP 主要针对未进行张量并行的模块,如 ``LayerNorm`` 和 ``Dropout`` 等模型进行序列化并行操作。需要注意的是,序列化并行大小与张量并行大小相等,且共用通信组。假设张量并行大小为 ``tp``,输入数据的序列长度为 ``seqlen``,隐藏层大小为 ``hidden size``,则序列化并行期间产生的激活值形状为 ``[seqlen/tp, hidden_size]``,张量并行期间产生的激活值形状为 ``[seqlen, hidden_size/tp]``。 + +MSP与MTP相比,通信原语有所变化,如上图所示 ``g`` 和 ``ḡ`` 是共轭的。在前向传递中 ``g`` 进行 ``all-gather`` 操作,而在反向传递中进行 ``reduce-scatter`` 操作。而 ``ḡ`` 在前向传递中进行 ``reduce-scatter`` 操作,在反向传递中进行 ``all-gather`` 操作。 + +在前向传递中 ``g`` 通信处于序列化并行和张量并行的交接处,进行的是激活值在 ``seqlen`` 维度的 ``all-gather`` 操作,该通信完成后,激活值形状变成完整的 ``[seqlen, hidden_size]``,然后进入张量并行模块范围。``ḡ`` 通信处于张量并行和序列化并行的交接处,需要把 MTP 中的 ``all-reduce`` 通信操作变成 ``reduce-scatter``,才能完成 ``seqlen`` 维度的切分,激活值形状变成 ``[seqlen/tp, hidden_size]``,从而正常进入序列化并行的阶段。而反向传递时,则是同样的道理。 + +- FSP + +FSP(Flash-Attn Sequence Parallel),引用自 `flash attention `_ 的序列化并行实现方案。该实现方案与 MSP 的唯一区别在于,在 ``g`` 进行 ``all-gather`` 通信后,MSP 会存储一份完整的输入数据用于 backward 计算,而 FSP 则只存储 ``seqlen`` 切分后的输入数据,因此在进行 backward 计算时,需要再额外 ``all-gather`` 一次输入。 + +因此,FSP 与 MSP 性能对比的话,会有更小的显存占用,但是由于引入额外的 ``all-gather`` 通信,会导致训练速度 TGS 降低。 + + +- ISP + +ISP(Intern Sequence Parallel),InternEvo 系统自研的灵活可扩展序列化并行方案,支持张量并行与序列化并行解耦,通过计算和通信的overlap提高训练性能,并基于显存池管理降低显存碎片化的可能性,提高显存利用率。 + +以 `configs/7B_isp_sft.py `_ 配置文件为例,将 ``tensor.mode`` 字段设置为 ``isp``,而 ``tensor.size`` 字段代表的是数据 ``seqlen`` 维度切分大小。ISP 算法可与 ``weight parallel`` 结合使用,其中 ``weight.size`` 字段代表的是模型权重切分大小,将 ``weight.overlap`` 字段设置为 ``True`` 即为开启计算与通信的 ``overlap``,可提高训练性能。将 ``weight.memory_pool`` 字段设置为 ``True`` 即为开启显存池管理功能,可一定程度降低 GPU 显存碎片的可能性,提高显存利用率。 + +.. code-block:: python + + parallel = dict( + zero1=dict(size=-1), + tensor=dict(size=2, mode="isp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=4, overlap=True, memory_pool=True), + ) + +如下图所示,带有序列化并行和权重并行的 ``Transformer`` 层: + +.. figure:: ../../imgs/isp.png + :scale: 30% + :class: with-border + +如图所示,ISP 的序列化并行范围覆盖整个 ``Transformer`` 模型层,模型权重并行主要针对 ``Attention`` 和 ``MLP Block`` 的 ``Linear module``。 + +通信原语变化情况为,在前向传递时,每个 ``Linear`` 需要进行模型权重的 ``all-gather`` 通信;在后向传递时,每个 ``Linear`` 在进行后向计算前需要进行模型权重的 ``all-gather`` 通信,在后向计算后,需要进行模型权重的梯度的 ``reduce-scatter`` 通信操作。 + +需要注意的是,与 MSP 和 FSP 相比,在进行 ``attention score`` 计算时,ISP 也有通信原语的一些变化,如 ``Self-Atten`` 前后各增加了一个 ``all-to-all`` 通信操作,用于完成激活值形状的转置,目的是在进行 ``attention score`` 计算时能保持原有的张量并行的模式。 + +关于 ISP 算法更多的设计思路和性能评测,请参考论文 `InternEvo: Efficient Long-sequence Large Language Model Training via Hybrid Parallelism and Redundant Sharding `_。 + 流水线并行 ----------------- @@ -55,7 +126,7 @@ InternEvo 在流水线并行中使用 `1F1B `_ @@ -97,20 +168,6 @@ InternEvo 在流水线并行中使用 `1F1B `_。这个并行策略有助于降低模型的内存消耗,提高了模型在资源受限环境中的可扩展性。 - -如果要启用序列并行, 用户需要设置 ``parallel.sequence_parallel = True``。 - -.. figure:: ../../imgs/sequence_parallel.png - :scale: 50% - :class: with-border - - 序列并行, 采用自 flash-attention - 数据并行 ----------------- diff --git a/doc/imgs/isp.png b/doc/imgs/isp.png new file mode 100644 index 00000000..5e308137 Binary files /dev/null and b/doc/imgs/isp.png differ diff --git a/doc/imgs/msp.png b/doc/imgs/msp.png new file mode 100644 index 00000000..c0384206 Binary files /dev/null and b/doc/imgs/msp.png differ diff --git a/doc/imgs/mtp.png b/doc/imgs/mtp.png new file mode 100644 index 00000000..16f111a0 Binary files /dev/null and b/doc/imgs/mtp.png differ