Skip to content

Commit 52a2dde

Browse files
[Feature] qlora support (#5586)
* [feature] qlora support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * qlora follow commit * migrate qutization folder to colossalai/ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cabc128 commit 52a2dde

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1049
-597
lines changed

LICENSE

+17
Original file line numberDiff line numberDiff line change
@@ -527,3 +527,20 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
527527
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
528528
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
529529
SOFTWARE.
530+
531+
532+
---------------- LICENSE FOR Hugging Face accelerate ----------------
533+
534+
Copyright 2021 The HuggingFace Team
535+
536+
Licensed under the Apache License, Version 2.0 (the "License");
537+
you may not use this file except in compliance with the License.
538+
You may obtain a copy of the License at
539+
540+
http://www.apache.org/licenses/LICENSE-2.0
541+
542+
Unless required by applicable law or agreed to in writing, software
543+
distributed under the License is distributed on an "AS IS" BASIS,
544+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
545+
See the License for the specific language governing permissions and
546+
limitations under the License.

applications/Chat/benchmarks/benchmark_opt_lora_dummy.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ def main(args):
7676
if args.strategy == "ddp":
7777
strategy = DDPStrategy()
7878
elif args.strategy == "colossalai_gemini":
79-
strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5)
79+
strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
8080
elif args.strategy == "colossalai_gemini_cpu":
81-
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
81+
strategy = GeminiStrategy(
82+
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
83+
)
8284
elif args.strategy == "colossalai_zero2":
8385
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
8486
elif args.strategy == "colossalai_zero2_cpu":

applications/Chat/coati/models/base/actor.py

