From 1c3b892b949921fd2a21e48c330f84e30c866854 Mon Sep 17 00:00:00 2001 From: Yang Gao Date: Thu, 8 Feb 2024 12:15:18 +0800 Subject: [PATCH] feat(model): update modeling_internlm2 with configs (#15) --- .gitignore | 1 + configs/_base_/default_runtime.py | 37 + configs/_base_/models/internlm2_20B.py | 73 ++ configs/_base_/models/internlm2_7B.py | 73 ++ configs/_base_/models/internlm_20B.py | 68 ++ configs/_base_/models/internlm_7B.py | 68 ++ configs/demo.py | 155 ++++ internlm/model/__init__.py | 2 + internlm/model/modeling_internlm2.py | 1149 ++++++++++++++++++++++++ internlm/utils/utils.py | 18 + train.py | 2 +- 11 files changed, 1645 insertions(+), 1 deletion(-) create mode 100644 configs/_base_/default_runtime.py create mode 100644 configs/_base_/models/internlm2_20B.py create mode 100644 configs/_base_/models/internlm2_7B.py create mode 100644 configs/_base_/models/internlm_20B.py create mode 100644 configs/_base_/models/internlm_7B.py create mode 100644 configs/demo.py create mode 100644 internlm/model/modeling_internlm2.py create mode 100644 internlm/utils/utils.py diff --git a/.gitignore b/.gitignore index 5e78704c..2944a354 100644 --- a/.gitignore +++ b/.gitignore @@ -128,6 +128,7 @@ aim_logs/ nvmelogs/ run_backup/ runs/ +RUN/ runs_bak/ LLM_ALERT small_demo/ diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py new file mode 100644 index 00000000..eb7ee6c6 --- /dev/null +++ b/configs/_base_/default_runtime.py @@ -0,0 +1,37 @@ +# Copyright (c) InternLM. All rights reserved. + +cudnn_deterministic = False +cudnn_benchmark = False + +enable_tb = True + +grad_profiling = dict( + # calculate layer norms and parameter norms, and show them on tensorboard + grad_norm_profiling=False, + # count zero gradients, and show them on tensorboard + zero_grad_profiling=False, + # [optional] layers displayed on tensorboard, default: layers=["ScaleColumnParallelLinear"] + # if not set, display all layers + layers=["ScaleColumnParallelLinear"], + vocab_grad_norm_profiling=False, + interval_steps=5, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) diff --git a/configs/_base_/models/internlm2_20B.py b/configs/_base_/models/internlm2_20B.py new file mode 100644 index 00000000..d1176977 --- /dev/null +++ b/configs/_base_/models/internlm2_20B.py @@ -0,0 +1,73 @@ +# Copyright (c) InternLM. All rights reserved. + +model_type = "INTERNLM2" + +VOCAB_SIZE = 92544 +HIDDEN_SIZE = 6144 +NUM_ATTENTION_HEAD = 48 +NUM_KV_ATTENTION_HEAD = 8 +MLP_RATIO = 8 / 3 +NUM_LAYER = 48 + +model = dict( + num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + checkpoint=1.0, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + embed_split_hidden=True, + num_layers=NUM_LAYER, + hidden_size=HIDDEN_SIZE, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=NUM_ATTENTION_HEAD, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, + mlp_ratio=MLP_RATIO, + norm_type="rmsnorm", + adapt_hf=True, + apply_post_layer_norm=False, + no_bias=True, + layer_norm_epsilon=1e-5, + rope_base=1000000, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=True, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +""" +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. + 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. memory_pool: bool, enable/disable memory pool, defaults to False. +""" +parallel = dict( + zero1=dict(size=16), + tensor=dict(size=2, mode="msp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=1, overlap=True, memory_pool=True), +) diff --git a/configs/_base_/models/internlm2_7B.py b/configs/_base_/models/internlm2_7B.py new file mode 100644 index 00000000..43ab9036 --- /dev/null +++ b/configs/_base_/models/internlm2_7B.py @@ -0,0 +1,73 @@ +# Copyright (c) InternLM. All rights reserved. + +model_type = "INTERNLM2" + +VOCAB_SIZE = 92544 +HIDDEN_SIZE = 4096 +NUM_ATTENTION_HEAD = 32 +NUM_KV_ATTENTION_HEAD = 8 +MLP_RATIO = 3.5 +NUM_LAYER = 32 + +model = dict( + num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + checkpoint=0.2, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + embed_split_hidden=True, + num_layers=NUM_LAYER, + hidden_size=HIDDEN_SIZE, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=NUM_ATTENTION_HEAD, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, + mlp_ratio=MLP_RATIO, + norm_type="rmsnorm", + adapt_hf=False, + apply_post_layer_norm=False, + no_bias=True, + layer_norm_epsilon=1e-5, + rope_base=1000000, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=True, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +""" +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. + 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. memory_pool: bool, enable/disable memory pool, defaults to False. +""" +parallel = dict( + zero1=dict(size=8), + tensor=dict(size=1, mode="mtp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=1, overlap=True, memory_pool=True), +) diff --git a/configs/_base_/models/internlm_20B.py b/configs/_base_/models/internlm_20B.py new file mode 100644 index 00000000..fd17e77a --- /dev/null +++ b/configs/_base_/models/internlm_20B.py @@ -0,0 +1,68 @@ +# Copyright (c) InternLM. All rights reserved. + +model_type = "INTERNLM" + +VOCAB_SIZE = 103168 +HIDDEN_SIZE = 5120 +NUM_ATTENTION_HEAD = 40 +MLP_RATIO = 8 / 3 +NUM_LAYER = 60 + +model = dict( + num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + embed_split_hidden=True, + num_layers=NUM_LAYER, + hidden_size=HIDDEN_SIZE, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=NUM_ATTENTION_HEAD, + mlp_ratio=MLP_RATIO, + norm_type="rmsnorm", + apply_post_layer_norm=False, + layer_norm_epsilon=1e-5, +) + +hybrid_zero_optimizer = dict( + # Enable overlap_communication + overlap_sync_grad=True, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +""" +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. + 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. memory_pool: bool, enable/disable memory pool, defaults to False. +""" +parallel = dict( + zero1=dict(size=8), + tensor=dict(size=4, mode="mtp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=1, overlap=True, memory_pool=True), +) diff --git a/configs/_base_/models/internlm_7B.py b/configs/_base_/models/internlm_7B.py new file mode 100644 index 00000000..7334c146 --- /dev/null +++ b/configs/_base_/models/internlm_7B.py @@ -0,0 +1,68 @@ +# Copyright (c) InternLM. All rights reserved. + +model_type = "INTERNLM" + +VOCAB_SIZE = 103168 +HIDDEN_SIZE = 4096 +NUM_ATTENTION_HEAD = 32 +MLP_RATIO = 8 / 3 +NUM_LAYER = 32 + +model = dict( + num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + embed_split_hidden=True, + num_layers=NUM_LAYER, + hidden_size=HIDDEN_SIZE, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=NUM_ATTENTION_HEAD, + mlp_ratio=MLP_RATIO, + norm_type="rmsnorm", + apply_post_layer_norm=False, + layer_norm_epsilon=1e-5, +) + +hybrid_zero_optimizer = dict( + # Enable overlap_communication + overlap_sync_grad=True, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +""" +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. + 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. memory_pool: bool, enable/disable memory pool, defaults to False. +""" +parallel = dict( + zero1=dict(size=8), + tensor=dict(size=1, mode="mtp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=1, overlap=True, memory_pool=True), +) diff --git a/configs/demo.py b/configs/demo.py new file mode 100644 index 00000000..52b3fb9f --- /dev/null +++ b/configs/demo.py @@ -0,0 +1,155 @@ +# Copyright (c) InternLM. All rights reserved. +from internlm.utils.utils import read_base + +with read_base(): + from configs._base_.default_runtime import * # pylint: disable=W0401,W0614 # noqa: F401 + from configs._base_.models.internlm2_7B import * # pylint: disable=W0401,W0614 # noqa: F401 + +JOB_NAME = "7b_train" + +DO_ALERT = False + +SEQ_LEN = 2048 +HIDDEN_SIZE = 4096 +NUM_ATTENTION_HEAD = 32 +MLP_RATIO = 8 / 3 +NUM_LAYER = 32 +VOCAB_SIZE = 103168 + +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts" +LOAD_CKPT_FOLDER = "local:llm_ckpts/49" +LOAD_CKPT_FOLDER = None + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internlm", "llama", "hf_llama". + load_ckpt_info=dict(path=LOAD_CKPT_FOLDER, content=("model",), ckpt_type="internlm"), + # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering + # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) + # with an automatic restart mechanism upon training reboot. + # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint + # path specified in `load_ckpt_info` by default. + # If you want to initialize your model weights from another model, you must set `auto_resume` to False. + # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. + auto_resume=False, + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = None # "/path/to/dataset" +VALID_FOLDER = None # "/path/to/dataset" +data = dict( + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + # packed_length = micro_bsz * SEQ_LEN + micro_bsz=2, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=50, + pack_sample_into_one=False, + total_steps=50000, + skip_batches="", + # rampup_batch_size (str): A string with three space-separated integers representing the + # starting batch size, the increment, and the number of steps between + # each increment. For example, "192 24 8" means that the batch size (micro_num) + # starts at 192 and increases by 24 every 8 steps. Defaults to None. + # (IMPORTANT): The interval step size is 'micro_bsz'. + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, + diag_outlier_ratio=1.1, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=True, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +loss = dict( + label_smoothing=0, +) + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, + ), +) + +use_fp32_norm = False + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" diff --git a/internlm/model/__init__.py b/internlm/model/__init__.py index c10552c3..c93d6c7a 100644 --- a/internlm/model/__init__.py +++ b/internlm/model/__init__.py @@ -5,6 +5,7 @@ from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear from .metrics import AccPerplex from .modeling_internlm import build_model_with_cfg +from .modeling_internlm2 import build_model_with_cfg as build_model_with_cfg2 from .modeling_llama import build_model_with_cfg as build_model_with_llama_cfg from .modeling_moe import build_model_with_moe_cfg from .moe import MoE @@ -23,6 +24,7 @@ "DistributedAttention", "gather_forward_split_backward", "build_model_with_cfg", + "build_model_with_cfg2", "build_model_with_moe_cfg", "build_model_with_llama_cfg", ] diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py new file mode 100644 index 00000000..25ad261c --- /dev/null +++ b/internlm/model/modeling_internlm2.py @@ -0,0 +1,1149 @@ +# Copyright (c) InternLM. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange +from flash_attn import flash_attn_varlen_kvpacked_func +from flash_attn.modules.embedding import ParallelGPT2Embeddings +from flash_attn.modules.mha import ( + CrossAttention, + FlashCrossAttention, + FlashSelfAttention, + SelfAttention, + _update_kv_cache, +) +from flash_attn.modules.mlp import ParallelFusedMLP +from flash_attn.ops.layer_norm import dropout_add_layer_norm +from torch import nn + +from internlm.core.context import ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.initialize.initialize_tensor import ( + normal_, + scaled_init_method_normal, + scaled_init_method_uniform, + uniform_, +) +from internlm.model.embedding import ( + DynamicNTKScalingRotaryEmbedding, + Embedding1D, + RotaryEmbedding, +) +from internlm.model.linear import ( + MegatronScaleColumnParallelLinear, + RewardModelLinear, + ScaleColumnParallelLinear, + get_linear_cls, + get_mlp_cls, +) +from internlm.model.multi_head_attention import DistributedAttention +from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm +from internlm.solver.pipeline_utils import partition_uniform +from internlm.utils.checkpoint import activation_checkpoint +from internlm.utils.common import filter_kwargs +from internlm.utils.logger import get_logger +from internlm.utils.registry import MODEL_INITIALIZER + +MODEL_TYPE = "INTERNLM2" + +logger = get_logger(__file__) +RMSNorm = try_import_RMSNorm() + + +class MHA(nn.Module): + """ + Multi-head self-attention and cross-attention. + + Args: + embed_dim (int): The dimention of hidden state. + num_heads (int): The number of attention heads. + num_kv_heads (int): The number of attention heads for key and value. + process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. + sequence_process_group (torch.distributed.ProcessGroup): The group for `sequence_parallel`. + bias (bool): Whether the bias is needed for linears. Will be used when initializing QKV matrix and + output projection. False by default. + dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. + softmax_scale (float): The temperature to use for the softmax attention. + causal (boolean): Whether to apply causal attention mask. False by default. + layer_idx (int): The index of current layer. None by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. + rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements + XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. + use_flash_attn (bool): Whether to use flash attention or not.If False, vanilla attention module will be used. + False by default. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + rot_embed_HF_impl (Optional[bool]): Whether to use the rotary embedding implementation from HuggingFace. + True by default. + tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], + "mtp" by default. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + process_group: Optional[torch.distributed.ProcessGroup], + sequence_process_group: Optional[torch.distributed.ProcessGroup], + max_position_embeddings: int = 2048, + bias: bool = False, + dropout: float = 0.0, + softmax_scale: float = None, + causal: bool = False, + layer_idx: int = None, + use_dynamic_ntk_rope: bool = False, + use_flash_attn: bool = True, + rope_base: int = 10000, + rotary_emb_dim: int = 0, + rotary_emb_scale_base: int = 0, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + rot_embed_HF_impl: Optional[bool] = True, + tp_mode: str = "mtp", + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads" + + self.head_dim = self.embed_dim // num_heads + self.num_kv_heads = num_kv_heads + self.kv_dim = self.head_dim * num_kv_heads + self.causal = causal + self.layer_idx = layer_idx + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.dtype = dtype + + self.q_per_kv = num_heads // num_kv_heads + + self.rot_embed_HF_impl = rot_embed_HF_impl + sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) + + self.max_position_embeddings = max_position_embeddings + self.use_dynamic_ntk_rope = use_dynamic_ntk_rope + self.tp_mode = tp_mode + + if self.rotary_emb_dim > 0: + if self.use_dynamic_ntk_rope: + self.rotary_emb = DynamicNTKScalingRotaryEmbedding( + self.rotary_emb_dim, + base=rope_base, + scale_base=rotary_emb_scale_base, + device=device, + max_position_embeddings=max_position_embeddings, + scaling_factor=1.0, # Currently do not support dynamic scaling. + ) + else: + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device + ) + + Wqkv_cls = get_linear_cls(self.tp_mode, "column") + self.wqkv = Wqkv_cls( + embed_dim, + embed_dim + 2 * self.kv_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + + inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_cross_attn = inner_cross_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + + self.inner_cross_attn_causal = causal + self.inner_cross_attn_softmax_scale = softmax_scale + self.inner_cross_attn_dropout = dropout + + if self.tp_mode == "isp": + self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=sequence_process_group) + self.inner_cross_attn = DistributedAttention( + self.inner_cross_attn, sequence_process_group=sequence_process_group + ) + + wo_cls = get_linear_cls(self.tp_mode, "row") + self.wo = wo_cls( + embed_dim, + embed_dim, + process_group, + bias=bias, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + + def forward(self, x, seqlen=None, inference_params=None, **kwargs): + if kwargs.get("indexes", None) is not None: + return self._packed_forward(x=x, inference_params=inference_params, **kwargs) + else: + return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) + + def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613 + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + """ + bsz, _, _ = x.shape + qkv = self.wqkv(x) + + if seqlen is None: + qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim) + else: + qkv = rearrange(qkv, "(b s) (h gs d) -> b s h gs d", s=seqlen, gs=self.q_per_kv + 2, d=self.head_dim) + + q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :]) + + q = rearrange(q, "b s h gs d -> b s (h gs) d") + + if not self.rot_embed_HF_impl: + q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) + k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) + + if inference_params is None: + if self.rotary_emb_dim > 0: + q = self.rotary_emb._single_eval_forward(q) + k = self.rotary_emb._single_eval_forward(k) + kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) + if self.dtype is torch.float32 and self.use_flash_attn: + if q.dtype not in [torch.float16, torch.bfloat16]: + q = q.to(torch.bfloat16) + if kv.dtype not in [torch.float16, torch.bfloat16]: + kv = kv.to(torch.bfloat16) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + context = self.inner_cross_attn(q, kv).to(self.dtype) + else: + context = self.inner_cross_attn(q, kv) + + else: + assert self.rotary_emb_dim > 0 + if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: + empties = inference_params.attention_mask[..., -1].sum(dim=-1) + moved_q = q.clone() + moved_k = k.clone() + if inference_params.sequence_len_offset == 0: + for i in range(len(empties)): + if empties[i] != 0: + moved_q[i][: -empties[i]] = q[i][empties[i] :] + moved_k[i][: -empties[i]] = k[i][empties[i] :] + moved_q = self.rotary_emb._single_eval_forward( + moved_q, seqlen_offset=inference_params.sequence_len_offset + ) + moved_k = self.rotary_emb._single_eval_forward( + moved_k, seqlen_offset=inference_params.sequence_len_offset + ) + for i in range(len(empties)): + if empties[i] != 0: + q[i][empties[i] :] = moved_q[i][: -empties[i]] + k[i][empties[i] :] = moved_k[i][: -empties[i]] + else: + q[i] = moved_q[i] + k[i] = moved_k[i] + else: + q = q.squeeze(1) + k = k.squeeze(1) + q = self.rotary_emb._single_forward( + q, + inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) + - empties, + ).unsqueeze(1) + k = self.rotary_emb._single_forward( + k, + inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) + - empties, + ).unsqueeze(1) + else: + raise NotImplementedError( + "You should make sure you are aware that you are changing the method of generating." + "According to your generation function instead of inference/seq_generator_module.py, " + "You may implement here for normal running." + ) + + kv = torch.stack([k, v], dim=2) + + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + if hasattr(inference_params, "window_size") and inference_params.window_size is not None: + if inference_params.window_size <= inference_params.sequence_len_offset: + assert kv.size(1) == 1, "update kv lenth more than 1" + inference_params.key_value_memory_dict[self.layer_idx][ + :, inference_params.keep_first : inference_params.window_size - 1, ... + ] = inference_params.key_value_memory_dict[self.layer_idx][ + :, -(inference_params.window_size - 1 - inference_params.keep_first) :, ... + ].clone() + inference_params.real_sequence_len_offset = inference_params.sequence_len_offset + inference_params.sequence_len_offset = inference_params.window_size - 1 + + kv = _update_kv_cache(kv, inference_params, self.layer_idx) + + inference_params.sequence_len_offset = inference_params.real_sequence_len_offset + else: + kv = _update_kv_cache(kv, inference_params, self.layer_idx) + else: + kv = _update_kv_cache(kv, inference_params, self.layer_idx) + + # When using FP16, there is a high probability of NAN in the KV. + # Since NAN cannot be removed by multiplying with and 0, it needs + # to be removed manually here. + kv = torch.where(torch.isnan(kv), 0, kv) + + if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: + assert self.use_flash_attn is True + if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) + attn_mask = inference_params.attention_mask[:, None, ...] + attn_mask = torch.logical_or( + torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask + ) + attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) + cu_seqlens = torch.concat( + [ + torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device), + attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32), + ], + dim=0, + ) + cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32) + max_seqlen_q = attn_mask4flsh.shape[-1] + max_seqlen_k = attn_mask4flsh.shape[-1] + total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) + total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view( + -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] + ) + if self.dtype is torch.float32: + if total_q.dtype not in [torch.float16, torch.bfloat16]: + total_q = total_q.to(torch.bfloat16) + if total_kv.dtype not in [torch.float16, torch.bfloat16]: + total_kv = total_kv.to(torch.bfloat16) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + output = flash_attn_varlen_kvpacked_func( + q=total_q, + kv=total_kv, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + causal=True, + ).to(self.dtype) + else: + output = flash_attn_varlen_kvpacked_func( + q=total_q, + kv=total_kv, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + causal=True, + ) + + context = torch.zeros_like(q) + context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) + + else: + attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1) + if hasattr(inference_params, "window_size") and inference_params.window_size is not None: + if inference_params.window_size <= inference_params.sequence_len_offset: + attn_mask = torch.concat( + [ + attn_mask[..., : inference_params.keep_first], + attn_mask[..., -(inference_params.window_size - inference_params.keep_first) :], + ], + dim=-1, + ) + + k, v = torch.chunk(kv, 2, dim=2) + k = k.squeeze(2) + v = v.squeeze(2) + sp = k.shape + expansion = q.size(2) // k.size(2) + scores = torch.einsum( + "blhd,bnhd->bhln", + q, + k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) / math.sqrt(q.size(-1)) + scores = scores.masked_fill(attn_mask, -65000.0) + scores = F.softmax(scores, dim=-1) # bsz x h x L x L + context = torch.einsum( + "bhmn,bnhd->bmhd", + scores, + v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) + else: + if self.dtype is torch.float32 and self.use_flash_attn: + if q.dtype not in [torch.float16, torch.bfloat16]: + q = q.to(torch.bfloat16) + if kv.dtype not in [torch.float16, torch.bfloat16]: + kv = kv.to(torch.bfloat16) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + context = self.inner_cross_attn(q, kv, causal=True).to(self.dtype) + else: + context = self.inner_cross_attn(q, kv, causal=True) + + if seqlen is None: + context = rearrange(context, "b s h d -> b s (h d)") + else: + context = rearrange(context, "b s h d -> (b s) (h d)") + + out = self.wo(context) + return out + + def _packed_forward(self, x, inference_params=None, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + """ + assert self.use_flash_attn is True + + qkv = self.wqkv(x) + + qkv = rearrange(qkv, "t (h gs d) -> t h gs d", gs=self.q_per_kv + 2, d=self.head_dim) + + q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :]) + + q = rearrange(q, "t h gs d -> t (h gs) d") + + # qkv shift + # the rotary embedding in flash attention module in performed by separating the front and back parts, while + # most of others are done by odd-even methods. + if not self.rot_embed_HF_impl: + q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) + k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) + + indexes = kwargs.pop("indexes") + q = self.rotary_emb._single_forward(q, indexes=indexes) + k = self.rotary_emb._single_forward(k, indexes=indexes) + + if inference_params is None: + kv = torch.concat([k.unsqueeze(1), v.unsqueeze(1)], dim=1) + if self.dtype is torch.float32: + if q.dtype not in [torch.float16, torch.bfloat16]: + q = q.to(torch.bfloat16) + if kv.dtype not in [torch.float16, torch.bfloat16]: + kv = kv.to(torch.bfloat16) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + context = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=kwargs["cu_seqlens"], + cu_seqlens_k=kwargs["cu_seqlens"], + max_seqlen_q=kwargs["max_seqlen"], + max_seqlen_k=kwargs["max_seqlen"], + dropout_p=self.inner_cross_attn_dropout, + softmax_scale=self.inner_cross_attn_softmax_scale, + causal=self.inner_cross_attn_causal, + ).to(self.dtype) + else: + context = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=kwargs["cu_seqlens"], + cu_seqlens_k=kwargs["cu_seqlens"], + max_seqlen_q=kwargs["max_seqlen"], + max_seqlen_k=kwargs["max_seqlen"], + dropout_p=self.inner_cross_attn_dropout, + softmax_scale=self.inner_cross_attn_softmax_scale, + causal=self.inner_cross_attn_causal, + ) + else: + raise RuntimeError("Not support this right now") + + context = rearrange(context, "b h d -> b (h d)") # recover shape + out = self.wo(context) + return out + + +class PackedFlashLlamaLayer1D(nn.Module): + """ + InternLM2 layer. + + Args: + hidden_size (int): The hidden size of model. 768 by default. + num_attention_heads (int): The number of attention heads. 12 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + max_position_embeddings (int): The maximum position embeddings. 2048 by default. + dtype (torch.dtype): Type of data. torch.float by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + layer_idx (int): The index of current layer. 0 by default. + use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + device (Optional[Union[str, torch.device]]): The device will be used. + norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. + use_flash_attn (bool): Whether use flash-attn. True by default. + tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], + "mtp" by default. + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + """ + + def __init__( + self, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + mlp_ratio: int = 4, + attn_drop_rate: float = 0, + drop_rate: float = 0.0, + max_position_embeddings: int = 2048, + dtype: torch.dtype = torch.float, + layer_norm_epsilon: float = 1e-6, + checkpoint: bool = False, + layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, + residual_in_fp32: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm: bool = False, + fused_dropout_add_ln: bool = True, + no_bias: bool = False, + norm_type: str = "rmsnorm", + adapt_hf: bool = True, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + use_flash_attn: bool = True, + tp_mode: str = "mtp", + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, + ): + super().__init__() + self.checkpoint = checkpoint + # dropout selective checkpoint can only be enabled when checkpoint is disabled. + self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False + self.layer_idx = layer_idx + self.use_flash_attn = use_flash_attn + self.prenorm = not apply_post_layer_norm + assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" + self.fused_dropout_add_ln = fused_dropout_add_ln + self.attn_wqkv_init_std = attn_wqkv_init_std + self.attn_other_init_std = attn_other_init_std + self.ffn_uplayer_init_std = ffn_uplayer_init_std + self.ffn_other_init_std = ffn_other_init_std + + self.max_position_embeddings = max_position_embeddings + self.use_dynamic_ntk_rope = use_dynamic_ntk_rope + self.tp_mode = tp_mode + parallel_mode = ParallelMode.WEIGHT if self.tp_mode == "isp" else ParallelMode.TENSOR + + head_dim = hidden_size // num_attention_heads + self.attention = MHA( + embed_dim=hidden_size, + num_heads=num_attention_heads, + num_kv_heads=num_kv_attention_heads, + process_group=gpc.get_group(parallel_mode), + sequence_process_group=gpc.get_group(ParallelMode.TENSOR), + dropout=attn_drop_rate, + max_position_embeddings=max_position_embeddings, + softmax_scale=1 / math.sqrt(head_dim), + causal=True, + layer_idx=layer_idx, + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + rotary_emb_dim=head_dim, + rotary_emb_scale_base=0, + use_flash_attn=use_flash_attn, + device=device, + dtype=dtype, + rot_embed_HF_impl=adapt_hf, + bias=not no_bias, + rope_base=rope_base, + tp_mode=self.tp_mode, + ) + + self.dropout1 = nn.Dropout(drop_rate) + if norm_type == "rmsnorm": + self.attention_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) + self.ffn_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) + else: + self.attention_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.ffn_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + if self.fused_dropout_add_ln: + assert dropout_add_layer_norm is not None, "dropout_add_ln is not installed" + assert isinstance(self.attention_norm, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout) + + sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) + if use_swiglu: + ffn = get_mlp_cls(self.tp_mode) + self.feed_forward = ffn( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + process_group=gpc.get_group(parallel_mode), + bias=False, + device=device, + dtype=dtype, + ) + else: + self.feed_forward = ParallelFusedMLP( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + activation="gelu_approx", + process_group=gpc.get_group(parallel_mode), + bias1=False, + bias2=False, + sequence_parallel=sequence_parallel, + checkpoint_lvl=0, + heuristic="auto", + device=device, + dtype=dtype, + ) + + self.dropout2 = nn.Dropout(drop_rate) + self.use_swiglu = use_swiglu + self.use_scaled_init = use_scaled_init + self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm + self.return_residual = False + + if init_type == "normal": + self.init_func = normal_ + self.scaled_init_func = scaled_init_method_normal + else: + self.init_func = uniform_ + self.scaled_init_func = scaled_init_method_uniform + + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + for name, param in self.attention.named_parameters(): + if param.ndim == 1: + param.data.zero_() + elif "wq" in name or "wk" in name or "wv" in name: + self.init_func(std=self.attn_wqkv_init_std)(param.data) + elif self.use_scaled_init: # wo + self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.attn_other_init_std)(param.data) + + for name, param in self.feed_forward.named_parameters(): + if self.use_swiglu: + if self.use_scaled_init and "w2" in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func( + std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std + )(param.data) + else: + if self.use_scaled_init and "fc1" not in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( + param.data + ) + + def forward( + self, hidden_states, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None + ): + if self.checkpoint and self.training: + return activation_checkpoint( + self._forward, False, hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen + ) + else: + return self._forward(hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen) + + def _forward( + self, hidden_states=None, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Attn/MLP(LN(residual)) + cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 + indexes: the length of index is same as hidden states, which stand for the current position + """ + if self.prenorm: + + def _dropout_and_norm_attn(_residual, _hidden_states): + _dropped = self.dropout1(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) + else: + residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + mixer_kwargs = { + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "indexes": indexes, + "inference_params": inference_params, + } + hidden_states = self.attention(hidden_states, **mixer_kwargs) + + if not isinstance(self.feed_forward, nn.Identity): + if not self.fused_dropout_add_ln: + + def _dropout_and_norm_ffn(_residual, _hidden_states): + _dropped = self.dropout2(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.ffn_norm(_residual.to(torch.float32)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint( + _dropout_and_norm_ffn, False, residual, hidden_states + ) + else: + residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + hidden_states = self.feed_forward(hidden_states) + + return hidden_states + residual + else: + assert residual is None + mixer_kwargs = { + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "indexes": indexes, + "inference_params": inference_params, + } + mixer_out = self.attention(hidden_states, **mixer_kwargs) + if self.return_residual: # mixer out is actually a pair here + mixer_out, hidden_states = mixer_out + hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( + dtype=self.attention_norm.weight.dtype + ) + if not isinstance(self.feed_forward, nn.Identity): + mlp_out = self.feed_forward(hidden_states) + if self.return_residual: # mlp out is actually a pair here + mlp_out, hidden_states = mlp_out + hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to( + dtype=self.ffn_norm.weight.dtype + ) + return hidden_states + + +class PackedFlashLlama1D(nn.Module): + """ + 1D Packed Flash InternLM2. + + Args: + num_layers (int): The number of layer. 12 by default. + hidden_size (int): The size of hidden state. 768 by default. + num_attention_heads (int): The number of attention head. 12 by default. + vocab_size (int): The size of vocabulary. 50304 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. + drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + max_position_embeddings (int): The maximum position embeddings. 2048 by default. + dtype (torch.dtype): The type of data. torch.float by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number + of layers. 1.0 by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + first (bool): Whether input embedding layer or not. False by default. + last (bool): Whether output embedding layer or not. False by default. + embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. + True by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. + device (Optional[Union[str, torch.device]]): The device will be used. None by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. + use_flash_attn (bool): Whether to use flash-attn. True by default. + embedding_init_std (float): std used to init embedding weight. 0.02 by default, + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], + "mtp" by default. + """ + + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + vocab_size: int = 50304, + mlp_ratio: int = 4, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + max_position_embeddings: int = 2048, + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + checkpoint_fraction: float = 1.0, + layer_norm_epsilon: float = 1e-5, + first: bool = False, + last: bool = False, + embed_split_hidden: bool = False, + embed_grad_scale: float = 0.1, + parallel_output: bool = True, + start_layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm=False, + no_bias=False, + residual_in_fp32: bool = False, + norm_type: str = "rmsnorm", + adapt_hf: bool = True, + is_reward: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + use_flash_attn: bool = True, + embedding_init_std: float = 0.02, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + out_head_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, + tp_mode: str = "mtp", + ): + super().__init__() + + self.use_flash_attn = use_flash_attn + + if checkpoint_fraction <= 0: + checkpoint = False + if not checkpoint: + checkpoint_fraction = 0 + checkpoint_layer_num = num_layers * checkpoint_fraction + + self.tp_mode = tp_mode + if isinstance(gpc.config.parallel["tensor"], dict): + self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") + + if is_reward: + head_cls = RewardModelLinear + else: + head_cls = ( + ScaleColumnParallelLinear + if self.tp_mode in ["mtp", "fsp", "isp"] + else MegatronScaleColumnParallelLinear + ) + + sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) + + if first: + if embed_split_hidden: + self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + else: + self.tok_embeddings = ParallelGPT2Embeddings( + embed_dim=hidden_size, + vocab_size=vocab_size, + max_position_embeddings=-1, + process_group=gpc.get_group(ParallelMode.TENSOR), + padding_idx=None, + sequence_parallel=sequence_parallel, + device=device, + dtype=dtype, + ) + for _, param in self.tok_embeddings.named_parameters(): + if init_type == "normal": + normal_(std=embedding_init_std)(param) + else: + uniform_(std=embedding_init_std)(param) + + self.embed_grad_scale = embed_grad_scale + + self.layers = nn.ModuleList( + [ + PackedFlashLlamaLayer1D( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_kv_attention_heads=num_kv_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + max_position_embeddings=max_position_embeddings, + dtype=dtype, + layer_norm_epsilon=layer_norm_epsilon, + checkpoint=lid < checkpoint_layer_num, + layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + residual_in_fp32=residual_in_fp32, + device=device, + apply_post_layer_norm=apply_post_layer_norm, + fused_dropout_add_ln=False, + no_bias=no_bias, + norm_type=norm_type, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_swiglu=use_swiglu, + use_flash_attn=use_flash_attn, + adapt_hf=adapt_hf, + attn_wqkv_init_std=attn_wqkv_init_std, + attn_other_init_std=attn_other_init_std, + ffn_uplayer_init_std=ffn_uplayer_init_std, + ffn_other_init_std=ffn_other_init_std, + init_type=init_type, + tp_mode=self.tp_mode, + rope_base=rope_base, + ) + for lid in range(num_layers) + ] + ) + + if last: + if not apply_post_layer_norm: + if norm_type == "rmsnorm": + self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) + else: + self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + + self.output = head_cls( + in_features=hidden_size, + out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, + device=device, + dtype=dtype, + weight_scale=embed_grad_scale, + ) + for _, param in self.output.named_parameters(): + if init_type == "normal": + normal_(std=out_head_init_std)(param) + else: + uniform_(std=out_head_init_std)(param) + + self.parallel_output = parallel_output + + def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + # attention_mask: compute attention on the places where the value is 1 + if hasattr(self, "tok_embeddings"): + hidden_states = self.tok_embeddings(input_ids) + if self.embed_grad_scale != 1: + hidden_states = ( + self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() + ) + if isinstance(cu_seqlens, list): + assert len(cu_seqlens) == 1 + cu_seqlens = cu_seqlens[0].to(hidden_states.device) + + if cu_seqlens is not None: + cu_seqlens = cu_seqlens.squeeze(0) + hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state, + # the batch dimension with a size of 1 should be directly squeezed off. + + if indexes is not None: + assert len(indexes) == 1 + # The indexes are used to indicate the actual position IDs of each token in the packed input. + indexes = indexes[0] + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None + + for _, block in enumerate(self.layers): + hidden_states = block( + hidden_states, + residual=None, + cu_seqlens=cu_seqlens, + indexes=indexes, + inference_params=inference_params, + max_seqlen=max_seqlen, + ) + + if hasattr(self, "norm"): + hidden_states = self.norm(hidden_states.float()) + if hasattr(self, "output"): + hidden_states = self.output(hidden_states) + + if not self.parallel_output: + hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) + + return hidden_states + + +def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): + """ + build generic model 1d + + Args: + num_layers (int): The number of layer. + num_chunks (int): The number of partitions in pipeline parallel. + device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default. + + """ + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) + parts = all_parts[pipeline_rank] + if gpc.is_rank_for_log(): + logger.info(f"The layer sharding is {all_parts}.") + + models = [] + kwargs["checkpoint_fraction"] = float(kwargs.get("checkpoint", False)) + start_idx, end_idx = 0, 0 + for start, end in parts: + start_idx, end_idx = start, end + kwargs["num_layers"] = end - start + kwargs["first"] = start == 0 + # If there is no content in the final layer, assign the last layer. + kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 + kwargs["device"] = device + kwargs["start_layer_idx"] = start + chunk = PackedFlashLlama1D(**filter_kwargs(PackedFlashLlama1D.__init__, kwargs)).to(device) + + models.append(chunk) + torch.distributed.barrier() + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + setattr(model, "first_layer", start_idx) + setattr(model, "last_layer", end_idx) + return model + + +@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) +def build_model_with_cfg( + num_chunks=1, + checkpoint=False, + dtype=torch.float, + embed_split_hidden=False, + num_layers=48, + hidden_size=2048, + vocab_size=50304, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=32, + num_kv_attention_heads=None, + mlp_ratio=4.0, + residual_in_fp32=False, + norm_type="rmsnorm", + adapt_hf=True, + drop_rate=0, + attn_drop_rate=0, + apply_post_layer_norm=False, # pylint: disable=W0613 + no_bias=False, + deepnorm=False, + layer_norm_epsilon=1e-5, + is_reward=False, + dropout_selective_checkpoint=True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + use_flash_attn: bool = True, + embedding_init_std: float = 0.02, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + out_head_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, + max_position_embeddings=2048, + use_dynamic_ntk_rope=False, +): + """ + Builde model with config + + Args: + num_chunks (int): The number of partitions in pipeline parallel. 1 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. + dtype (torch.dtype): The type of data. torch.float by default. + embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. + False by default. + num_layers (int): The number of layer. 48 by default. + hidden_size (int): The size of hidden state. 2048 by default. + vocab_size (int): The size of vocabulary. 50304 by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + num_attention_heads (int): The number of attention head. 32 by default. + mlp_ratio (int): The ratio of MLP layers. 4.0 by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily + because this parameter requires inconsistent data types to be passed between pipelines, + which requires significant modifications to internlm. + norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. + drop_rate (float): The dropout rate of input hidden state. 0 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + is_reward (bool): Whether to use reward model. False by default. + dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. + use_scaled_init (bool): Whether to use scaled init. True by default. + use_swiglu (bool): Whether to use swiglu. True by default. + use_flash_attn (bool): Whether to use flash-attn. True by default. + embedding_init_std (float): std used to init embedding weight. 0.02 by default, + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + max_position_embeddings (int): The maximum position embeddings. 2048 by default. + use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. + """ + if deepnorm: + raise AssertionError("deepnorm will not be supported in future versions." "Use early versions if necessary.") + + cfg = dict( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_kv_attention_heads=num_kv_attention_heads if num_kv_attention_heads else num_attention_heads, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + vocab_size=vocab_size, + embed_grad_scale=embed_grad_scale, + parallel_output=parallel_output, + mlp_ratio=mlp_ratio, + apply_post_layer_norm=apply_post_layer_norm, + no_bias=no_bias, + residual_in_fp32=residual_in_fp32, + norm_type=norm_type, + adapt_hf=adapt_hf, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + layer_norm_epsilon=layer_norm_epsilon, + is_reward=is_reward, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_swiglu=use_swiglu, + use_flash_attn=use_flash_attn, + embedding_init_std=embedding_init_std, + attn_wqkv_init_std=attn_wqkv_init_std, + attn_other_init_std=attn_other_init_std, + ffn_uplayer_init_std=ffn_uplayer_init_std, + ffn_other_init_std=ffn_other_init_std, + out_head_init_std=out_head_init_std, + init_type=init_type, + rope_base=rope_base, + max_position_embeddings=max_position_embeddings, + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + ) + + return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py new file mode 100644 index 00000000..9a30eb26 --- /dev/null +++ b/internlm/utils/utils.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from contextlib import contextmanager + + +@contextmanager +def read_base(): + """Context manager to mark the base config. + + The pure Python-style configuration file allows you to use the import + syntax. However, it is important to note that you need to import the base + configuration file within the context of ``read_base``, and import other + dependencies outside of it. + + You can see more usage of Python-style configuration in the `tutorial`_ + + .. _tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta # pylint: disable=line-too-long + """ # noqa: E501 + yield diff --git a/train.py b/train.py index 490894a9..79e3c2a7 100644 --- a/train.py +++ b/train.py @@ -122,7 +122,7 @@ def main(args): ckpt_manager = CheckpointManager( ckpt_config=gpc.config.ckpt, - model=model, + model=model.model, optimizer=optimizer, lr_scheduler=lr_scheduler, train_dl=train_dl,