|
1 | 1 | #!/usr/bin/env python3
|
2 | 2 | # -*- coding: utf-8 -*-
|
3 | 3 | """
|
4 |
| -Continual Pre-training of LLaMA-2 developed by Colossal-AI Team |
| 4 | +Continual Pre-training of LLaMA-2 developed by Colossal-AI Team |
5 | 5 | """
|
6 | 6 |
|
7 |
| -import json |
8 | 7 | import argparse
|
| 8 | +import json |
9 | 9 | import os
|
10 | 10 | import resource
|
11 | 11 | from contextlib import nullcontext
|
12 |
| -from tqdm import tqdm |
13 | 12 |
|
14 | 13 | import torch
|
15 | 14 | import torch.distributed as dist
|
| 15 | +from colossal_llama2.dataset.loader import ( |
| 16 | + DataCollatorForSupervisedDataset, |
| 17 | + StatefulDistributedSampler, |
| 18 | + load_tokenized_dataset, |
| 19 | + setup_distributed_dataloader, |
| 20 | +) |
| 21 | +from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint |
| 22 | +from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention |
| 23 | +from colossal_llama2.utils.froze import freeze_non_embeds_parameters |
16 | 24 | from torch.utils.tensorboard import SummaryWriter
|
17 |
| -from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig |
| 25 | +from tqdm import tqdm |
| 26 | +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer |
18 | 27 |
|
19 | 28 | import colossalai
|
20 | 29 | from colossalai.booster import Booster
|
21 |
| -from colossalai.booster.plugin import ( |
22 |
| - GeminiPlugin, |
23 |
| - LowLevelZeroPlugin, |
24 |
| - HybridParallelPlugin, |
25 |
| -) |
| 30 | +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin |
26 | 31 | from colossalai.cluster import DistCoordinator
|
27 | 32 | from colossalai.lazy import LazyInitContext
|
28 | 33 | from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
29 | 34 | from colossalai.nn.optimizer import HybridAdam
|
30 | 35 | from colossalai.utils import get_current_device
|
31 | 36 |
|
32 |
| -from colossal_llama2.dataset.loader import ( |
33 |
| - load_tokenized_dataset, |
34 |
| - setup_distributed_dataloader, |
35 |
| - DataCollatorForSupervisedDataset, |
36 |
| - StatefulDistributedSampler, |
37 |
| -) |
38 |
| - |
39 |
| -from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention |
40 |
| -from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint |
41 |
| -from colossal_llama2.utils.froze import freeze_non_embeds_parameters |
42 |
| - |
43 | 37 |
|
44 | 38 | def get_model_numel(model: torch.nn.Module) -> int:
|
45 | 39 | return sum(p.numel() for p in model.parameters())
|
@@ -372,9 +366,7 @@ def main() -> None:
|
372 | 366 | # Final save.
|
373 | 367 | coordinator.print_on_master("Start saving final model checkpoint")
|
374 | 368 | booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
375 |
| - coordinator.print_on_master( |
376 |
| - f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}" |
377 |
| - ) |
| 369 | + coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") |
378 | 370 |
|
379 | 371 | coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
380 | 372 |
|
|
0 commit comments