Skip to content

Commit

Permalink
Add LoRA and Prefix-Tuning as Modeling Options for Improved Memory Ef…
Browse files Browse the repository at this point in the history
…ficiency + performance (potentially) (#2840)

This PR adds LoRA and prefix-tuning as modelling options (training and
sampling code).

Both have shown strong performance and can outperform fine-tuning. They
also can protect against the catastrophic forgetting problem which is
important for chatbots. They keep the whole language model frozen so
they can be distributed freely independent of the base language model.

They also allow much more memory-efficient training as there is no need
for the optimizer states of the base model.

Benefits of LoRA
-  Can run only DS 2 with 30B.
- Run 65B without intense CPU usage
- Less overfitting, maintains performance on datasets. 
- Only need to share/push to hub the OS component (small file in Mbs)

— See [Andrej Karpathy
(OpenAI)](https://twitter.com/karpathy/status/1649127655122550784?s=20)
comment
— See purported [google
leak](https://www.semianalysis.com/i/119223672/lora-is-an-incredibly-powerful-technique-we-should-probably-be-paying-more-attention-to)


Implementation Details:
- Explicitly set input_ids as a keyword argument in the sampling code to
ensure proper functionality with PEFT `generate.`
- Manually enable gradients for input to leverage gradient checkpointing
effectively, as frozen embeddings would otherwise prevent gradient
attachment.
- Include saving code from `pytorch_model.bin` for special tokens.
Although these tokens are randomly initialized, they must be stored and
saved as an external module since PEFT parameters learn to utilize them.
Making them trainable is an option, but it is unlikely to make a
significant difference.
- Integrate prefix-tuning, a powerful technique that involves stacking
keys and values, by incorporating custom LLama modelling code. (Training
speed lower than LoRA in my initial tests and performance
similar/worse).

---------

Co-authored-by: jordiclive <jordiclive19@imperial.ac.uk>
  • Loading branch information
jordiclive and jordiclive authored May 26, 2023
1 parent 0d610a9 commit e059d86
Show file tree
Hide file tree
Showing 8 changed files with 1,220 additions and 11 deletions.
10 changes: 8 additions & 2 deletions model/model_eval/manual/sampling_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pydantic
import torch
from model_training.models.peft_modeling import load_peft_model
from tqdm import tqdm
from transformers import AutoTokenizer, PreTrainedTokenizer

Expand Down Expand Up @@ -132,9 +133,9 @@ def sample(
).to(device)
input_ids = inputs.input_ids
outputs = model.generate(
input_ids,
**sampling_params,
input_ids=input_ids,
pad_token_id=tokenizer.eos_token_id,
**sampling_params,
)
if skip_input_tokens:
output_tokens = outputs[0, input_ids.size(1) :]
Expand Down Expand Up @@ -255,6 +256,7 @@ def parse_args():
parser.add_argument("--max-input-len", type=int, help="max token counts for input")
parser.add_argument("--auth-token", type=str)
parser.add_argument("--num-threads", type=int, default=8)
parser.add_argument("--peft_model", type=str, default=None)

return parser.parse_args()

Expand Down Expand Up @@ -314,6 +316,10 @@ def main():
else:
raise RuntimeError("Invalid model_type specified")

if args.peft_model is not None:
tokenizer = AutoTokenizer.from_pretrained(args.peft_model)
model = load_peft_model(model, args.peft_model, tokenizer)

print("special_tokens_map:", tokenizer.special_tokens_map)
print(f"eos_token='{tokenizer.eos_token}', eos_token_id={tokenizer.eos_token_id}")

Expand Down
74 changes: 74 additions & 0 deletions model/model_training/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ defaults:
per_digit_tokens: false
is_reward_model: false
deepspeed_config: configs/zero_config.json
peft_model: false
peft_type: "lora"

use_system_tag:
use_system_tag: True
Expand Down Expand Up @@ -442,6 +444,78 @@ llama-30b-pretrain:
num_train_epochs: 1
save_total_limit: 2

lora-llama-13b:
dtype: fp16
log_dir: "llama_lora_log_13b"
learning_rate: 5e-5
model_name: /home/ubuntu/llama_hf/13B
output_dir: llama_model_13b_lora
weight_decay: 0.0
max_length: 2048
warmup_steps: 300
gradient_checkpointing: true
gradient_accumulation_steps: 1
per_device_train_batch_size: 24
per_device_eval_batch_size: 5
eval_steps: 500
num_train_epochs: 12
save_total_limit: 2
save_strategy: epoch
use_flash_attention: True
residual_dropout: 0.0
deepspeed_config: configs/zero_config.json
peft_model: true
peft_type: "lora"
use_custom_sampler: true

lora-llama-30b:
dtype: fp16
log_dir: "llama_lora_log_30b"
learning_rate: 5e-5
model_name: /home/ubuntu/llama_hf/30B
output_dir: llama_model_30b_lora
weight_decay: 0.0
max_length: 2048
warmup_steps: 300
gradient_checkpointing: true
gradient_accumulation_steps: 1
per_device_train_batch_size: 4
per_device_eval_batch_size: 2
eval_steps: 500
num_train_epochs: 12
save_total_limit: 2
save_strategy: epoch
use_flash_attention: True
residual_dropout: 0.0
deepspeed_config: configs/zero_config.json
peft_model: true
peft_type: "lora"
use_custom_sampler: true

lora-llama-65b:
dtype: fp16
log_dir: "llama_lora_log_65b"
learning_rate: 5e-5
model_name: /home/ubuntu/llama_hf/65B
output_dir: llama_model_65b_lora
weight_decay: 0.0
max_length: 2048
warmup_steps: 300
gradient_checkpointing: true
gradient_accumulation_steps: 1
per_device_train_batch_size: 12
per_device_eval_batch_size: 5
eval_steps: 250
num_train_epochs: 12
save_total_limit: 2
save_strategy: epoch
use_flash_attention: True
residual_dropout: 0.0
deepspeed_config: configs/zero_config_sft_65b.json
peft_model: true
peft_type: "lora"
use_custom_sampler: true

pythia-70m-deduped:
learning_rate: 8e-6
# model_name: EleutherAI/pythia-1b-deduped
Expand Down
58 changes: 58 additions & 0 deletions model/model_training/configs/zero_config_sft_65b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 2e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 2e9,
"stage3_max_reuse_distance": 2e9,
"stage3_gather_16bit_weights_on_model_save": true,
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
101 changes: 101 additions & 0 deletions model/model_training/models/peft_modeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from dataclasses import dataclass
from pathlib import Path

import torch
from huggingface_hub import hf_hub_download
from model_training.utils.utils import get_model, get_tokenizer
from peft import LoraConfig, PeftModel, PrefixTuningConfig, get_peft_model, prepare_model_for_int8_training


def load_peft_model(model, peft_model_path, tokenizer):
model.resize_token_embeddings(len(tokenizer))
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model = PeftModel.from_pretrained(
model,
peft_model_path,
torch_dtype=model.dtype,
)
model.eos_token_id = tokenizer.eos_token_id
extra_embeds = hf_hub_download(peft_model_path, "extra_embeddings.pt")
embed_weights = torch.load(extra_embeds, map_location=model.device)
model.base_model.model.model.embed_tokens.weight[len(tokenizer) - embed_weights.shape[0] :, :] = embed_weights.to(
model.base_model.model.model.embed_tokens.weight.dtype
)
return model


def prepare_model_for_gradient_checkpointing(model):
r"""
Prepares the model for gradient checkpointing if necessary
"""
if not getattr(model, "is_loaded_in_8bit", False):
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:

def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model


def peft_model(model, peft_type="lora", int8_training=False, gradient_checkpointing=False):
if peft_type == "lora":
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
elif peft_type == "prefix-tuning":
config = PrefixTuningConfig(
num_virtual_tokens=30, prefix_projection=True, encoder_hidden_size=1024, task_type="CAUSAL_LM"
)
else:
raise ValueError("peft_method config is lora or prefix-tuning")
model = get_peft_model(model, config)
if int8_training:
model = prepare_model_for_int8_training(model)

if gradient_checkpointing:
model = prepare_model_for_gradient_checkpointing(model)
model.print_trainable_parameters()
return model


@dataclass
class SaveLoraConfig:
dtype: torch.dtype = torch.float16
is_reward_model: bool = False
quantization: bool = False
seq2seqmodel: bool = False
freeze_layer: bool = False
residual_dropout: float = 0
use_flash_attention: bool = False
adapter_save_path: str = "adapter"
cache_dir: str = ""
model_name: str = ""
torch_ckpt_path: str = ""
peft_type: str = "lora"


def save_adapter_model_from_ckpt(save_config: SaveLoraConfig):
tokenizer = get_tokenizer(save_config)
model = get_model(save_config, tokenizer)
model = peft_model(model)
model.load_state_dict(torch.load(save_config.torch_ckpt_path))
vocab_size = tokenizer.vocab_size
num_special_tokens = len(tokenizer.additional_special_tokens)

new_embs = model.state_dict()["base_model.model.model.embed_tokens.weight"][
vocab_size : vocab_size + num_special_tokens, :
].clone()
new_embs = new_embs.to(save_config.dtype)
model.save_pretrained(save_config.adapter_save_path, torch_dtype=save_config.dtype)
tokenizer.save_pretrained(save_config.adapter_save_path)
torch.save(new_embs, Path(save_config.adapter_save_path).joinpath("extra_embeddings.pt"))
Loading

0 comments on commit e059d86

Please sign in to comment.