Skip to content

Commit 059c4d4

Browse files
flybird11111ver217digger-yubinmakeswellyuanheng-zhao
authored andcommitted
[shardformer] refactor embedding resize (hpcaitech#5603)
* [branch rebase] rebase main to Feature/resize_embedding (hpcaitech#5554) * fix * [release] update version (hpcaitech#5411) * [hotfix] fix typo s/keywrods/keywords etc. (hpcaitech#5429) * [devops] fix compatibility (hpcaitech#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (hpcaitech#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (hpcaitech#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (hpcaitech#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (hpcaitech#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [CI] run pre-commit (hpcaitech#5577) * fix * [release] update version (hpcaitech#5411) * [hotfix] fix typo s/keywrods/keywords etc. (hpcaitech#5429) * [devops] fix compatibility (hpcaitech#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (hpcaitech#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (hpcaitech#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (hpcaitech#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (hpcaitech#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme * run pre-commit --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [rebase] rebase main to resize-embedding (hpcaitech#5581) * [release] grok-1 314b inference (hpcaitech#5490) * [release] grok-1 inference * [release] grok-1 inference * [release] grok-1 inference * [example] update Grok-1 inference (hpcaitech#5495) * revise grok-1 example * remove unused arg in scripts * prevent re-installing torch * update readme * revert modifying colossalai requirements * add perf * trivial * add tokenizer url * [hotfix] set return_outputs=False in examples and polish code (hpcaitech#5404) * fix: simplify merge_batch * fix: use return_outputs=False to eliminate extra memory consumption * feat: add return_outputs warning * style: remove `return_outputs=False` as it is the default value * [release] grok-1 inference benchmark (hpcaitech#5500) * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [shardformer]Fix lm parallel. (hpcaitech#5480) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * fix lm forward distribution * fix * test ci * fix * [fix] fix grok-1 example typo (hpcaitech#5506) * [devops] fix example test ci (hpcaitech#5504) * Fix ColoTensorSpec for py11 (hpcaitech#5440) * fixed layout converter caching and updated tester * Empty-Commit * [shardformer] update colo attention to support custom mask (hpcaitech#5510) * [feature] refactor colo attention (hpcaitech#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (hpcaitech#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (hpcaitech#5514) * [test] fix shardformer tests * [test] fix shardformer tests * [format] applied code formatting on changed files in pull request 5510 (hpcaitech#5517) Co-authored-by: github-actions <github-actions@github.com> * [shardformer] fix pipeline forward error if custom layer distribution is used (hpcaitech#5189) * Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [Fix] Grok-1 use tokenizer from the same pretrained path (hpcaitech#5532) * [fix] use tokenizer from the same pretrained path * trust remote code * [ColossalChat] Update RLHF V2 (hpcaitech#5286) * Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com> * [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (hpcaitech#5508) * feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests * fix incorrect sharding without zero (hpcaitech#5545) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [shardformer] Sequence Parallelism Optimization (hpcaitech#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (hpcaitech#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (hpcaitech#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (hpcaitech#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * [hotfix] quick fixes to make legacy tutorials runnable (hpcaitech#5559) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [fix] fix typo s/muiti-node /multi-node etc. (hpcaitech#5448) * [hotfix] fix typo s/get_defualt_parser /get_default_parser (hpcaitech#5548) * [devops] remove post commit ci (hpcaitech#5566) * [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]enable padding vocabulary size. (hpcaitech#5489) * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * padding vocab * padding vocabe * fix * fix * fxi * test ci * fix fix fix fix * fix fix * fix * fix * Update hybrid_parallel_plugin.py fix fix fix * fix fix * fix fix * fix * resolve super init resolve super init resolve super init resolve super init * resolve comments * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * vocab checkpointio * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix fix fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * padding vocab * fix * fix fix * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * cherry-pick * revert moe modify * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix fix fix fix fix fix fix fix * resolve comments resolve comments resolve comments resolve comments resolve comments * ptensor ptensor resolve comments fix fix fix fix fix resolve comments resolve comments resolve comments resolve comments resolve comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix rebase * fix rebase --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 87d09ab commit 059c4d4

36 files changed

+1347
-284
lines changed

colossalai/booster/plugin/gemini_plugin.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@
4444
def get_param_info(optim: Optimizer):
4545
# Get a backup of necessary information of parameters for future use, which includes:
4646
# 1. A mapping from integer param_id to param32 shape.
47-
4847
if optim is None:
4948
return {}
5049
param_info = {"id2shape": {}}
50+
5151
start_index = 0
5252
for group in optim.param_groups:
5353
for param_id, param in enumerate(group["params"], start_index):
@@ -527,7 +527,7 @@ def configure(
527527
dataloader: Optional[DataLoader] = None,
528528
lr_scheduler: Optional[LRScheduler] = None,
529529
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
530-
optimizer_params_info = get_param_info(optimizer)
530+
params_info = get_param_info(optimizer)
531531
if not isinstance(model, ModelWrapper):
532532
# convert model to sync bn
533533
# FIXME(ver217): gemini does not support sync bn
@@ -558,7 +558,7 @@ def configure(
558558
**self.zero_optim_config,
559559
**self.optim_kwargs,
560560
tp_group=self.tp_group,
561-
optimizer_params_info=optimizer_params_info,
561+
params_info=params_info,
562562
verbose=self.verbose,
563563
)
564564

colossalai/booster/plugin/hybrid_parallel_plugin.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,7 @@ def get_param_info(optim: Optimizer):
213213

214214
if optim is None:
215215
return {}
216-
param_info = {
217-
"param_groups": [],
218-
"param2id": {},
219-
"id2param": {},
220-
"param2shape": {},
221-
}
216+
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
222217
start_index = 0
223218
for group in optim.param_groups:
224219
packed_group = {k: v for k, v in group.items() if k != "params"}
@@ -947,6 +942,8 @@ class HybridParallelPlugin(PipelinePluginBase):
947942
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
948943
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
949944
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
945+
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
946+
950947
"""
951948

952949
def __init__(
@@ -989,6 +986,7 @@ def __init__(
989986
num_model_chunks: int = 1,
990987
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
991988
enable_metadata_cache: bool = True,
989+
make_vocab_size_divisible_by: int = 64,
992990
) -> None:
993991
super().__init__()
994992
assert (
@@ -1095,6 +1093,7 @@ def __init__(
10951093
sequence_parallelism_mode=sequence_parallelism_mode,
10961094
enable_sequence_overlap=enable_sequence_overlap,
10971095
parallel_output=parallel_output,
1096+
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
10981097
gradient_checkpoint_config=gradient_checkpoint_config,
10991098
)
11001099
self.amp_config = dict(

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414

1515
from colossalai.cluster import DistCoordinator
1616
from colossalai.interface import ModelWrapper, OptimizerWrapper
17+
from colossalai.tensor.padded_tensor import (
18+
init_as_padded_tensor,
19+
is_padded_tensor,
20+
to_padded_tensor,
21+
to_unpadded_tensor,
22+
)
1723
from colossalai.utils import get_current_device
1824

1925
from .general_checkpoint_io import GeneralCheckpointIO
@@ -32,6 +38,7 @@
3238
save_param_groups,
3339
save_state_dict,
3440
save_state_dict_shards,
41+
search_padding_dim,
3542
search_tp_partition_dim,
3643
sharded_optimizer_loading_epilogue,
3744
)
@@ -89,6 +96,8 @@ def _model_sharder(
8996
if param is None:
9097
continue
9198
# Gather tensor pieces when using tensor parallel.
99+
if is_padded_tensor(param):
100+
param = to_unpadded_tensor(param)
92101
param_ = gather_distributed_param(param, keep_vars=False)
93102
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
94103
if block is not None:
@@ -231,7 +240,6 @@ def save_sharded_model(
231240
# When pipeline is used, each stage produces its own shard files and index files.
232241
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
233242
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
234-
235243
final_index_file_path = copy.deepcopy(save_index_file)
236244
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
237245
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
@@ -251,6 +259,7 @@ def save_sharded_model(
251259
use_safetensors=use_safetensors,
252260
use_pp_format=True,
253261
)
262+
254263
if control_saving:
255264
assert (
256265
self.dp_rank == 0 and self.tp_rank == 0
@@ -867,6 +876,11 @@ def gather_from_sharded_optimizer_state(
867876
dist.all_gather(gather_tensor, v, group=tp_group)
868877
v = torch.cat(gather_tensor, dim=partition_dim)
869878

879+
padding_dim = search_padding_dim(v.shape, original_shape)
880+
if padding_dim is not None:
881+
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
882+
v = to_unpadded_tensor(v)
883+
870884
state_[k] = v.detach().clone().to(device)
871885

872886
return state_
@@ -899,6 +913,19 @@ def shard_from_complete_optimizer_state(
899913
if isinstance(v, torch.Tensor) and k != "step":
900914
# Shard state along tensor parallel group.
901915
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
916+
global_shape = current_shape
917+
if partition_dim is not None:
918+
# pad embedding params
919+
global_shape = (
920+
*current_shape[:partition_dim],
921+
current_shape[partition_dim] * self.tp_size,
922+
*current_shape[partition_dim + 1 :],
923+
)
924+
925+
padding_dim = search_padding_dim(global_shape, original_shape)
926+
if padding_dim is not None:
927+
v = to_padded_tensor(v, global_shape[padding_dim], padding_dim)
928+
902929
if partition_dim is not None:
903930
slice_size = current_shape[partition_dim]
904931
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]

colossalai/checkpoint_io/utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
120120
return partition_dim
121121

122122

123+
def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]:
124+
padding_dim = None
125+
for dim, length in enumerate(global_shape):
126+
if length > original_shape[dim]:
127+
padding_dim = dim
128+
break
129+
return padding_dim
130+
131+
123132
# ======================================
124133
# Helper classes and functions for saving shard file
125134
# ======================================

colossalai/shardformer/layer/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from ._operation import all_to_all_comm
22
from .attn import AttnMaskType, ColoAttention
33
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
4-
from .embedding import Embedding1D, VocabParallelEmbedding1D
5-
from .linear import Linear1D_Col, Linear1D_Row
4+
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
5+
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
66
from .loss import cross_entropy_1d
77
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
88
from .parallel_module import ParallelModule
@@ -25,6 +25,9 @@
2525
"FusedRMSNorm",
2626
"FusedLinear1D_Col",
2727
"ParallelModule",
28+
"PaddingEmbedding",
29+
"PaddingLMHead",
30+
"VocabParallelLMHead1D",
2831
"AttnMaskType",
2932
"ColoAttention",
3033
"all_to_all_comm",

colossalai/shardformer/layer/embedding.py

+95-16
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
)
2222

2323
from ._operation import gather_forward_split_backward, reduce_forward
24-
from .parallel_module import ParallelModule
24+
from .parallel_module import PaddingParallelModule, ParallelModule
2525
from .utils import create_randomizer_with_offset
2626

27-
__all__ = ["Embedding1D", "VocabParallelEmbedding1D"]
27+
__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"]
2828

2929

3030
class Embedding1D(ParallelModule):
@@ -161,7 +161,80 @@ def forward(self, input_: Tensor) -> Tensor:
161161
return output_parallel
162162

163163

164-
class VocabParallelEmbedding1D(ParallelModule):
164+
class PaddingEmbedding(PaddingParallelModule):
165+
def __init__(
166+
self,
167+
num_embeddings: int,
168+
embedding_dim: int,
169+
padding_idx: int = None,
170+
dtype: torch.dtype = None,
171+
device: torch.device = None,
172+
weight: Optional[nn.Parameter] = None,
173+
make_vocab_size_divisible_by: int = 64,
174+
*args,
175+
**kwargs,
176+
):
177+
self.num_embeddings = num_embeddings
178+
self.embedding_dim = embedding_dim
179+
self.embed_args = args
180+
self.embed_kwargs = kwargs
181+
self.padding_idx = padding_idx
182+
if num_embeddings % make_vocab_size_divisible_by != 0:
183+
self.num_embeddings = (
184+
num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by)
185+
)
186+
# create weight and bias
187+
if weight is None:
188+
factory_kwargs = {"device": device, "dtype": dtype}
189+
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
190+
else:
191+
weight.data = weight.data.to(device=device, dtype=dtype)
192+
193+
super().__init__(self.num_embeddings, num_embeddings, weight)
194+
195+
if weight is None:
196+
self.reset_parameters()
197+
198+
def reset_parameters(self) -> None:
199+
init.normal_(self.weight)
200+
self._fill_padding_idx_with_zero()
201+
202+
def _fill_padding_idx_with_zero(self) -> None:
203+
if self.padding_idx is not None:
204+
with torch.no_grad():
205+
self.weight[self.padding_idx].fill_(0)
206+
207+
def forward(self, input: Tensor) -> Tensor:
208+
return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
209+
210+
@staticmethod
211+
def from_native_module(
212+
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
213+
) -> PaddingParallelModule:
214+
r"""
215+
Convert a native pytorch embedding module to a parallel module.
216+
"""
217+
LazyInitContext.materialize(module)
218+
# get the origin attributes
219+
num_embeddings = module.num_embeddings
220+
embedding_dim = module.embedding_dim
221+
padding_idx = module.padding_idx
222+
device = module.weight.device
223+
# create the parallel module
224+
padding_embedding = PaddingEmbedding(
225+
num_embeddings=num_embeddings,
226+
embedding_dim=embedding_dim,
227+
padding_idx=padding_idx,
228+
device=device,
229+
weight=module.weight,
230+
*args,
231+
**kwargs,
232+
)
233+
234+
return padding_embedding
235+
236+
237+
class VocabParallelEmbedding1D(PaddingParallelModule):
165238
r"""Embedding parallelized in the vocabulary dimension.
166239
167240
Args:
@@ -201,10 +274,10 @@ def __init__(
201274
process_group: ProcessGroup = None,
202275
weight: Optional[nn.Parameter] = None,
203276
weight_initializer: Callable = init.normal_(),
277+
make_vocab_size_divisible_by: int = 64,
204278
*args,
205279
**kwargs,
206280
):
207-
super().__init__()
208281
self.num_embeddings = num_embeddings
209282
self.embedding_dim = embedding_dim
210283
self.embed_args = args
@@ -214,8 +287,23 @@ def __init__(
214287
tensor_parallel_size = dist.get_world_size(group=process_group)
215288
tensor_parallel_rank = dist.get_rank(group=process_group)
216289

217-
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
218-
self.num_embeddings = self.num_embeddings_per_partition
290+
# generate weight and bias
291+
if weight is None:
292+
factory_kwargs = {"device": device, "dtype": dtype}
293+
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
294+
else:
295+
weight.data = weight.data.to(device=device, dtype=dtype)
296+
297+
# calculate new padding size
298+
multiple = make_vocab_size_divisible_by * tensor_parallel_size
299+
if num_embeddings % multiple != 0:
300+
self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple)
301+
302+
# resize vocabulary size
303+
super().__init__(self.num_embeddings, num_embeddings, weight)
304+
305+
# deal with tensor parallelism
306+
self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)
219307
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
220308
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
221309

@@ -226,13 +314,6 @@ def __init__(
226314
seed = torch.random.initial_seed()
227315
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
228316

229-
# parameter
230-
if weight is None:
231-
factory_kwargs = {"device": device, "dtype": dtype}
232-
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
233-
else:
234-
weight.data = weight.data.to(device=device, dtype=dtype)
235-
self.weight = weight
236317
if not is_distributed_tensor(self.weight):
237318
sharded_weight = shard_rowwise(self.weight.data, process_group)
238319
sharded_tensor_to_existing_param(sharded_weight, self.weight)
@@ -243,7 +324,7 @@ def __init__(
243324
@staticmethod
244325
def from_native_module(
245326
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
246-
) -> ParallelModule:
327+
) -> PaddingParallelModule:
247328
r"""
248329
Convert a native pytorch embedding module to a parallel module.
249330
"""
@@ -303,11 +384,9 @@ def forward(self, input_: Tensor) -> Tensor:
303384
# Mask the input.
304385
masked_input = input_.clone() - self.vocab_start_index
305386
masked_input[input_mask] = 0
306-
307387
output_parallel = F.embedding(
308388
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
309389
)
310-
311390
# Mask the output embedding.
312391
embedding_output = output_parallel.clone()
313392
embedding_output[input_mask, :] = 0.0

0 commit comments

Comments
 (0)