-1
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,3 @@ def forward(
3030
"""Returns model output."""
3131
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
3232
return output
33-

applications/Chat/coati/ray/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def get_strategy_from_args(strategy: str):
7575
elif strategy == "colossalai_zero2":
7676
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
7777
elif strategy == "colossalai_gemini_cpu":
78-
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
78+
strategy_ = GeminiStrategy(
79+
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
80+
)
7981
elif strategy == "colossalai_zero2_cpu":
8082
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
8183
else:

applications/Chat/coati/trainer/strategies/ddp.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,17 @@ def save_pretrained(
101101

102102
model_path = os.path.join(path, "pytorch_model.bin")
103103
self.save_model(model, model_path, shard=shard)
104+
104105
def _replace_keys(model_path: str, replace_fn: Callable):
105106
state_dict = torch.load(model_path, map_location="cpu")
106107
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
107108
torch.save(state_dict, model_path)
109+
108110
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
109111
# HACK: rename keys of pytorch_model.bin
110112
if dist.get_rank() == 0:
111113
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
112114

113-
114115
def get_model_state_dict_shard(self, model: nn.Module, **config):
115116
# TODO: implement sharding on naive strategy
116117
model = self.unwrap_model(model)

applications/Chat/examples/community/peft/train_peft_prompts.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def main(args):
2424
if args.strategy == "ddp":
2525
strategy = DDPStrategy()
2626
elif args.strategy == "colossalai_gemini":
27-
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
27+
strategy = GeminiStrategy(
28+
placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
29+
)
2830
elif args.strategy == "colossalai_zero2":
2931
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
3032
else:

applications/Colossal-LLaMA-2/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download
130130
model_dir = snapshot_download('colossalai/Colossal-LLaMA-2-7b-base', revision='v1.0.1')
131131
tokenizer = AutoTokenizer.from_pretrained(model_dir, device_map="auto", trust_remote_code=True)
132132
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True).eval()
133-
generation_kwargs = {"max_new_tokens": 256,
134-
"top_p": 0.95,
133+
generation_kwargs = {"max_new_tokens": 256,
134+
"top_p": 0.95,
135135
"temperature": 0.3
136136
}
137137
input = '离离原上草,'

applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
33

4-
import numpy as np
54
import os
65
import random
76
from dataclasses import dataclass
8-
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
7+
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Union
98

9+
import numpy as np
1010
import torch
11-
from datasets import dataset_dict, load_from_disk
11+
import torch.nn.functional as F
1212
from datasets import Dataset as HFDataset
13+
from datasets import dataset_dict, load_from_disk
1314
from torch.distributed import ProcessGroup
1415
from torch.distributed.distributed_c10d import _get_default_group
15-
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
16+
from torch.utils.data import ConcatDataset, DataLoader, Dataset, DistributedSampler
1617
from transformers.tokenization_utils import PreTrainedTokenizer
17-
import torch.nn.functional as F
1818

1919
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
2020
PathType = Union[str, os.PathLike]

applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import random
88
import warnings
99
from copy import deepcopy
10-
from datasets import dataset_dict
11-
from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
10+
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
1211

12+
from datasets import dataset_dict
1313
from torch.utils.data import ConcatDataset, Dataset, IterableDataset
1414
from transformers.models.llama.tokenization_llama import LlamaTokenizer
1515
from transformers.tokenization_utils import PreTrainedTokenizer
@@ -169,12 +169,7 @@ def __iter__(self) -> Iterable[Dict[str, List[int]]]:
169169
spliced_labels.extend(seq_labels)
170170
# For residual spliced data point at the end of the data set
171171
if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
172-
examples.append(
173-
{
174-
self.input_ids_field: spliced_input_ids,
175-
self.labels_field: spliced_labels
176-
}
177-
)
172+
examples.append({self.input_ids_field: spliced_input_ids, self.labels_field: spliced_labels})
178173
if self.shuffle:
179174
random.shuffle(examples)
180175
for spliced_data_point in examples:

applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88

99
import numpy as np
1010
import torch
11-
from transformers import LlamaTokenizer, LlamaForCausalLM
11+
from transformers import LlamaForCausalLM, LlamaTokenizer
1212

1313
from colossalai.logging import get_dist_logger
1414

15-
1615
logger = get_dist_logger()
1716

1817

applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
"""
77

88
import argparse
9-
import os
109
import json
10+
import os
1111
from typing import List, Union
1212

13-
from transformers.models.llama.tokenization_llama import LlamaTokenizer
1413
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
14+
from transformers.models.llama.tokenization_llama import LlamaTokenizer
1515

1616
from colossalai.logging import get_dist_logger
1717

applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from typing import Any, Dict, Tuple, Union
1111

1212
import torch
13-
from torch.optim.optimizer import Optimizer
1413
from torch.optim.lr_scheduler import _LRScheduler
14+
from torch.optim.optimizer import Optimizer
1515

1616
from colossalai.booster import Booster
1717
from colossalai.cluster import DistCoordinator

applications/Colossal-LLaMA-2/docs/example.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -242,4 +242,4 @@ To comprehensively assess the performance of the Colossal-LLaMA-2-7B-base model,
242242
## Conclusion
243243
In general, the Colossal-LLaMA-2-7B-base model not only enhances its understanding of English but also exhibits significant improvements in its comprehension of Chinese. It boasts a broad spectrum of general knowledge, encompassing various fields such as food, sports, technology, literature, games, and more. Regarding text generation tasks, the Colossal-LLaMA-2-7B-base model excels in writing performance; however, its ability to generate specific formats like code, emails, tables, etc., needs enhancement due to the scarcity of relevant training data during our training phase. When compared to the Qwen-7b-base model, the Colossal-LLaMA-2-7B-base model outperforms it in answering most English questions and some Chinese questions, as demonstrated in the examples above.
244244

245-
Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements.
245+
Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements.
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
hostname1
2-
hostname2
2+
hostname2

applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
import time
1212
from multiprocessing import cpu_count
1313

14+
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
15+
ClosedToConstantLengthSplicedDataset,
16+
supervised_tokenize,
17+
)
1418
from datasets import dataset_dict, load_dataset
1519
from transformers.models.llama.tokenization_llama import LlamaTokenizer
1620

1721
from colossalai.logging import get_dist_logger
18-
from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
19-
supervised_tokenize,
20-
ClosedToConstantLengthSplicedDataset,
21-
)
2222

2323
logger = get_dist_logger()
2424

@@ -149,5 +149,5 @@ def main():
149149
spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
150150

151151

152-
if __name__ == '__main__':
152+
if __name__ == "__main__":
153153
main()

applications/Colossal-LLaMA-2/requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,3 @@ flash-attn>=2.0.0,<=2.0.5
1212
tqdm
1313
sentencepiece==0.1.99
1414
protobuf<=3.20.0
15-

applications/Colossal-LLaMA-2/train.py

+15-23
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,39 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
33
"""
4-
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
4+
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
55
"""
66

7-
import json
87
import argparse
8+
import json
99
import os
1010
import resource
1111
from contextlib import nullcontext
12-
from tqdm import tqdm
1312

1413
import torch
1514
import torch.distributed as dist
15+
from colossal_llama2.dataset.loader import (
16+
DataCollatorForSupervisedDataset,
17+
StatefulDistributedSampler,
18+
load_tokenized_dataset,
19+
setup_distributed_dataloader,
20+
)
21+
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
22+
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
23+
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
1624
from torch.utils.tensorboard import SummaryWriter
17-
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
25+
from tqdm import tqdm
26+
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
1827

1928
import colossalai
2029
from colossalai.booster import Booster
21-
from colossalai.booster.plugin import (
22-
GeminiPlugin,
23-
LowLevelZeroPlugin,
24-
HybridParallelPlugin,
25-
)
30+
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
2631
from colossalai.cluster import DistCoordinator
2732
from colossalai.lazy import LazyInitContext
2833
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
2934
from colossalai.nn.optimizer import HybridAdam
3035
from colossalai.utils import get_current_device
3136

32-
from colossal_llama2.dataset.loader import (
33-
load_tokenized_dataset,
34-
setup_distributed_dataloader,
35-
DataCollatorForSupervisedDataset,
36-
StatefulDistributedSampler,
37-
)
38-
39-
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
40-
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
41-
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
42-
4337

4438
def get_model_numel(model: torch.nn.Module) -> int:
4539
return sum(p.numel() for p in model.parameters())
@@ -372,9 +366,7 @@ def main() -> None:
372366
# Final save.
373367
coordinator.print_on_master("Start saving final model checkpoint")
374368
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
375-
coordinator.print_on_master(
376-
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
377-
)
369+
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
378370

379371
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
380372

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.0.1
1+
0.0.1

colossalai/booster/booster.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import colossalai.interface.pretrained as pretrained_utils
2020
from colossalai.checkpoint_io import GeneralCheckpointIO
2121
from colossalai.interface import ModelWrapper, OptimizerWrapper
22+
from colossalai.quantization import BnbQuantizationConfig
2223

2324
from .accelerator import Accelerator
2425
from .mixed_precision import MixedPrecision, mixed_precision_factory
@@ -230,7 +231,12 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -
230231
return self.plugin.no_sync(model, optimizer)
231232

232233
def enable_lora(
233-
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
234+
self,
235+
model: nn.Module,
236+
pretrained_dir: Optional[str] = None,
237+
lora_config: "peft.LoraConfig" = None,
238+
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
239+
quantize=False,
234240
) -> nn.Module:
235241
"""
236242
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
@@ -259,7 +265,20 @@ def enable_lora(
259265
assert (
260266
pretrained_dir is not None
261267
), "Please provide pretrained directory path if not passing in lora configuration."
262-
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
268+
if quantize is True:
269+
if bnb_quantization_config is not None:
270+
warnings.warn(
271+
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
272+
)
273+
else:
274+
bnb_quantization_config = BnbQuantizationConfig(
275+
load_in_4bit=True,
276+
bnb_4bit_compute_dtype=torch.bfloat16,
277+
bnb_4bit_use_double_quant=True,
278+
bnb_4bit_quant_type="nf4",
279+
)
280+
281+
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
263282

264283
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
265284
"""Load model from checkpoint.

0 commit comments

Comments
 (0)