1
1
import logging
2
+ import warnings
2
3
import os
3
4
from functools import partial
4
5
from pathlib import Path
5
6
from types import MethodType
6
7
from typing import Callable , Dict , Iterator , List , Optional , Tuple , Dict
7
8
8
- from peft import LoraConfig , TaskType , get_peft_model
9
-
10
9
import torch
11
10
import torch .nn as nn
11
+ from torch .nn import Parameter
12
12
from torch .optim import Optimizer
13
13
from torch .optim .lr_scheduler import _LRScheduler as LRScheduler
14
14
from torch .utils ._pytree import tree_map
@@ -335,13 +335,44 @@ def enable_lora(
335
335
from peft import PeftModel , get_peft_model
336
336
assert not isinstance (model , LowLevelZeroModel ), "Lora should be enabled before boosting the model."
337
337
self .lora_enabled = True
338
+ warnings .warn ("You have enabled LoRa training. Please check the hyperparameters such as lr" )
338
339
339
340
if pretrained_dir is None :
340
341
peft_model = get_peft_model (model , lora_config )
341
342
else :
342
343
peft_model = PeftModel .from_pretrained (model , pretrained_dir , is_trainable = True )
343
344
return peft_model
344
345
346
+ def get_param_group_id (self , optimizer : Optimizer , origin_param : Parameter ):
347
+ origin_param_id = id (origin_param )
348
+ for group_id , param_group in enumerate (optimizer .param_groups ):
349
+ for p in param_group ['params' ]:
350
+ if id (p ) == origin_param_id :
351
+ return group_id
352
+ return - 1
353
+
354
+ def add_lora_para_to_optimizer (self , model , optimizer ):
355
+ """ add lora parameters to optimizer """
356
+ name2param = {}
357
+ for name , param in model .named_parameters ():
358
+ name2param [name ] = param
359
+
360
+ optimizer_param_nums = 0
361
+ for param_group in optimizer .param_groups :
362
+ optimizer_param_nums += len (param_group ['params' ])
363
+
364
+ # Check if the optimizer is created after the model is transformed into a LoRa model.
365
+ if len (name2param ) != optimizer_param_nums :
366
+ for name , param in name2param .items ():
367
+ if 'lora_A' in name or 'lora_B' in name :
368
+ origin_key = name .replace ("lora_A." , "" )
369
+ origin_key = origin_key .replace ("lora_B." , "" )
370
+ origin_key = origin_key .replace (f"{ model .active_adapter } ." , "" )
371
+ origin_param = name2param [origin_key ]
372
+ group_id = self .get_param_group_id (optimizer , origin_param )
373
+ assert group_id != - 1 , "Parameter error, origin parameter does't exists."
374
+ optimizer .param_groups [group_id ]['params' ].append (param )
375
+
345
376
def configure (
346
377
self ,
347
378
model : nn .Module ,
@@ -353,12 +384,8 @@ def configure(
353
384
if self .lora_enabled :
354
385
from peft import PeftModel
355
386
assert isinstance (model , PeftModel ), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
356
-
357
- optim_params_nums = 0
358
- for param_group in optimizer .param_groups :
359
- optim_params_nums += len (param_group ['params' ])
360
- model_params_nums = len (list (model .named_parameters ()))
361
- assert optim_params_nums == model_params_nums , "Optimizer should be initialized after enabling lora."
387
+ self .add_lora_para_to_optimizer (model , optimizer )
388
+
362
389
363
390
if not isinstance (model , ModelWrapper ):
364
391
model = LowLevelZeroModel (model , self .precision )
0 commit comments