Skip to content

Commit

Permalink
feat: add LoRA adapter layer and Mixtral LoRA training
Browse files Browse the repository at this point in the history
  • Loading branch information
Matvei Pashkovskii committed Feb 3, 2025
1 parent fe353fd commit 4b5b3fc
Show file tree
Hide file tree
Showing 6 changed files with 451 additions and 1 deletion.
16 changes: 16 additions & 0 deletions examples/mixtral/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,22 @@ docker run \

The above functionality also applys to Mixtral 8x22B actually, you should set the model config (including hidden_size/head_num/num_layers/ffn_hidden_size) properly according to the original [config](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1/blob/main/config.json).

## LoRA from pretrained Mixtral 8x7B

By default `lora_mixtral_8x7b_distributed.sh` uses TP1PP1EP8 parallelizm. To run LoRA from pretrained Mixtral 8x7B, use the following command:

```bash
bash lora_mixtral_8x7b_distributed.sh \
/workspace/checkpoints/mixtral-mcore-TP1PP1EP8 \
/workspace/checkpoints/mixtral-hf/tokenizer.model \
path/to/data
```


bash lora_mixtral_8x7b_distributed.sh ~/models/megatron-lm/mistralai/Mixtral-8x7B-v0.1/ep8 ~/models/models--mistralai--Mixtral-8x7B-v0.1/snapshots/ffe1a706bacbd5abddc5ff99432ee38f7e0662fb/tokenizer.model ~/data/instructions/mistralai/Mixtral-8x7B-v0.1/dataset_input_document

To run LoRA with another parallelizm setting convert model to the Megatron-LM checkpoint with target paralleizm and modify `MODEL_PARALLEL_ARGS` arguments in `lora_mixtral_8x7b_distributed.sh`.

## Acknowledgements
Contributors outside NVIDIA for the huggingface converter and example of Mixtral models in Megatron-Core:
- Peng Li <jerry.lp@alibaba-inc.com>
Expand Down
71 changes: 71 additions & 0 deletions examples/mixtral/lora_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.



import os
import sys
from argparse import ArgumentParser

sys.path.append(
os.path.abspath(
os.path.join(
os.path.dirname(__file__),
os.path.pardir,
os.path.pardir
)
)
)

from megatron.core.enums import ModelType
from megatron.core.models.gpt import GPTModel
from megatron.core.transformer.lora_adapter import LoraAdapter
from megatron.training import get_args, pretrain

from pretrain_gpt import train_valid_test_datasets_provider, forward_step, model_provider


def lora_model_provider(pre_process: bool = True, post_process: bool = True) -> GPTModel:
args = get_args()
rank = args.lora_rank
alpha = args.lora_alpha
assert rank > 0 and alpha > 0, "LoRA rank and alpha have to be greater then zero"
assert not args.moe_grouped_gemm, "MoE grouped GEMM is not supported"

model = model_provider(pre_process, post_process)
common_args = {
"config": model.config,
"rank": rank,
"alpha": alpha,
"dropout": args.lora_dropout,
}
for layer in model.decoder.layers:
layer.self_attention.linear_qkv = LoraAdapter(layer.self_attention.linear_qkv, **common_args)
layer.self_attention.linear_proj = LoraAdapter(layer.self_attention.linear_proj, **common_args)
layer.mlp.router = LoraAdapter(layer.mlp.router, **common_args)
for fc in layer.mlp.experts.local_experts:
fc.linear_fc1 = LoraAdapter(fc.linear_fc1, is_expert=True, **common_args)
fc.linear_fc2 = LoraAdapter(fc.linear_fc2, is_expert=True, **common_args)
model.output_layer = LoraAdapter(model.output_layer, **common_args)
return model


def add_lora_args(parser: ArgumentParser) -> ArgumentParser:
group = parser.add_argument_group(title='LoRA')
group.add_argument('--lora-rank', default=16, type=int,
help='LoRA rank')
group.add_argument('--lora-alpha', default=32.0, type=float,
help='LoRA alpha')
group.add_argument('--lora-dropout', default=0.1, type=float,
help='LoRA dropout')
return parser


