Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add LoRA adapter layer and Mixtral LoRA training #53

Open
wants to merge 10 commits into
base: rocm_dev
Choose a base branch
from
13 changes: 13 additions & 0 deletions examples/mixtral/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ docker run \

The above functionality also applys to Mixtral 8x22B actually, you should set the model config (including hidden_size/head_num/num_layers/ffn_hidden_size) properly according to the original [config](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1/blob/main/config.json).

## LoRA from pretrained Mixtral 8x7B

By default `lora_mixtral_8x7b_distributed.sh` uses TP1PP1EP8 parallelizm. To run LoRA from pretrained Mixtral 8x7B, use the following command:

```bash
bash lora_mixtral_8x7b_distributed.sh \
$CHECKPOINT_PATH \
$TOKENIZER_MODEL \
path/to/data
```

To run LoRA with another parallelizm setting convert model to the Megatron-LM checkpoint with target parallelism and modify `MODEL_PARALLEL_ARGS` arguments in `lora_mixtral_8x7b_distributed.sh`.

## Acknowledgements
Contributors outside NVIDIA for the huggingface converter and example of Mixtral models in Megatron-Core:
- Peng Li <jerry.lp@alibaba-inc.com>
Expand Down
69 changes: 69 additions & 0 deletions examples/mixtral/lora_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

import os
import sys
from argparse import ArgumentParser

sys.path.append(
os.path.abspath(
os.path.join(
os.path.dirname(__file__),
os.path.pardir,
os.path.pardir
)
)
)

from megatron.core.enums import ModelType
from megatron.core.models.gpt import GPTModel
from megatron.core.transformer.lora_adapter import LoraAdapter
from megatron.training import get_args, pretrain

from pretrain_gpt import train_valid_test_datasets_provider, forward_step, model_provider


def lora_model_provider(pre_process: bool = True, post_process: bool = True) -> GPTModel:
args = get_args()
rank = args.lora_rank
alpha = args.lora_alpha
assert rank > 0 and alpha > 0, "LoRA rank and alpha have to be greater than zero"

model = model_provider(pre_process, post_process)
common_args = {
"config": model.config,
"rank": rank,
"alpha": alpha,
"dropout": args.lora_dropout,
}
model.embedding.word_embeddings = LoraAdapter(model.embedding.word_embeddings, **common_args)
for layer in model.decoder.layers:
layer.self_attention.linear_qkv = LoraAdapter(layer.self_attention.linear_qkv, **common_args)
layer.self_attention.linear_proj = LoraAdapter(layer.self_attention.linear_proj, **common_args)
layer.mlp.router = LoraAdapter(layer.mlp.router, **common_args)
for fc in layer.mlp.experts.local_experts:
fc.linear_fc1 = LoraAdapter(fc.linear_fc1, is_expert=True, **common_args)
fc.linear_fc2 = LoraAdapter(fc.linear_fc2, is_expert=True, **common_args)
model.output_layer = LoraAdapter(model.output_layer, **common_args)
return model


def add_lora_args(parser: ArgumentParser) -> ArgumentParser:
group = parser.add_argument_group(title='LoRA')
group.add_argument('--lora-rank', default=16, type=int,
help='LoRA rank')
group.add_argument('--lora-alpha', default=32.0, type=float,
help='LoRA alpha')
group.add_argument('--lora-dropout', default=0.1, type=float,
help='LoRA dropout')
return parser


if __name__ == "__main__":
train_valid_test_datasets_provider.is_distributed = True
pretrain(
train_valid_test_datasets_provider,
lora_model_provider,
ModelType.encoder_or_decoder,
forward_step,
extra_args_provider=add_lora_args,
)
135 changes: 135 additions & 0 deletions examples/mixtral/lora_mixtral_8x7b_distributed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#!/bin/bash

# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

export CUDA_DEVICE_MAX_CONNECTIONS=1

export NCCL_CHECKS_DISABLE=1

export NVTE_CK_V3_ATOMIC_FP32=0
export NVTE_CK_V3_BF16_CVT=1
export NVTE_CK_V3_SPEC=1
export NVTE_CK_USES_BWD_V3=1

export TE_HIPBLASLT_TUNING_ALGO_COUNT=50
export TE_HIPBLASLT_TUNING_RUN_COUNT=10

export TORCH_NCCL_HIGH_PRIORITY=1

GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-"6000"}
NNODES=${SLURM_NNODES:-"1"}
NODE_RANK=${RANK:-"0"}
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

CHECKPOINT_PATH=$1
TOKENIZER_MODEL=$2
DATA_PATH=$3

DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NNODES
--node_rank $NODE_RANK
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)

MODEL_ARGS=(
--use-mcore-models
--disable-bias-linear
--seq-length 4096
--max-position-embeddings 32768
--num-layers 32
--hidden-size 4096
--ffn-hidden-size 14336
--num-attention-heads 32
--init-method-std 0.01
--attention-dropout 0.0
--hidden-dropout 0.0
--normalization RMSNorm
--position-embedding-type rope
--swiglu
--untie-embeddings-and-output-weights
--group-query-attention
--num-query-groups 8
--no-masked-softmax-fusion
--no-position-embedding
--rotary-base 1000000
--lora-rank 16
--lora-alpha 32
)

