Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA and Prefix-Tuning as Modeling Options for Improved Memory Efficiency + performance (potentially) #2840

Merged
merged 8 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions model/model_eval/manual/sampling_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pydantic
import torch
from model_training.models.peft_modeling import load_peft_model
from tqdm import tqdm
from transformers import AutoTokenizer, PreTrainedTokenizer

Expand Down Expand Up @@ -115,9 +116,9 @@ def sample(
).to(device)
input_ids = inputs.input_ids
outputs = model.generate(
input_ids,
**sampling_params,
input_ids=input_ids,
pad_token_id=tokenizer.eos_token_id,
**sampling_params,
)
if skip_input_tokens:
output_tokens = outputs[0, input_ids.size(1) :]
Expand Down Expand Up @@ -232,6 +233,7 @@ def parse_args():
parser.add_argument("--max-input-len", type=int, help="max token counts for input")
parser.add_argument("--auth-token", type=str)
parser.add_argument("--num-threads", type=int, default=8)
parser.add_argument("--peft_model", type=str, default=None)

return parser.parse_args()

Expand Down Expand Up @@ -291,6 +293,10 @@ def main():
else:
raise RuntimeError("Invalid model_type specified")

if args.peft_model is not None:
tokenizer = AutoTokenizer.from_pretrained(args.peft_model)
model = load_peft_model(model, args.peft_model, tokenizer)

print("special_tokens_map:", tokenizer.special_tokens_map)
print(f"eos_token='{tokenizer.eos_token}', eos_token_id={tokenizer.eos_token_id}")

Expand Down
74 changes: 74 additions & 0 deletions model/model_training/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ defaults:
per_digit_tokens: false
is_reward_model: false
deepspeed_config: configs/zero_config.json
peft_model: false
peft_type: "lora"

webgpt_dataset_only:
datasets:
Expand Down Expand Up @@ -354,6 +356,78 @@ llama-30b-pretrain:
num_train_epochs: 1
save_total_limit: 2

lora-llama-13b:
jordiclive marked this conversation as resolved.
Show resolved Hide resolved
dtype: fp16
log_dir: "llama_lora_log_13b"
learning_rate: 5e-5
model_name: /home/ubuntu/llama_hf/13B
output_dir: llama_model_13b_lora
weight_decay: 0.0
max_length: 2048
warmup_steps: 300
gradient_checkpointing: true
gradient_accumulation_steps: 1
per_device_train_batch_size: 24
per_device_eval_batch_size: 5
eval_steps: 500
num_train_epochs: 12
save_total_limit: 2
save_strategy: epoch
use_flash_attention: True
residual_dropout: 0.0
deepspeed_config: configs/zero_config.json
peft_model: true
peft_type: "lora"
use_custom_sampler: true

lora-llama-30b:
dtype: fp16
log_dir: "llama_lora_log_30b"
learning_rate: 5e-5
model_name: /home/ubuntu/llama_hf/30B
output_dir: llama_model_30b_lora
weight_decay: 0.0
max_length: 2048
warmup_steps: 300
gradient_checkpointing: true
gradient_accumulation_steps: 1
per_device_train_batch_size: 4
per_device_eval_batch_size: 2
eval_steps: 500
num_train_epochs: 12
save_total_limit: 2
save_strategy: epoch
use_flash_attention: True
residual_dropout: 0.0
deepspeed_config: configs/zero_config.json
peft_model: true
peft_type: "lora"
use_custom_sampler: true

lora-llama-65b:
dtype: fp16
log_dir: "llama_lora_log_65b"
learning_rate: 5e-5
model_name: /home/ubuntu/llama_hf/65B
output_dir: llama_model_65b_lora
weight_decay: 0.0
max_length: 2048
warmup_steps: 300
gradient_checkpointing: true
gradient_accumulation_steps: 1
per_device_train_batch_size: 12
per_device_eval_batch_size: 5
eval_steps: 250
num_train_epochs: 12
save_total_limit: 2
save_strategy: epoch
use_flash_attention: True
residual_dropout: 0.0
deepspeed_config: configs/zero_config_sft_65b.json
peft_model: true
peft_type: "lora"
use_custom_sampler: true

pythia-70m-deduped:
learning_rate: 8e-6
# model_name: EleutherAI/pythia-1b-deduped
Expand Down
58 changes: 58 additions & 0 deletions model/model_training/configs/zero_config_sft_65b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 2e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 2e9,
"stage3_max_reuse_distance": 2e9,
"stage3_gather_16bit_weights_on_model_save": true,
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
101 changes: 101 additions & 0 deletions model/model_training/models/peft_modeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from dataclasses import dataclass
from pathlib import Path

import torch
from huggingface_hub import hf_hub_download
from model_training.utils.utils import get_model, get_tokenizer
from peft import LoraConfig, PeftModel, PrefixTuningConfig, get_peft_model, prepare_model_for_int8_training


def load_peft_model(model, peft_model_path, tokenizer):
model.resize_token_embeddings(len(tokenizer))
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model = PeftModel.from_pretrained(
model,
peft_model_path,
torch_dtype=model.dtype,
)
model.eos_token_id = tokenizer.eos_token_id
extra_embeds = hf_hub_download(peft_model_path, "extra_embeddings.pt")
embed_weights = torch.load(extra_embeds, map_location=model.device)
model.base_model.model.model.embed_tokens.weight[len(tokenizer) - embed_weights.shape[0] :, :] = embed_weights.to(
model.base_model.model.model.embed_tokens.weight.dtype
)
return model


def prepare_model_for_gradient_checkpointing(model):
r"""
Prepares the model for gradient checkpointing if necessary
"""
if not getattr(model, "is_loaded_in_8bit", False):
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:

def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model


def peft_model(model, peft_type="lora", int8_training=False, gradient_checkpointing=False):
if peft_type == "lora":
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
elif peft_type == "prefix-tuning":
config = PrefixTuningConfig(
num_virtual_tokens=30, prefix_projection=True, encoder_hidden_size=1024, task_type="CAUSAL_LM"
)
else:
raise ValueError("peft_method config is lora or prefix-tuning")
model = get_peft_model(model, config)
if int8_training:
model = prepare_model_for_int8_training(model)

if gradient_checkpointing:
model = prepare_model_for_gradient_checkpointing(model)
model.print_trainable_parameters()
return model


@dataclass
class SaveLoraConfig:
dtype: torch.dtype = torch.float16
is_reward_model: bool = False
quantization: bool = False
seq2seqmodel: bool = False
freeze_layer: bool = False
residual_dropout: float = 0
use_flash_attention: bool = False
adapter_save_path: str = "adapter"
cache_dir: str = ""
model_name: str = ""
torch_ckpt_path: str = ""
peft_type: str = "lora"


def save_adapter_model_from_ckpt(save_config: SaveLoraConfig):
tokenizer = get_tokenizer(save_config)
model = get_model(save_config, tokenizer)
model = peft_model(model)
model.load_state_dict(torch.load(save_config.torch_ckpt_path))
vocab_size = tokenizer.vocab_size
num_special_tokens = len(tokenizer.additional_special_tokens)

new_embs = model.state_dict()["base_model.model.model.embed_tokens.weight"][
vocab_size : vocab_size + num_special_tokens, :
].clone()
new_embs = new_embs.to(save_config.dtype)
model.save_pretrained(save_config.adapter_save_path, torch_dtype=save_config.dtype)
tokenizer.save_pretrained(save_config.adapter_save_path)
torch.save(new_embs, Path(save_config.adapter_save_path).joinpath("extra_embeddings.pt"))
Loading