Skip to content

Commit

Permalink
Merge pull request #1 from cccntu/lora
Browse files Browse the repository at this point in the history
Add minLoRA
  • Loading branch information
cccntu authored Feb 27, 2023
2 parents ae3a8d5 + 9a3993f commit e18d9f2
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 1 deletion.
21 changes: 21 additions & 0 deletions config/finetune_shakespeare.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import time
from functools import partial

import torch
from minlora import LoRAParametrization

out_dir = 'out-shakespeare'
eval_interval = 5
Expand All @@ -9,6 +13,8 @@

dataset = 'shakespeare'
init_from = 'gpt2-xl' # this is the largest GPT-2 model
init_from = 'gpt2-large' # use a smaller for faster training
# xl doesn't fit on 24GB GPU, but with LORA it does

# only save checkpoints if the validation loss improves
always_save_checkpoint = False
Expand All @@ -23,3 +29,18 @@
# finetune at constant LR
learning_rate = 3e-5
decay_lr = False


use_lora = True
learning_rate = 1e-3 # use a higher LR for LoRA
lora_dropout_p = 0.0
rank=4
lora_alpha = 64
lora_config = {
torch.nn.Embedding: {
"weight": partial(LoRAParametrization.from_embedding, rank=rank, lora_alpha=lora_alpha),
},
torch.nn.Linear: {
"weight": partial(LoRAParametrization.from_linear, rank=rank, lora_alpha=lora_alpha),
},
}
12 changes: 12 additions & 0 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import tiktoken
from model import GPTConfig, GPT
import minlora

# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
Expand Down Expand Up @@ -38,12 +39,23 @@
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
if use_lora:
minlora.add_lora(model, lora_config)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
if use_lora:
# the full state dict includes the LoRA state dict
# so actually we don't need to load it separately again
model.load_state_dict(checkpoint['lora'], strict=False)
print('Loaded LoRA state dict')
# sanity check
#model.apply(minlora.apply_to_lora(lambda m: print((m.lora_A.sum(), m.lora_B.sum()))))
# merge for zero-overhead inference
minlora.merge_lora(model)
elif init_from.startswith('gpt2'):
# init from a given GPT-2 model
model = GPT.from_pretrained(init_from, dict(dropout=0.0))
Expand Down
39 changes: 38 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import time
import math
import pickle
import inspect
from contextlib import nullcontext

import numpy as np
Expand All @@ -29,6 +30,7 @@

from model import GPTConfig, GPT

import minlora
# -----------------------------------------------------------------------------
# default config values designed to train a gpt2 (124M) on OpenWebText
# I/O
Expand Down Expand Up @@ -180,13 +182,46 @@ def get_batch(split):
if block_size < model.config.block_size:
model.crop_block_size(block_size)
model_args['block_size'] = block_size # so that the checkpoint will have the right value
if use_lora:
minlora.add_lora(model, lora_config=lora_config)
minlora.tie_weights(linear=model.lm_head, embedding=model.transformer.wte)
model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
def configure_optimizers_lora(self, weight_decay, learning_rate, betas, device_type):
# we apply weight decay to all lora params
optim_groups = [
# note: .get_lora_params() returns a generator
# we need to wrap it in a list so we can consume it twice
{"params": list(minlora.get_lora_params(self)) , "weight_decay": weight_decay},
# you can also add biases for fine-tuning,
# but I want to make sure lora alone works
# {"params": minlora.get_bias_params(self), "weight_decay": 0.0}, # bias params don't get weight decay
]

def parameter_count(optim_groups):
n = sum(p.numel() for group in optim_groups for p in group["params"])
if n < 1e6:
return f"{n/1e3:.1f}k"
else:
return f"{n/1e6:.1f}M"

print(f"optimizing {parameter_count(optim_groups)} parameters")

# new PyTorch nightly has a new 'fused' option for AdamW that is much faster
use_fused = (device_type == "cuda") and ("fused" in inspect.signature(torch.optim.AdamW).parameters)
print(f"using fused AdamW: {use_fused}")
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)

return optimizer
if use_lora:
optimizer = configure_optimizers_lora(model, weight_decay, learning_rate, (beta1, beta2), device_type)
else:
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
optimizer.load_state_dict(checkpoint['optimizer'])

Expand Down Expand Up @@ -271,6 +306,8 @@ def get_lr(it):
'best_val_loss': best_val_loss,
'config': config,
}
if use_lora:
checkpoint['lora'] = minlora.get_lora_state_dict(raw_model)
print(f"saving checkpoint to {out_dir}")
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
if iter_num == 0 and eval_only:
Expand Down

0 comments on commit e18d9f2

Please sign in to comment.