MOE_ARGS=(
--num-experts 8
--moe-router-topk 2
--moe-router-load-balancing-type aux_loss
--moe-aux-loss-coeff 1e-2
--moe-z-loss-coeff 1e-3
--moe-token-dispatcher-type alltoall
--moe-pad-expert-input-to-capacity
--moe-expert-capacity-factor 1.25
--overlap-param-gather
--overlap-grad-reduce
)

DATA_ARGS=(
--data-cache-path ~/data/cache
--dataloader-type cyclic
--data-path $DATA_PATH
--tokenizer-model $TOKENIZER_MODEL
--tokenizer-type Llama2Tokenizer
)

TRAINING_ARGS=(
--train-iters 5000
--micro-batch-size 2
--global-batch-size 64
--lr 1e-4
--lr-decay-style cosine
--min-lr 1.0e-5
--weight-decay 0.1
--clip-grad 1.0
--bf16
--no-gradient-accumulation-fusion
--fp8-margin 0
--fp8-format hybrid
--fp8-interval 1
--fp8-amax-history-len 1024
--fp8-amax-compute-algo max
--attention-softmax-in-fp32
)

MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--expert-model-parallel-size 8
--use-distributed-optimizer
--sequence-parallel
)

LOGGING_ARGS=(
--eval-interval 1000
--eval-iters 10
--log-interval 1
--log-throughput
--tensorboard-dir $CHECKPOINT_PATH/tensorboard
--ckpt-format torch
--no-save-optim
--save $CHECKPOINT_PATH
--save-interval 500
--exit-on-missing-checkpoint
--load $CHECKPOINT_PATH
--no-load-optim
--no-load-rng
)

mkdir -p $CHECKPOINT_PATH/logs
torchrun ${DISTRIBUTED_ARGS[@]} lora_mixtral.py \
${MODEL_ARGS[@]} \
${MOE_ARGS[@]} \
${DATA_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${LOGGING_ARGS[@]} |& tee $CHECKPOINT_PATH/logs/output_`date +"%Y%m%d_%H%M"`.log
116 changes: 116 additions & 0 deletions megatron/core/transformer/lora_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

import logging
import math
from copy import deepcopy
from functools import partial
from typing import Tuple, Union

import torch

from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TELayerNormColumnParallelLinear,
TELinear,
TERowParallelLinear,
)
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.module import MegatronModule


LOGGER = logging.getLogger(__name__)
KAIMING_INIT_METHOD: callable = lambda x: torch.nn.init.kaiming_uniform_(x, a=math.sqrt(5))
LORA_LAYERS_DEFAULT_CONFIG = {
"bias": False,
"skip_bias_add": True,
}
COLUMN_PARALLEL_LAYERS = [
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, parallel_mode=None, skip_weight_param_allocation=False),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean Linear instead of TELinear here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

te.pytorch.Linear and especially torch.nn.Linear has quite different constructor signatures whereas TELinear constructor is aligned with ColumnParallelLinear, TEColumnParallelLinear etc and incapsulate those differences. To make the code more readable I explicitly used TELinear. But essentially, in this case, it is a thin wrapper around torch.nn.Linear.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TELinear is not a wrapper around torch.nn.Linear but te.pytorch.Linear.

In Megatron-LM, there are two alternative transformer implementation: local (using pytorch layers) and transformer-engine (using TE layers). ColumnParallelLinear uses pytorch Linear layers and TEColumnParallelLinear uses TE Linear layers. Usually, when a model is constructed with ColumnParallelLinear, it often means that TE is not available. So here it is not appropriate to use TELinear here.

Given that we cannot use torch.nn.Linear directly, it seems that we will also need to create a thin wrapper around torch.nn.Linear. And this also triggers another question from me: can we use the wrapper ColumnParallelLinear again for the second lora layer here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I actually have a further question whether this would work for TP or not.

It seems that for a base layer like TEColumnParallelLinear, we used two LoRA layers. The first layer is TELinear and the second layer is TEColumnParallelLinear. Does that mean the first layer will not be sliced to different GPUs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed, I made a typo in the last sentence: TELinear wraps te.pytorch.Linear.

I implemented the Linear layer as a wrapper around ColumnParallelLinear. The main difference is that the weight output size must be non-sharded. To achieve this, I copied some code from the ColumnParallelLinear constructor.

Yes, your understanding is correct: all Linear/TELinear layers in the LoraAdapter are not sliced. This is a deliberate decision:

  • For TP, we sacrifice some memory to gain performance. Using a different approach would introduce approximately five additional inter-GPU calls per LoraAdapter.
  • For EP+PP, which, as we observed, is the most performant training configuration for MoE models, no layers in the model are sliced.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification! While I think it might be Okay to sacrifice some memory to gain performance, I am kinda concerned about the functionality.