if __name__ == "__main__":
train_valid_test_datasets_provider.is_distributed = True
pretrain(
train_valid_test_datasets_provider,
lora_model_provider,
ModelType.encoder_or_decoder,
forward_step,
extra_args_provider=add_lora_args,
)
138 changes: 138 additions & 0 deletions examples/mixtral/lora_mixtral_8x7b_distributed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#!/bin/bash

# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

export CUDA_DEVICE_MAX_CONNECTIONS=1

export NCCL_CHECKS_DISABLE=1

export NVTE_CK_V3_ATOMIC_FP32=0
export NVTE_CK_V3_BF16_CVT=1
export NVTE_CK_V3_SPEC=1
export NVTE_CK_USES_BWD_V3=1

export TE_HIPBLASLT_TUNING_ALGO_COUNT=50
export TE_HIPBLASLT_TUNING_RUN_COUNT=10

export TORCH_NCCL_HIGH_PRIORITY=1

GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-"6000"}
NNODES=${SLURM_NNODES:-"1"}
NODE_RANK=${RANK:-"0"}
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

CHECKPOINT_PATH=$1
TOKENIZER_MODEL=$2
DATA_PATH=$3

DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NNODES
--node_rank $NODE_RANK
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)

MODEL_ARGS=(
--use-mcore-models
--disable-bias-linear
--seq-length 4096
--max-position-embeddings 32768
--num-layers 32
--hidden-size 4096
--ffn-hidden-size 14336
--num-attention-heads 32
--init-method-std 0.01
--attention-dropout 0.0
--hidden-dropout 0.0
--normalization RMSNorm
--position-embedding-type rope
--swiglu
--untie-embeddings-and-output-weights
--group-query-attention
--num-query-groups 8
--no-masked-softmax-fusion
--no-position-embedding
--rotary-base 1000000
--lora-rank 16
--lora-alpha 32
)

MOE_ARGS=(
--num-experts 8
--moe-router-topk 2
--moe-router-load-balancing-type aux_loss
--moe-aux-loss-coeff 1e-2
--moe-z-loss-coeff 1e-3
--moe-token-dispatcher-type alltoall
--moe-pad-expert-input-to-capacity
--moe-expert-capacity-factor 1.25
--overlap-param-gather
--overlap-grad-reduce
)

DATA_ARGS=(
--data-cache-path ~/data/cache
--data-path $DATA_PATH
--tokenizer-model $TOKENIZER_MODEL
--tokenizer-type Llama2Tokenizer
# --split 99990,8,2
)

TRAINING_ARGS=(
--train-iters 5000
--micro-batch-size 2
--global-batch-size 64
--lr 1e-4
# --lr-decay-iters 320000
--lr-decay-style cosine
# --lr-warmup-iters 500
--min-lr 1.0e-5
--weight-decay 0.1
--clip-grad 1.0
--bf16
--no-gradient-accumulation-fusion
--fp8-margin 0
--fp8-format hybrid
--fp8-interval 1
--fp8-amax-history-len 1024
--fp8-amax-compute-algo max
--attention-softmax-in-fp32
)

MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--expert-model-parallel-size 8
--use-distributed-optimizer
--sequence-parallel
)

LOGGING_ARGS=(
--eval-interval 1000
--eval-iters 10
--log-interval 1
--log-throughput
--tensorboard-dir $CHECKPOINT_PATH/tensorboard

# --save-interval 500
# --no-save-optim
# --save $CHECKPOINT_PATH

--exit-on-missing-checkpoint
--load $CHECKPOINT_PATH
--no-load-optim \
--no-load-rng
)

mkdir -p $CHECKPOINT_PATH/logs
torchrun ${DISTRIBUTED_ARGS[@]} examples/mixtral/lora_mixtral.py \
${MODEL_ARGS[@]} \
${MOE_ARGS[@]} \
${DATA_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${LOGGING_ARGS[@]} |& tee $CHECKPOINT_PATH/logs/output_`date +"%Y%m%d_%H%M"`.log
108 changes: 108 additions & 0 deletions megatron/core/transformer/lora_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

import logging
import math
from copy import deepcopy
from functools import partial
from typing import Tuple

import torch

from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TELayerNormColumnParallelLinear,
TELinear,
TERowParallelLinear,
)
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.module import MegatronModule


