forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add LoRA adapter layer and Mixtral LoRA training #53
Open
mpashkovskii
wants to merge
10
commits into
ROCm:rocm_dev
Choose a base branch
from
mpashkovskii:feat/mixtral-lora
base: rocm_dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+594
−1
Open
Changes from 6 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
e86131b
feat: add LoRA adapter layer and Mixtral LoRA training
0c59133
fix: enable HuggingFace Mixtral 8x7B model conversion
mpashkovskii 6fb5b1f
Merge branch 'ROCm:rocm_dev' into feat/mixtral-lora
mpashkovskii 284b87d
fix: address comments from PR review
mpashkovskii bbe8201
fix: wrap embeddings layer into LoraAdapter
mpashkovskii 3c9c8a4
docs: user checkpoint and tokenizer variables in the example
mpashkovskii 9e7e6e8
tests: improve initialized weights checks
mpashkovskii 5c87cbe
feat: implement Megatron-LM Linear layer
mpashkovskii 953d61c
Merge branch 'ROCm:rocm_dev' into feat/mixtral-lora
mpashkovskii c474f0d
fix: improve training parameters
mpashkovskii File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
import os | ||
import sys | ||
from argparse import ArgumentParser | ||
|
||
sys.path.append( | ||
os.path.abspath( | ||
os.path.join( | ||
os.path.dirname(__file__), | ||
os.path.pardir, | ||
os.path.pardir | ||
) | ||
) | ||
) | ||
|
||
from megatron.core.enums import ModelType | ||
from megatron.core.models.gpt import GPTModel | ||
from megatron.core.transformer.lora_adapter import LoraAdapter | ||
from megatron.training import get_args, pretrain | ||
|
||
from pretrain_gpt import train_valid_test_datasets_provider, forward_step, model_provider | ||
|
||
|
||
def lora_model_provider(pre_process: bool = True, post_process: bool = True) -> GPTModel: | ||
args = get_args() | ||
rank = args.lora_rank | ||
alpha = args.lora_alpha | ||
assert rank > 0 and alpha > 0, "LoRA rank and alpha have to be greater than zero" | ||
|
||
model = model_provider(pre_process, post_process) | ||
common_args = { | ||
"config": model.config, | ||
"rank": rank, | ||
"alpha": alpha, | ||
"dropout": args.lora_dropout, | ||
} | ||
model.embedding.word_embeddings = LoraAdapter(model.embedding.word_embeddings, **common_args) | ||
for layer in model.decoder.layers: | ||
layer.self_attention.linear_qkv = LoraAdapter(layer.self_attention.linear_qkv, **common_args) | ||
layer.self_attention.linear_proj = LoraAdapter(layer.self_attention.linear_proj, **common_args) | ||
layer.mlp.router = LoraAdapter(layer.mlp.router, **common_args) | ||
for fc in layer.mlp.experts.local_experts: | ||
fc.linear_fc1 = LoraAdapter(fc.linear_fc1, is_expert=True, **common_args) | ||
fc.linear_fc2 = LoraAdapter(fc.linear_fc2, is_expert=True, **common_args) | ||
model.output_layer = LoraAdapter(model.output_layer, **common_args) | ||
return model | ||
|
||
|
||
def add_lora_args(parser: ArgumentParser) -> ArgumentParser: | ||
group = parser.add_argument_group(title='LoRA') | ||
group.add_argument('--lora-rank', default=16, type=int, | ||
help='LoRA rank') | ||
group.add_argument('--lora-alpha', default=32.0, type=float, | ||
help='LoRA alpha') | ||
group.add_argument('--lora-dropout', default=0.1, type=float, | ||
help='LoRA dropout') | ||
return parser | ||
|
||
|
||
if __name__ == "__main__": | ||
train_valid_test_datasets_provider.is_distributed = True | ||
pretrain( | ||
train_valid_test_datasets_provider, | ||
lora_model_provider, | ||
ModelType.encoder_or_decoder, | ||
forward_step, | ||
extra_args_provider=add_lora_args, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
#!/bin/bash | ||
|
||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
export CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
|
||
export NCCL_CHECKS_DISABLE=1 | ||
|
||
export NVTE_CK_V3_ATOMIC_FP32=0 | ||
export NVTE_CK_V3_BF16_CVT=1 | ||
export NVTE_CK_V3_SPEC=1 | ||
export NVTE_CK_USES_BWD_V3=1 | ||
|
||
export TE_HIPBLASLT_TUNING_ALGO_COUNT=50 | ||
export TE_HIPBLASLT_TUNING_RUN_COUNT=10 | ||
|
||
export TORCH_NCCL_HIGH_PRIORITY=1 | ||
|
||
GPUS_PER_NODE=8 | ||
# Change for multinode config | ||
MASTER_ADDR=${MASTER_ADDR:-"localhost"} | ||
MASTER_PORT=${MASTER_PORT:-"6000"} | ||
NNODES=${SLURM_NNODES:-"1"} | ||
NODE_RANK=${RANK:-"0"} | ||
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) | ||
|
||
CHECKPOINT_PATH=$1 | ||
TOKENIZER_MODEL=$2 | ||
DATA_PATH=$3 | ||
|
||
DISTRIBUTED_ARGS=( | ||
--nproc_per_node $GPUS_PER_NODE | ||
--nnodes $NNODES | ||
--node_rank $NODE_RANK | ||
--master_addr $MASTER_ADDR | ||
--master_port $MASTER_PORT | ||
) | ||
|
||
MODEL_ARGS=( | ||
--use-mcore-models | ||
--disable-bias-linear | ||
--seq-length 4096 | ||
--max-position-embeddings 32768 | ||
--num-layers 32 | ||
--hidden-size 4096 | ||
--ffn-hidden-size 14336 | ||
--num-attention-heads 32 | ||
--init-method-std 0.01 | ||
--attention-dropout 0.0 | ||
--hidden-dropout 0.0 | ||
--normalization RMSNorm | ||
--position-embedding-type rope | ||
--swiglu | ||
--untie-embeddings-and-output-weights | ||
--group-query-attention | ||
--num-query-groups 8 | ||
--no-masked-softmax-fusion | ||
--no-position-embedding | ||
--rotary-base 1000000 | ||
--lora-rank 16 | ||
--lora-alpha 32 | ||
) | ||
|
||
MOE_ARGS=( | ||
--num-experts 8 | ||
--moe-router-topk 2 | ||
--moe-router-load-balancing-type aux_loss | ||
--moe-aux-loss-coeff 1e-2 | ||
--moe-z-loss-coeff 1e-3 | ||
--moe-token-dispatcher-type alltoall | ||
--moe-pad-expert-input-to-capacity | ||
--moe-expert-capacity-factor 1.25 | ||
--overlap-param-gather | ||
--overlap-grad-reduce | ||
) | ||
|
||
DATA_ARGS=( | ||
--data-cache-path ~/data/cache | ||
--dataloader-type cyclic | ||
--data-path $DATA_PATH | ||
--tokenizer-model $TOKENIZER_MODEL | ||
--tokenizer-type Llama2Tokenizer | ||
) | ||
|
||
TRAINING_ARGS=( | ||
--train-iters 5000 | ||
--micro-batch-size 2 | ||
--global-batch-size 64 | ||
--lr 1e-4 | ||
--lr-decay-style cosine | ||
--min-lr 1.0e-5 | ||
--weight-decay 0.1 | ||
--clip-grad 1.0 | ||
--bf16 | ||
--no-gradient-accumulation-fusion | ||
--fp8-margin 0 | ||
--fp8-format hybrid | ||
--fp8-interval 1 | ||
--fp8-amax-history-len 1024 | ||
--fp8-amax-compute-algo max | ||
--attention-softmax-in-fp32 | ||
) | ||
|
||
MODEL_PARALLEL_ARGS=( | ||
--tensor-model-parallel-size 1 | ||
--pipeline-model-parallel-size 1 | ||
--expert-model-parallel-size 8 | ||
--use-distributed-optimizer | ||
--sequence-parallel | ||
) | ||
|
||
LOGGING_ARGS=( | ||
--eval-interval 1000 | ||
--eval-iters 10 | ||
--log-interval 1 | ||
--log-throughput | ||
--tensorboard-dir $CHECKPOINT_PATH/tensorboard | ||
--ckpt-format torch | ||
--no-save-optim | ||
--save $CHECKPOINT_PATH | ||
--save-interval 500 | ||
--exit-on-missing-checkpoint | ||
--load $CHECKPOINT_PATH | ||
--no-load-optim | ||
--no-load-rng | ||
) | ||
|
||
mkdir -p $CHECKPOINT_PATH/logs | ||
torchrun ${DISTRIBUTED_ARGS[@]} lora_mixtral.py \ | ||
${MODEL_ARGS[@]} \ | ||
${MOE_ARGS[@]} \ | ||
${DATA_ARGS[@]} \ | ||
${TRAINING_ARGS[@]} \ | ||
${MODEL_PARALLEL_ARGS[@]} \ | ||
${LOGGING_ARGS[@]} |& tee $CHECKPOINT_PATH/logs/output_`date +"%Y%m%d_%H%M"`.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
import logging | ||
import math | ||
from copy import deepcopy | ||
from functools import partial | ||
from typing import Tuple, Union | ||
|
||
import torch | ||
|
||
from megatron.core.extensions.transformer_engine import ( | ||
TEColumnParallelLinear, | ||
TELayerNormColumnParallelLinear, | ||
TELinear, | ||
TERowParallelLinear, | ||
) | ||
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear | ||
from megatron.core.transformer import TransformerConfig | ||
from megatron.core.transformer.module import MegatronModule | ||
|
||
|
||
LOGGER = logging.getLogger(__name__) | ||
KAIMING_INIT_METHOD: callable = lambda x: torch.nn.init.kaiming_uniform_(x, a=math.sqrt(5)) | ||
LORA_LAYERS_DEFAULT_CONFIG = { | ||
"bias": False, | ||
"skip_bias_add": True, | ||
} | ||
COLUMN_PARALLEL_LAYERS = [ | ||
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, parallel_mode=None, skip_weight_param_allocation=False), | ||
partial(ColumnParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_), | ||
] | ||
TE_COLUMN_PARALLEL_LAYERS = [ | ||
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, parallel_mode=None, skip_weight_param_allocation=False), | ||
wenchenvincent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
partial(TEColumnParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_, gather_output=False), | ||
] | ||
ROW_PARALLEL_LAYERS = [ | ||
partial(RowParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, input_is_parallel=True), | ||
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_, parallel_mode=None, skip_weight_param_allocation=False), | ||
] | ||
TE_ROW_PARALLEL_LAYERS = [ | ||
partial(TERowParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, input_is_parallel=True), | ||
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_, parallel_mode=None, skip_weight_param_allocation=False), | ||
] | ||
LORA_LAYERS_MAPPING = { | ||
ColumnParallelLinear: COLUMN_PARALLEL_LAYERS, | ||
TEColumnParallelLinear: TE_COLUMN_PARALLEL_LAYERS, | ||
TELayerNormColumnParallelLinear: TE_COLUMN_PARALLEL_LAYERS, | ||
wenchenvincent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
RowParallelLinear: ROW_PARALLEL_LAYERS, | ||
TERowParallelLinear: TE_ROW_PARALLEL_LAYERS, | ||
} | ||
|
||
|
||
class LoraAdapter(MegatronModule): | ||
def __init__(self, base_layer: torch.nn.Module, *, config: TransformerConfig, rank: int, alpha: float, dropout: float, is_expert: bool = False): | ||
super(LoraAdapter, self).__init__(config) | ||
self.lora_alpha = alpha | ||
self.base_layer = base_layer | ||
self.base_layer.weight.requires_grad = False | ||
wenchenvincent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.lora_a = None | ||
self.lora_b = None | ||
self.register_load_state_dict_pre_hook(self._remap_base_layer_for_training) | ||
wenchenvincent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.register_load_state_dict_post_hook(self._ignore_missing_lora_keys_for_training) | ||
|
||
base_layer_class = type(base_layer) | ||
if base_layer_class not in LORA_LAYERS_MAPPING: | ||
if torch.distributed.get_rank() == 0: | ||
LOGGER.warning(f"LoRA is not supported for {base_layer_class}. Freezing weights of {base_layer_class} but skipping addition of LoRA layers") | ||
return | ||
|
||
layer_config = { | ||
"config": config, | ||
"is_expert": is_expert, | ||
} | ||
output_size, input_size = self.base_layer.weight.shape | ||
if base_layer_class in [RowParallelLinear, TERowParallelLinear]: | ||
input_size *= config.tensor_model_parallel_size | ||
if base_layer_class in [ColumnParallelLinear, TEColumnParallelLinear, TELayerNormColumnParallelLinear]: | ||
output_size *= config.tensor_model_parallel_size | ||
lora_a_class, lora_b_class = LORA_LAYERS_MAPPING[base_layer_class] | ||
self.lora_a = lora_a_class(input_size=input_size, output_size=rank, **layer_config) | ||
self.lora_b = lora_b_class(input_size=rank, output_size=output_size, **layer_config) | ||
self.lora_dropout = torch.nn.Dropout(p=dropout, inplace=False) | ||
|
||
def _remap_base_layer_for_training(self, _: torch.nn.Module, state_dict: dict, prefix: str, *args) -> None: | ||
extra_prefix = "base_layer." | ||
keys = list(state_dict.keys()) | ||
for key in keys: | ||
# The model is already finetuned with LoRA | ||
if extra_prefix in key or "lora_" in key: | ||
continue | ||
|
||
# The model has no adapter layers | ||
new_key = key.replace(prefix, f"{prefix}{extra_prefix}") | ||
state_dict[new_key] = state_dict.pop(key) | ||
|
||
def _ignore_missing_lora_keys_for_training(self, _: torch.nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys) -> None: | ||
keys = deepcopy(incompatible_keys.missing_keys) | ||
for key in keys: | ||
if "lora_" in key: | ||
incompatible_keys.missing_keys.remove(key) | ||
|
||
def forward(self, input: torch.Tensor, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
output = self.base_layer(input, *args, **kwargs) | ||
if self.lora_a is None: | ||
return output | ||
|
||
lora_a_output_parallel, _ = self.lora_a(input) | ||
lora_b_output_parallel, _ = self.lora_b(lora_a_output_parallel) | ||
lora_dropout_output_parallel = self.lora_dropout(lora_b_output_parallel) | ||
lora_output_parallel = self.lora_alpha * lora_dropout_output_parallel | ||
|
||
if type(output) is torch.Tensor: | ||
return output + lora_output_parallel | ||
|
||
output, bias = output | ||
return output + lora_output_parallel, bias |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean
Linear
instead ofTELinear
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
te.pytorch.Linear
and especiallytorch.nn.Linear
has quite different constructor signatures whereasTELinear
constructor is aligned with ColumnParallelLinear, TEColumnParallelLinear etc and incapsulate those differences. To make the code more readable I explicitly usedTELinear
. But essentially, in this case, it is a thin wrapper aroundtorch.nn.Linear
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TELinear
is not a wrapper aroundtorch.nn.Linear
butte.pytorch.Linear
.In Megatron-LM, there are two alternative transformer implementation: local (using pytorch layers) and transformer-engine (using TE layers).
ColumnParallelLinear
uses pytorch Linear layers andTEColumnParallelLinear
uses TE Linear layers. Usually, when a model is constructed withColumnParallelLinear
, it often means that TE is not available. So here it is not appropriate to useTELinear
here.Given that we cannot use
torch.nn.Linear
directly, it seems that we will also need to create a thin wrapper aroundtorch.nn.Linear
. And this also triggers another question from me: can we use the wrapperColumnParallelLinear
again for the second lora layer here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I actually have a further question whether this would work for TP or not.
It seems that for a base layer like
TEColumnParallelLinear
, we used two LoRA layers. The first layer isTELinear
and the second layer isTEColumnParallelLinear
. Does that mean the first layer will not be sliced to different GPUs?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, indeed, I made a typo in the last sentence:
TELinear
wrapste.pytorch.Linear
.I implemented the Linear layer as a wrapper around
ColumnParallelLinear
. The main difference is that the weight output size must be non-sharded. To achieve this, I copied some code from theColumnParallelLinear
constructor.Yes, your understanding is correct: all
Linear
/TELinear
layers in theLoraAdapter
are not sliced. This is a deliberate decision:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification! While I think it might be Okay to sacrifice some memory to gain performance, I am kinda concerned about the functionality.
So in the case of TP, the weights of the first lora layer is not sliced and the weights of second lora layer is sliced across different GPUs with the same TP group. And the input data is the same across the GPUs within the same TP group. When we pass the activation of the first lora layer to the second lora layer, how do we make sure it is sliced properly? And how do we make sure that the gradient reduction and accumulation is done properly for the backward pass? In the scheme of this PR, we are doing DP for the first lora layer and TP for the second lora layer within a TP group. The combination of these two might be error prone, we will need to have tests to make sure this is implemented correctly.