forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add LoRA adapter layer and Mixtral LoRA training
- Loading branch information
Matvei Pashkovskii
committed
Feb 3, 2025
1 parent
fe353fd
commit 4b5b3fc
Showing
6 changed files
with
451 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
|
||
|
||
import os | ||
import sys | ||
from argparse import ArgumentParser | ||
|
||
sys.path.append( | ||
os.path.abspath( | ||
os.path.join( | ||
os.path.dirname(__file__), | ||
os.path.pardir, | ||
os.path.pardir | ||
) | ||
) | ||
) | ||
|
||
from megatron.core.enums import ModelType | ||
from megatron.core.models.gpt import GPTModel | ||
from megatron.core.transformer.lora_adapter import LoraAdapter | ||
from megatron.training import get_args, pretrain | ||
|
||
from pretrain_gpt import train_valid_test_datasets_provider, forward_step, model_provider | ||
|
||
|
||
def lora_model_provider(pre_process: bool = True, post_process: bool = True) -> GPTModel: | ||
args = get_args() | ||
rank = args.lora_rank | ||
alpha = args.lora_alpha | ||
assert rank > 0 and alpha > 0, "LoRA rank and alpha have to be greater then zero" | ||
assert not args.moe_grouped_gemm, "MoE grouped GEMM is not supported" | ||
|
||
model = model_provider(pre_process, post_process) | ||
common_args = { | ||
"config": model.config, | ||
"rank": rank, | ||
"alpha": alpha, | ||
"dropout": args.lora_dropout, | ||
} | ||
for layer in model.decoder.layers: | ||
layer.self_attention.linear_qkv = LoraAdapter(layer.self_attention.linear_qkv, **common_args) | ||
layer.self_attention.linear_proj = LoraAdapter(layer.self_attention.linear_proj, **common_args) | ||
layer.mlp.router = LoraAdapter(layer.mlp.router, **common_args) | ||
for fc in layer.mlp.experts.local_experts: | ||
fc.linear_fc1 = LoraAdapter(fc.linear_fc1, is_expert=True, **common_args) | ||
fc.linear_fc2 = LoraAdapter(fc.linear_fc2, is_expert=True, **common_args) | ||
model.output_layer = LoraAdapter(model.output_layer, **common_args) | ||
return model | ||
|
||
|
||
def add_lora_args(parser: ArgumentParser) -> ArgumentParser: | ||
group = parser.add_argument_group(title='LoRA') | ||
group.add_argument('--lora-rank', default=16, type=int, | ||
help='LoRA rank') | ||
group.add_argument('--lora-alpha', default=32.0, type=float, | ||
help='LoRA alpha') | ||
group.add_argument('--lora-dropout', default=0.1, type=float, | ||
help='LoRA dropout') | ||
return parser | ||
|
||
|
||
if __name__ == "__main__": | ||
train_valid_test_datasets_provider.is_distributed = True | ||
pretrain( | ||
train_valid_test_datasets_provider, | ||
lora_model_provider, | ||
ModelType.encoder_or_decoder, | ||
forward_step, | ||
extra_args_provider=add_lora_args, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
#!/bin/bash | ||
|
||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
export CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
|
||
export NCCL_CHECKS_DISABLE=1 | ||
|
||
export NVTE_CK_V3_ATOMIC_FP32=0 | ||
export NVTE_CK_V3_BF16_CVT=1 | ||
export NVTE_CK_V3_SPEC=1 | ||
export NVTE_CK_USES_BWD_V3=1 | ||
|
||
export TE_HIPBLASLT_TUNING_ALGO_COUNT=50 | ||
export TE_HIPBLASLT_TUNING_RUN_COUNT=10 | ||
|
||
export TORCH_NCCL_HIGH_PRIORITY=1 | ||
|
||
GPUS_PER_NODE=8 | ||
# Change for multinode config | ||
MASTER_ADDR=${MASTER_ADDR:-"localhost"} | ||
MASTER_PORT=${MASTER_PORT:-"6000"} | ||
NNODES=${SLURM_NNODES:-"1"} | ||
NODE_RANK=${RANK:-"0"} | ||
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) | ||
|
||
CHECKPOINT_PATH=$1 | ||
TOKENIZER_MODEL=$2 | ||
DATA_PATH=$3 | ||
|
||
DISTRIBUTED_ARGS=( | ||
--nproc_per_node $GPUS_PER_NODE | ||
--nnodes $NNODES | ||
--node_rank $NODE_RANK | ||
--master_addr $MASTER_ADDR | ||
--master_port $MASTER_PORT | ||
) | ||
|
||
MODEL_ARGS=( | ||
--use-mcore-models | ||
--disable-bias-linear | ||
--seq-length 4096 | ||
--max-position-embeddings 32768 | ||
--num-layers 32 | ||
--hidden-size 4096 | ||
--ffn-hidden-size 14336 | ||
--num-attention-heads 32 | ||
--init-method-std 0.01 | ||
--attention-dropout 0.0 | ||
--hidden-dropout 0.0 | ||
--normalization RMSNorm | ||
--position-embedding-type rope | ||
--swiglu | ||
--untie-embeddings-and-output-weights | ||
--group-query-attention | ||
--num-query-groups 8 | ||
--no-masked-softmax-fusion | ||
--no-position-embedding | ||
--rotary-base 1000000 | ||
--lora-rank 16 | ||
--lora-alpha 32 | ||
) | ||
|
||
MOE_ARGS=( | ||
--num-experts 8 | ||
--moe-router-topk 2 | ||
--moe-router-load-balancing-type aux_loss | ||
--moe-aux-loss-coeff 1e-2 | ||
--moe-z-loss-coeff 1e-3 | ||
--moe-token-dispatcher-type alltoall | ||
--moe-pad-expert-input-to-capacity | ||
--moe-expert-capacity-factor 1.25 | ||
--overlap-param-gather | ||
--overlap-grad-reduce | ||
) | ||
|
||
DATA_ARGS=( | ||
--data-cache-path ~/data/cache | ||
--data-path $DATA_PATH | ||
--tokenizer-model $TOKENIZER_MODEL | ||
--tokenizer-type Llama2Tokenizer | ||
# --split 99990,8,2 | ||
) | ||
|
||
TRAINING_ARGS=( | ||
--train-iters 5000 | ||
--micro-batch-size 2 | ||
--global-batch-size 64 | ||
--lr 1e-4 | ||
# --lr-decay-iters 320000 | ||
--lr-decay-style cosine | ||
# --lr-warmup-iters 500 | ||
--min-lr 1.0e-5 | ||
--weight-decay 0.1 | ||
--clip-grad 1.0 | ||
--bf16 | ||
--no-gradient-accumulation-fusion | ||
--fp8-margin 0 | ||
--fp8-format hybrid | ||
--fp8-interval 1 | ||
--fp8-amax-history-len 1024 | ||
--fp8-amax-compute-algo max | ||
--attention-softmax-in-fp32 | ||
) | ||
|
||
MODEL_PARALLEL_ARGS=( | ||
--tensor-model-parallel-size 1 | ||
--pipeline-model-parallel-size 1 | ||
--expert-model-parallel-size 8 | ||
--use-distributed-optimizer | ||
--sequence-parallel | ||
) | ||
|
||
LOGGING_ARGS=( | ||
--eval-interval 1000 | ||
--eval-iters 10 | ||
--log-interval 1 | ||
--log-throughput | ||
--tensorboard-dir $CHECKPOINT_PATH/tensorboard | ||
|
||
# --save-interval 500 | ||
# --no-save-optim | ||
# --save $CHECKPOINT_PATH | ||
|
||
--exit-on-missing-checkpoint | ||
--load $CHECKPOINT_PATH | ||
--no-load-optim \ | ||
--no-load-rng | ||
) | ||
|
||
mkdir -p $CHECKPOINT_PATH/logs | ||
torchrun ${DISTRIBUTED_ARGS[@]} examples/mixtral/lora_mixtral.py \ | ||
${MODEL_ARGS[@]} \ | ||
${MOE_ARGS[@]} \ | ||
${DATA_ARGS[@]} \ | ||
${TRAINING_ARGS[@]} \ | ||
${MODEL_PARALLEL_ARGS[@]} \ | ||
${LOGGING_ARGS[@]} |& tee $CHECKPOINT_PATH/logs/output_`date +"%Y%m%d_%H%M"`.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
import logging | ||
import math | ||
from copy import deepcopy | ||
from functools import partial | ||
from typing import Tuple | ||
|
||
import torch | ||
|
||
from megatron.core.extensions.transformer_engine import ( | ||
TEColumnParallelLinear, | ||
TELayerNormColumnParallelLinear, | ||
TELinear, | ||
TERowParallelLinear, | ||
) | ||
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear | ||
from megatron.core.transformer import TransformerConfig | ||
from megatron.core.transformer.module import MegatronModule | ||
|
||
|
||
LOGGER = logging.getLogger(__name__) | ||
KAIMING_INIT_METHOD: callable = lambda x: torch.nn.init.kaiming_uniform_(x, a=math.sqrt(5)) | ||
LORA_LAYERS_DEFAULT_CONFIG = { | ||
"bias": False, | ||
"skip_bias_add": True, | ||
} | ||
COLUMN_PARALLEL_LAYERS = [ | ||
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, parallel_mode=None, skip_weight_param_allocation=False), | ||
partial(ColumnParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_), | ||
] | ||
TE_COLUMN_PARALLEL_LAYERS = [ | ||
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, parallel_mode=None, skip_weight_param_allocation=False), | ||
partial(TEColumnParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_, gather_output=False), | ||
] | ||
LORA_LAYERS_MAPPING = { | ||
ColumnParallelLinear: COLUMN_PARALLEL_LAYERS, | ||
TEColumnParallelLinear: TE_COLUMN_PARALLEL_LAYERS, | ||
TELayerNormColumnParallelLinear: TE_COLUMN_PARALLEL_LAYERS, | ||
RowParallelLinear: [ | ||
partial(RowParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, input_is_parallel=True), | ||
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_, parallel_mode=None, skip_weight_param_allocation=False), | ||
], | ||
TERowParallelLinear: [ | ||
partial(TERowParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, input_is_parallel=True), | ||
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_, parallel_mode=None, skip_weight_param_allocation=False), | ||
], | ||
} | ||
|
||
|
||
class LoraAdapter(MegatronModule): | ||
def __init__(self, base_layer: torch.nn.Module, *, config: TransformerConfig, rank: int, alpha: float, dropout: float, is_expert: bool = False): | ||
super(LoraAdapter, self).__init__(config) | ||
self.lora_alpha = alpha | ||
self.base_layer = base_layer | ||
self.base_layer.weight.requires_grad = False | ||
self.lora_a = None | ||
self.lora_b = None | ||
self.register_load_state_dict_pre_hook(self._remap_base_layer_for_training) | ||
self.register_load_state_dict_post_hook(self._ignore_missing_lora_keys_for_training) | ||
|
||
base_layer_class = type(base_layer) | ||
if base_layer_class not in LORA_LAYERS_MAPPING: | ||
if torch.distributed.get_rank() == 0: | ||
LOGGER.warning(f"LoRA is not supported for {base_layer_class}. Freezing weights of {base_layer_class} but skipping addition of LoRA layers") | ||
return | ||
|
||
layer_config = { | ||
"config": config, | ||
"is_expert": is_expert, | ||
} | ||
output_size, input_size = self.base_layer.weight.shape | ||
if base_layer_class in [RowParallelLinear, TERowParallelLinear]: | ||
input_size *= config.tensor_model_parallel_size | ||
if base_layer_class in [ColumnParallelLinear, TEColumnParallelLinear, TELayerNormColumnParallelLinear]: | ||
output_size *= config.tensor_model_parallel_size | ||
lora_a_class, lora_b_class = LORA_LAYERS_MAPPING[base_layer_class] | ||
self.lora_a = lora_a_class(input_size=input_size, output_size=rank, **layer_config) | ||
self.lora_b = lora_b_class(input_size=rank, output_size=output_size, **layer_config) | ||
self.lora_dropout = torch.nn.Dropout(p=dropout, inplace=False) | ||
|
||
def _remap_base_layer_for_training(self, _: torch.nn.Module, state_dict: dict, prefix: str, *args) -> None: | ||
extra_prefix = "base_layer." | ||
keys = list(state_dict.keys()) | ||
for key in keys: | ||
# The model is already finetuned with LoRA | ||
if extra_prefix in key or "lora_" in key: | ||
continue | ||
|
||
# The model has no adapter layers | ||
new_key = key.replace(prefix, f"{prefix}{extra_prefix}") | ||
state_dict[new_key] = state_dict.pop(key) | ||
|
||
def _ignore_missing_lora_keys_for_training(self, _: torch.nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys) -> None: | ||
keys = deepcopy(incompatible_keys.missing_keys) | ||
for key in keys: | ||
if "lora_" in key: | ||
incompatible_keys.missing_keys.remove(key) | ||
|
||
def forward(self, input: torch.Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: | ||
output, bias = self.base_layer(input, *args, **kwargs) | ||
if self.lora_a is None: | ||
return output, bias | ||
|
||
lora_a_output_parallel, _ = self.lora_a(input) | ||
lora_b_output_parallel, _ = self.lora_b(lora_a_output_parallel) | ||
lora_dropout_output_parallel = self.lora_dropout(lora_b_output_parallel) | ||
return output + self.lora_alpha * lora_dropout_output_parallel, bias |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.