Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a69c139

Browse files
committedDec 11, 2023
fix
fix fix fix fix fix fix
1 parent 4997393 commit a69c139

File tree

4 files changed

+41
-12
lines changed

4 files changed

+41
-12
lines changed
 

‎colossalai/booster/plugin/low_level_zero_plugin.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import logging
2+
import warnings
23
import os
34
from functools import partial
45
from pathlib import Path
56
from types import MethodType
67
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Dict
78

8-
from peft import LoraConfig, TaskType, get_peft_model
9-
109
import torch
1110
import torch.nn as nn
11+
from torch.nn import Parameter
1212
from torch.optim import Optimizer
1313
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
1414
from torch.utils._pytree import tree_map
@@ -335,13 +335,44 @@ def enable_lora(
335335
from peft import PeftModel, get_peft_model
336336
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
337337
self.lora_enabled = True
338+
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
338339

339340
if pretrained_dir is None:
340341
peft_model = get_peft_model(model, lora_config)
341342
else:
342343
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
343344
return peft_model
344345

346+
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
347+
origin_param_id = id(origin_param)
348+
for group_id, param_group in enumerate(optimizer.param_groups):
349+
for p in param_group['params']:
350+
if id(p) == origin_param_id:
351+
return group_id
352+
return -1
353+
354+
def add_lora_para_to_optimizer(self, model, optimizer):
355+
""" add lora parameters to optimizer """
356+
name2param= {}
357+
for name, param in model.named_parameters():
358+
name2param[name] = param
359+
360+
optimizer_param_nums = 0
361+
for param_group in optimizer.param_groups:
362+
optimizer_param_nums += len(param_group['params'])
363+
364+
# Check if the optimizer is created after the model is transformed into a LoRa model.
365+
if len(name2param) != optimizer_param_nums:
366+
for name, param in name2param.items():
367+
if 'lora_A' in name or 'lora_B' in name:
368+
origin_key = name.replace("lora_A.", "")
369+
origin_key = origin_key.replace("lora_B.", "")
370+
origin_key = origin_key.replace(f"{model.active_adapter}.", "")
371+
origin_param = name2param[origin_key]
372+
group_id = self.get_param_group_id(optimizer, origin_param)
373+
assert group_id != -1, "Parameter error, origin parameter does't exists."
374+
optimizer.param_groups[group_id]['params'].append(param)
375+
345376
def configure(
346377
self,
347378
model: nn.Module,
@@ -353,12 +384,8 @@ def configure(
353384
if self.lora_enabled:
354385
from peft import PeftModel
355386
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
356-
357-
optim_params_nums = 0
358-
for param_group in optimizer.param_groups:
359-
optim_params_nums += len(param_group['params'])
360-
model_params_nums = len(list(model.named_parameters()))
361-
assert optim_params_nums == model_params_nums, "Optimizer should be initialized after enabling lora."
387+
self.add_lora_para_to_optimizer(model, optimizer)
388+
362389

363390
if not isinstance(model, ModelWrapper):
364391
model = LowLevelZeroModel(model, self.precision)

‎colossalai/zero/low_level/bookkeeping/gradient_store.py

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def get_working_grads_by_group_id(self, group_id: int) -> List:
8282
"""
8383

8484
grad_list = []
85+
# When using LoRa and the user sets multiple param_groups, it is possible that some param_groups have no parameters with gradients.
8586
if group_id not in self._grads_of_params.keys():
8687
return grad_list
8788
for param_grads in self._grads_of_params[group_id].values():

‎tests/test_booster/test_plugin/test_low_level_zero_plugin.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)
2424
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
2525
booster = Booster(plugin=plugin)
2626
model = model_fn()
27+
optimizer = HybridAdam(model.parameters(), lr=1e-3)
2728

2829
if lora_config is not None:
2930
model = booster.enable_lora(model, lora_config=lora_config)
3031

31-
optimizer = HybridAdam(model.parameters(), lr=1e-3)
3232
criterion = lambda x: x.mean()
3333
data = data_gen_fn()
3434

@@ -48,6 +48,7 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)
4848

4949
except Exception as e:
5050
return repr(e)
51+
# raise e
5152

5253

5354

‎tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,10 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
8080
booster = Booster(plugin=plugin)
8181
new_booster = Booster(plugin=new_plugin)
8282
model = model_fn()
83+
optimizer = HybridAdam(model.parameters(), lr=1e-3)
8384
new_model = deepcopy(model)
85+
new_optimizer = HybridAdam(new_model.parameters(), lr=1e-3)
8486
model = booster.enable_lora(model, lora_config=lora_config)
85-
optimizer = HybridAdam(model.parameters(), lr=1e-3)
8687
criterion = lambda x: x.mean()
8788
data = data_gen_fn()
8889

@@ -107,7 +108,6 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
107108
booster.save_lora_as_pretrained(model, model_ckpt_path)
108109
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False)
109110
new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config)
110-
new_optimizer = HybridAdam(new_model.parameters(), lr=1e-3)
111111
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
112112
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
113113

@@ -168,7 +168,7 @@ def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: boo
168168

169169
def run_dist(rank, world_size, port):
170170
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
171-
# check_low_level_zero_checkpointIO()
171+
check_low_level_zero_checkpointIO()
172172
check_low_level_zero_lora_checkpointIO()
173173
torch.cuda.empty_cache()
174174

0 commit comments

Comments
 (0)
Please sign in to comment.