Skip to content

Commit 901ab1e

Browse files
authored
[chat]: add lora merge weights config (#4766)
* feat: modify lora merge weights fn * feat: add lora merge weights config
1 parent 493a5ef commit 901ab1e

File tree

4 files changed

+61
-31
lines changed

4 files changed

+61
-31
lines changed

applications/Chat/coati/models/lora.py

+39-31
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import dataclasses
12
import math
3+
import warnings
24
from typing import Optional
35

46
import loralib as lora
@@ -7,6 +9,14 @@
79
import torch.nn.functional as F
810

911

12+
@dataclasses.dataclass
13+
class LoRAManager:
14+
merge_weights: bool = False
15+
16+
17+
LORA_MANAGER = LoRAManager()
18+
19+
1020
class LoraLinear(lora.LoRALayer, nn.Module):
1121
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
1222

@@ -17,13 +27,11 @@ def __init__(
1727
r: int = 0,
1828
lora_alpha: int = 1,
1929
lora_dropout: float = 0.0,
20-
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
21-
merge_weights: bool = True,
30+
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
31+
fan_in_fan_out: bool = False,
2232
):
2333
nn.Module.__init__(self)
24-
lora.LoRALayer.__init__(
25-
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
26-
)
34+
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
2735
self.weight = weight
2836
self.bias = bias
2937

@@ -53,31 +61,31 @@ def train(self, mode: bool = True):
5361
def T(w):
5462
return w.T if self.fan_in_fan_out else w
5563

56-
nn.Module.train(self, mode)
57-
if self.merge_weights and self.merged:
58-
# Make sure that the weights are not merged
59-
if self.r > 0:
60-
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
61-
# FIXME(csric): temporary fix
62-
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
63-
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
64-
self.reset_parameters()
65-
else:
66-
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
67-
self.merged = False
68-
69-
def eval(self):
70-
def T(w):
71-
return w.T if self.fan_in_fan_out else w
72-
73-
nn.Module.eval(self)
74-
if self.merge_weights and not self.merged:
75-
# Merge the weights and mark it
76-
if self.r > 0:
77-
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
78-
delattr(self, "lora_A")
79-
delattr(self, "lora_B")
80-
self.merged = True
64+
self.training = mode
65+
if LORA_MANAGER.merge_weights:
66+
if mode and self.merged:
67+
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
68+
raise NotImplementedError("LoRA unmerge is not tested.")
69+
# Make sure that the weights are not merged
70+
if self.r > 0:
71+
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
72+
# FIXME(csric): temporary fix
73+
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
74+
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
75+
self.reset_parameters()
76+
else:
77+
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
78+
self.merged = False
79+
elif not mode and not self.merged:
80+
warnings.warn("Invoke module.eval() would merge LoRA weights.")
81+
# Merge the weights and mark it
82+
if self.r > 0:
83+
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
84+
delattr(self, "lora_A")
85+
delattr(self, "lora_B")
86+
self.merged = True
87+
88+
return self
8189

8290
def forward(self, x: torch.Tensor):
8391
def T(w):
@@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
96104
assert (
97105
lora_rank <= linear.in_features
98106
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
99-
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
107+
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
100108
return lora_linear
101109

102110

applications/Chat/examples/train_prompts.py

+7
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ def main(args):
192192
use_wandb=args.use_wandb,
193193
)
194194

195+
if args.lora_rank > 0 and args.merge_lora_weights:
196+
from coati.models.lora import LORA_MANAGER
197+
198+
# NOTE: set model to eval to merge LoRA weights
199+
LORA_MANAGER.merge_weights = True
200+
actor.eval()
195201
# save model checkpoint after fitting
196202
strategy.save_model(actor, args.save_path, only_rank0=True)
197203
# save optimizer checkpoint on all ranks
@@ -227,6 +233,7 @@ def main(args):
227233
parser.add_argument("--ptx_batch_size", type=int, default=1)
228234
parser.add_argument("--experience_batch_size", type=int, default=8)
229235
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
236+
parser.add_argument("--merge_lora_weights", type=bool, default=True)
230237
parser.add_argument("--lr", type=float, default=1e-7)
231238
parser.add_argument("--kl_coef", type=float, default=0.1)
232239
parser.add_argument("--ptx_coef", type=float, default=0.9)

applications/Chat/examples/train_reward_model.py

+8
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,13 @@ def train(args):
157157
log_dir=args.log_dir,
158158
use_wandb=args.use_wandb,
159159
)
160+
161+
if args.lora_rank > 0 and args.merge_lora_weights:
162+
from coati.models.lora import LORA_MANAGER
163+
164+
# NOTE: set model to eval to merge LoRA weights
165+
LORA_MANAGER.merge_weights = True
166+
model.eval()
160167
# save model checkpoint after fitting on only rank0
161168
strategy.save_model(model, args.save_path, only_rank0=True)
162169
# save optimizer checkpoint on all ranks
@@ -186,6 +193,7 @@ def train(args):
186193
parser.add_argument("--batch_size", type=int, default=1)
187194
parser.add_argument("--max_len", type=int, default=512)
188195
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
196+
parser.add_argument("--merge_lora_weights", type=bool, default=True)
189197
parser.add_argument("--lr", type=float, default=9e-6)
190198
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
191199
parser.add_argument("--log_dir", default="logs", type=str)

applications/Chat/examples/train_sft.py

+7
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,12 @@ def train(args):
177177
use_wandb=args.use_wandb,
178178
)
179179

180+
if args.lora_rank > 0 and args.merge_lora_weights:
181+
from coati.models.lora import LORA_MANAGER
182+
183+
# NOTE: set model to eval to merge LoRA weights
184+
LORA_MANAGER.merge_weights = True
185+
model.eval()
180186
# save model checkpoint after fitting on only rank0
181187
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
182188
# save optimizer checkpoint on all ranks
@@ -204,6 +210,7 @@ def train(args):
204210
parser.add_argument("--batch_size", type=int, default=4)
205211
parser.add_argument("--max_len", type=int, default=512)
206212
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
213+
parser.add_argument("--merge_lora_weights", type=bool, default=True)
207214
parser.add_argument("--lr", type=float, default=5e-6)
208215
parser.add_argument("--accumulation_steps", type=int, default=8)
209216
parser.add_argument("--log_dir", default="logs", type=str)

0 commit comments

Comments
 (0)