LOGGER = logging.getLogger(__name__)
KAIMING_INIT_METHOD: callable = lambda x: torch.nn.init.kaiming_uniform_(x, a=math.sqrt(5))
LORA_LAYERS_DEFAULT_CONFIG = {
"bias": False,
"skip_bias_add": True,
}
COLUMN_PARALLEL_LAYERS = [
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, parallel_mode=None, skip_weight_param_allocation=False),
partial(ColumnParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_),
]
TE_COLUMN_PARALLEL_LAYERS = [
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, parallel_mode=None, skip_weight_param_allocation=False),
partial(TEColumnParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_, gather_output=False),
]
LORA_LAYERS_MAPPING = {
ColumnParallelLinear: COLUMN_PARALLEL_LAYERS,
TEColumnParallelLinear: TE_COLUMN_PARALLEL_LAYERS,
TELayerNormColumnParallelLinear: TE_COLUMN_PARALLEL_LAYERS,
RowParallelLinear: [
partial(RowParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, input_is_parallel=True),
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_, parallel_mode=None, skip_weight_param_allocation=False),
],
TERowParallelLinear: [
partial(TERowParallelLinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, input_is_parallel=True),
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=torch.nn.init.zeros_, parallel_mode=None, skip_weight_param_allocation=False),
],
}


class LoraAdapter(MegatronModule):
def __init__(self, base_layer: torch.nn.Module, *, config: TransformerConfig, rank: int, alpha: float, dropout: float, is_expert: bool = False):
super(LoraAdapter, self).__init__(config)
self.lora_alpha = alpha
self.base_layer = base_layer
self.base_layer.weight.requires_grad = False
self.lora_a = None
self.lora_b = None
self.register_load_state_dict_pre_hook(self._remap_base_layer_for_training)
self.register_load_state_dict_post_hook(self._ignore_missing_lora_keys_for_training)

base_layer_class = type(base_layer)
if base_layer_class not in LORA_LAYERS_MAPPING:
if torch.distributed.get_rank() == 0:
LOGGER.warning(f"LoRA is not supported for {base_layer_class}. Freezing weights of {base_layer_class} but skipping addition of LoRA layers")
return

layer_config = {
"config": config,
"is_expert": is_expert,
}
output_size, input_size = self.base_layer.weight.shape
if base_layer_class in [RowParallelLinear, TERowParallelLinear]:
input_size *= config.tensor_model_parallel_size
if base_layer_class in [ColumnParallelLinear, TEColumnParallelLinear, TELayerNormColumnParallelLinear]:
output_size *= config.tensor_model_parallel_size
lora_a_class, lora_b_class = LORA_LAYERS_MAPPING[base_layer_class]
self.lora_a = lora_a_class(input_size=input_size, output_size=rank, **layer_config)
self.lora_b = lora_b_class(input_size=rank, output_size=output_size, **layer_config)
self.lora_dropout = torch.nn.Dropout(p=dropout, inplace=False)

def _remap_base_layer_for_training(self, _: torch.nn.Module, state_dict: dict, prefix: str, *args) -> None:
extra_prefix = "base_layer."
keys = list(state_dict.keys())
for key in keys:
# The model is already finetuned with LoRA
if extra_prefix in key or "lora_" in key:
continue

# The model has no adapter layers
new_key = key.replace(prefix, f"{prefix}{extra_prefix}")
state_dict[new_key] = state_dict.pop(key)

def _ignore_missing_lora_keys_for_training(self, _: torch.nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys) -> None:
keys = deepcopy(incompatible_keys.missing_keys)
for key in keys:
if "lora_" in key:
incompatible_keys.missing_keys.remove(key)

def forward(self, input: torch.Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
output, bias = self.base_layer(input, *args, **kwargs)
if self.lora_a is None:
return output, bias

lora_a_output_parallel, _ = self.lora_a(input)
lora_b_output_parallel, _ = self.lora_b(lora_a_output_parallel)
lora_dropout_output_parallel = self.lora_dropout(lora_b_output_parallel)
return output + self.lora_alpha * lora_dropout_output_parallel, bias
2 changes: 1 addition & 1 deletion megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def _load_base_checkpoint(
else:
checkpoint_name = get_checkpoint_name(load_dir, iteration, release, return_base_dir=False)
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
except ModuleNotFoundError:
from megatron.legacy.fp16_deprecated import loss_scaler

Expand Down
Loading

0 comments on commit 4b5b3fc

Please sign in to comment.