So in the case of TP, the weights of the first lora layer is not sliced and the weights of second lora layer is sliced across different GPUs with the same TP group. And the input data is the same across the GPUs within the same TP group. When we pass the activation of the first lora layer to the second lora layer, how do we make sure it is sliced properly? And how do we make sure that the gradient reduction and accumulation is done properly for the backward pass? In the scheme of this PR, we are doing DP for the first lora layer and TP for the second lora layer within a TP group. The combination of these two might be error prone, we will need to have tests to make sure this is implemented correctly.

partial(ColumnParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_),
]
TE_COLUMN_PARALLEL_LAYERS = [
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, parallel_mode=None, skip_weight_param_allocation=False),
wenchenvincent marked this conversation as resolved.
Show resolved Hide resolved
partial(TEColumnParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_, gather_output=False),
]
ROW_PARALLEL_LAYERS = [
partial(RowParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, input_is_parallel=True),
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_, parallel_mode=None, skip_weight_param_allocation=False),
]
TE_ROW_PARALLEL_LAYERS = [
partial(TERowParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, input_is_parallel=True),
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_, parallel_mode=None, skip_weight_param_allocation=False),
]
LORA_LAYERS_MAPPING = {
ColumnParallelLinear: COLUMN_PARALLEL_LAYERS,
TEColumnParallelLinear: TE_COLUMN_PARALLEL_LAYERS,
TELayerNormColumnParallelLinear: TE_COLUMN_PARALLEL_LAYERS,
wenchenvincent marked this conversation as resolved.
Show resolved Hide resolved
RowParallelLinear: ROW_PARALLEL_LAYERS,
TERowParallelLinear: TE_ROW_PARALLEL_LAYERS,
}


class LoraAdapter(MegatronModule):
def __init__(self, base_layer: torch.nn.Module, *, config: TransformerConfig, rank: int, alpha: float, dropout: float, is_expert: bool = False):
super(LoraAdapter, self).__init__(config)
self.lora_alpha = alpha
self.base_layer = base_layer
self.base_layer.weight.requires_grad = False
wenchenvincent marked this conversation as resolved.
Show resolved Hide resolved
self.lora_a = None
self.lora_b = None
self.register_load_state_dict_pre_hook(self._remap_base_layer_for_training)
wenchenvincent marked this conversation as resolved.
Show resolved Hide resolved
self.register_load_state_dict_post_hook(self._ignore_missing_lora_keys_for_training)

base_layer_class = type(base_layer)
if base_layer_class not in LORA_LAYERS_MAPPING:
if torch.distributed.get_rank() == 0:
LOGGER.warning(f"LoRA is not supported for {base_layer_class}. Freezing weights of {base_layer_class} but skipping addition of LoRA layers")
return

layer_config = {
"config": config,
"is_expert": is_expert,
}
output_size, input_size = self.base_layer.weight.shape
if base_layer_class in [RowParallelLinear, TERowParallelLinear]:
input_size *= config.tensor_model_parallel_size
if base_layer_class in [ColumnParallelLinear, TEColumnParallelLinear, TELayerNormColumnParallelLinear]:
output_size *= config.tensor_model_parallel_size
lora_a_class, lora_b_class = LORA_LAYERS_MAPPING[base_layer_class]
self.lora_a = lora_a_class(input_size=input_size, output_size=rank, **layer_config)
self.lora_b = lora_b_class(input_size=rank, output_size=output_size, **layer_config)
self.lora_dropout = torch.nn.Dropout(p=dropout, inplace=False)

def _remap_base_layer_for_training(self, _: torch.nn.Module, state_dict: dict, prefix: str, *args) -> None:
extra_prefix = "base_layer."
keys = list(state_dict.keys())
for key in keys:
# The model is already finetuned with LoRA
if extra_prefix in key or "lora_" in key:
continue

# The model has no adapter layers
new_key = key.replace(prefix, f"{prefix}{extra_prefix}")
state_dict[new_key] = state_dict.pop(key)

def _ignore_missing_lora_keys_for_training(self, _: torch.nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys) -> None:
keys = deepcopy(incompatible_keys.missing_keys)
for key in keys:
if "lora_" in key:
incompatible_keys.missing_keys.remove(key)

def forward(self, input: torch.Tensor, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
output = self.base_layer(input, *args, **kwargs)
if self.lora_a is None:
return output

lora_a_output_parallel, _ = self.lora_a(input)
lora_b_output_parallel, _ = self.lora_b(lora_a_output_parallel)
lora_dropout_output_parallel = self.lora_dropout(lora_b_output_parallel)
lora_output_parallel = self.lora_alpha * lora_dropout_output_parallel

if type(output) is torch.Tensor:
return output + lora_output_parallel

output, bias = output
return output + lora_output_parallel, bias
2 changes: 1 addition & 1 deletion megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def _load_base_checkpoint(
else:
checkpoint_name = get_checkpoint_name(load_dir, iteration, release, return_base_dir=False)
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
wenchenvincent marked this conversation as resolved.
Show resolved Hide resolved
except ModuleNotFoundError:
from megatron.legacy.fp16_deprecated import loss_scaler

Expand Down
Loading