1
+ import dataclasses
1
2
import math
3
+ import warnings
2
4
from typing import Optional
3
5
4
6
import loralib as lora
7
9
import torch .nn .functional as F
8
10
9
11
12
+ @dataclasses .dataclass
13
+ class LoRAManager :
14
+ merge_weights : bool = False
15
+
16
+
17
+ LORA_MANAGER = LoRAManager ()
18
+
19
+
10
20
class LoraLinear (lora .LoRALayer , nn .Module ):
11
21
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
12
22
@@ -17,13 +27,11 @@ def __init__(
17
27
r : int = 0 ,
18
28
lora_alpha : int = 1 ,
19
29
lora_dropout : float = 0.0 ,
20
- fan_in_fan_out : bool = False , # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
21
- merge_weights : bool = True ,
30
+ # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
31
+ fan_in_fan_out : bool = False ,
22
32
):
23
33
nn .Module .__init__ (self )
24
- lora .LoRALayer .__init__ (
25
- self , r = r , lora_alpha = lora_alpha , lora_dropout = lora_dropout , merge_weights = merge_weights
26
- )
34
+ lora .LoRALayer .__init__ (self , r = r , lora_alpha = lora_alpha , lora_dropout = lora_dropout , merge_weights = False )
27
35
self .weight = weight
28
36
self .bias = bias
29
37
@@ -53,31 +61,31 @@ def train(self, mode: bool = True):
53
61
def T (w ):
54
62
return w .T if self .fan_in_fan_out else w
55
63
56
- nn . Module . train ( self , mode )
57
- if self .merge_weights and self . merged :
58
- # Make sure that the weights are not merged
59
- if self . r > 0 :
60
- if not hasattr ( self , "lora_A" ) or not hasattr ( self , "lora_B" ):
61
- # FIXME(csric): temporary fix
62
- self .lora_A = nn . Parameter ( self . weight . new_empty (( self . r , self . in_features )))
63
- self . lora_B = nn . Parameter (self . weight . new_empty (( self . out_features , self . r )))
64
- self . reset_parameters ()
65
- else :
66
- self .weight . data -= T (self .lora_B @ self .lora_A ) * self .scaling
67
- self . merged = False
68
-
69
- def eval (self ):
70
- def T ( w ):
71
- return w . T if self . fan_in_fan_out else w
72
-
73
- nn . Module . eval ( self )
74
- if self . merge_weights and not self .merged :
75
- # Merge the weights and mark it
76
- if self . r > 0 :
77
- self . weight . data += T ( self . lora_B @ self . lora_A ) * self . scaling
78
- delattr ( self , "lora_A" )
79
- delattr ( self , "lora_B" )
80
- self . merged = True
64
+ self . training = mode
65
+ if LORA_MANAGER .merge_weights :
66
+ if mode and self . merged :
67
+ warnings . warn ( "Invoke module.train() would unmerge LoRA weights." )
68
+ raise NotImplementedError ( "LoRA unmerge is not tested." )
69
+ # Make sure that the weights are not merged
70
+ if self .r > 0 :
71
+ if not hasattr (self , "lora_A" ) or not hasattr ( self , "lora_B" ):
72
+ # FIXME(csric): temporary fix
73
+ self . lora_A = nn . Parameter ( self . weight . new_empty (( self . r , self . in_features )))
74
+ self .lora_B = nn . Parameter (self .weight . new_empty (( self .out_features , self .r )))
75
+ self . reset_parameters ()
76
+ else :
77
+ self . weight . data -= T (self . lora_B @ self . lora_A ) * self . scaling
78
+ self . merged = False
79
+ elif not mode and not self . merged :
80
+ warnings . warn ( "Invoke module.eval() would merge LoRA weights." )
81
+ # Merge the weights and mark it
82
+ if self .r > 0 :
83
+ self . weight . data += T ( self . lora_B @ self . lora_A ) * self . scaling
84
+ delattr ( self , "lora_A" )
85
+ delattr ( self , "lora_B" )
86
+ self . merged = True
87
+
88
+ return self
81
89
82
90
def forward (self , x : torch .Tensor ):
83
91
def T (w ):
@@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
96
104
assert (
97
105
lora_rank <= linear .in_features
98
106
), f"LoRA rank ({ lora_rank } ) must be less than or equal to in features ({ linear .in_features } )"
99
- lora_linear = LoraLinear (linear .weight , linear .bias , r = lora_rank , merge_weights = False )
107
+ lora_linear = LoraLinear (linear .weight , linear .bias , r = lora_rank )
100
108
return lora_linear
101
109
102
110
0 commit comments