Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chat]: add lora merge weights config #4766

Merged
merged 2 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 39 additions & 31 deletions applications/Chat/coati/models/lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
import math
import warnings
from typing import Optional

import loralib as lora
Expand All @@ -7,6 +9,14 @@
import torch.nn.functional as F


@dataclasses.dataclass
class LoRAManager:
merge_weights: bool = False


LORA_MANAGER = LoRAManager()


class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""

Expand All @@ -17,13 +27,11 @@ def __init__(
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
fan_in_fan_out: bool = False,
):
nn.Module.__init__(self)
lora.LoRALayer.__init__(
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
)
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
self.weight = weight
self.bias = bias

Expand Down Expand Up @@ -53,31 +61,31 @@ def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w

nn.Module.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
# FIXME(csric): temporary fix
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
self.reset_parameters()
else:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False

def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w

nn.Module.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, "lora_A")
delattr(self, "lora_B")
self.merged = True
self.training = mode
if LORA_MANAGER.merge_weights:
if mode and self.merged:
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
raise NotImplementedError("LoRA unmerge is not tested.")
# Make sure that the weights are not merged
if self.r > 0:
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
# FIXME(csric): temporary fix
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
self.reset_parameters()
else:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
elif not mode and not self.merged:
warnings.warn("Invoke module.eval() would merge LoRA weights.")
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, "lora_A")
delattr(self, "lora_B")
self.merged = True

return self

def forward(self, x: torch.Tensor):
def T(w):
Expand All @@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert (
lora_rank <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
return lora_linear


Expand Down
7 changes: 7 additions & 0 deletions applications/Chat/examples/train_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def main(args):
use_wandb=args.use_wandb,
)

if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER

# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
actor.eval()
# save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
Expand Down Expand Up @@ -227,6 +233,7 @@ def main(args):
parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=1e-7)
parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.9)
Expand Down
8 changes: 8 additions & 0 deletions applications/Chat/examples/train_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ def train(args):
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)

if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER

# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
Expand Down Expand Up @@ -186,6 +193,7 @@ def train(args):
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=9e-6)
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
parser.add_argument("--log_dir", default="logs", type=str)
Expand Down
7 changes: 7 additions & 0 deletions applications/Chat/examples/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def train(args):
use_wandb=args.use_wandb,
)

if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER

# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
Expand Down Expand Up @@ -204,6 +210,7 @@ def train(args):
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str)
Expand Down